#!/usr/bin/env python
"""Python program used to get and set XML variables
"""

#from __future__ import print_function

import sys

if sys.hexversion < 0x02070000:
    print(70 * "*")
    print("ERROR: {0} requires python >= 2.7.x. ".format(sys.argv[0]))
    print("It appears that you are running python {0}".format(
        ".".join(str(x) for x in sys.version_info[0:3])))
    print(70 * "*")
    sys.exit(1)

#
# built-in modules
#
import argparse
import glob
from operator import itemgetter
import os
import re
import traceback

try:
    import lxml.etree as etree
except:
    import xml.etree.ElementTree as etree

# get the postprocess virtualenv path from the env_postprocess.xml file
env_file = './env_postprocess.xml'
postprocess_path = ''
standalone = ''
if os.path.isfile(env_file):
    xml_tree = etree.ElementTree()
    xml_tree.parse(env_file)
    for entry_tag in xml_tree.findall('entry'):
        if entry_tag.get('id') == 'POSTPROCESS_PATH':
            postprocess_path = entry_tag.get('value')
        if entry_tag.get('id') == 'STANDALONE':
            standalone = entry_tag.get('value')
else:
    err_msg = ('pp_config ERROR: env_postprocess.xml does not exist in this directory.')
    raise OSError(err_msg)

# check if virtualenv is activated
if hasattr(sys, 'real_prefix'):
    try:
        import cesm_utils
    except:
        #
        # activate the virtual environment that was created by create_python_env.sh
        #
        activate_file = '{0}/cesm-env2/bin/activate_this.py'.format(postprocess_path)
        if not os.path.isfile(activate_file):
            err_msg = ('pp_config ERROR: the virtual environment in {0} does not exist.'.format(postprocess_path) \
                          + 'Please run {0}/create_python_env.sh -cimeroot [cimeroot] -machine [machine name]'.format(postprocess_path))
            raise OSError(err_msg)

        try:
            execfile(activate_file, dict(__file__=activate_file))
        except:
            raise OSError('pp_config ERROR: Unable to activate python virtualenv {0}'.format(activate_file))
else:
    #
    # activate the virtual environment that was created by create_python_env.sh
    #
    activate_file = '{0}/cesm-env2/bin/activate_this.py'.format(postprocess_path)
    if not os.path.isfile(activate_file):
             err_msg = ('pp_config ERROR: the virtual environment in {0} does not exist.'.format(postprocess_path) \
                           + 'Please run {0}/create_python_env.sh -cimeroot [cimeroot] -machine [machine name]'.format(postprocess_path))
             raise OSError(err_msg)

    try:
        execfile(activate_file, dict(__file__=activate_file))
    except:
        raise OSError('pp_config ERROR: Unable to activate python virtualenv {0}'.format(activate_file))


if sys.version_info[0] == 2:
    from ConfigParser import SafeConfigParser as config_parser
else:
    from configparser import ConfigParser as config_parser

#
# import virtualenv installed modules
#
from cesm_utils import cesmEnvLib, processXmlLib
from asaptools import vprinter
import jinja2

# global variables
_scripts = ['timeseries','averages','regrid','diagnostics','xconform']
_machines = ['cheyenne','edison','dav']
_comps = ['atm','ice','lnd','ocn']

# -------------------------------------------------------------------------------
# User input
# -------------------------------------------------------------------------------

def commandline_options():
    """Process the command line arguments.

    """
    parser = argparse.ArgumentParser(
        description=('pp_config: get and set post-processing '
                     'configuration variables.'))

    parser.add_argument('-backtrace', '--backtrace', action='store_true',
                        help='show exception backtraces as extra debugging '
                        'output')

    parser.add_argument('-caseroot','--caseroot', nargs=1, default=['.'],
                        help='path to postprocessing case directory. Defaults to current directory.')

    parser.add_argument('-debug', '--debug', nargs=1, required=False, type=int, default=0,
                        help='debugging verbosity level output: 0 = none, 1 = minimum, 2 = maximum. 0 is default')

    parser.add_argument('-get', '--get', nargs=1, default=[],
                        help='variable name to retreive')

    parser.add_argument('-set', '--set', nargs=1, default=[],
                        help=('variable and value to set in the form: '
                              '"key=value". Values with a "$" must be escaped using "\$" '
                              'unless you want the shell to replace the value with the value of the '
                              'expanded environment variable.'))

    parser.add_argument('-value', '--value', action='store_true',
                        help=('print only the value of the variable.'
                              'Works in conjunction with the --get option'))

    parser.add_argument('--getbatch', nargs=1, required=False, choices=_scripts,
                        help='batch script option.')

    parser.add_argument('-comp', '--comp', nargs=1, required=False, choices=_comps,
                        help='batch script option.')

    parser.add_argument('-machine', '--machine', nargs=1, required=False, choices=_machines,
                        help='machine name used in conjunction with --getbatch option.')

    options = parser.parse_args()
    return options


# -------------------------------------------------------------------------------
# get_batch
# -------------------------------------------------------------------------------
def get_batch(pp_path, script, mach, comp, project, debugMsg):
    """
    print the batch directives only for given script and machine
    """
    found = False
    mach_dict = dict()
    xml_file = os.path.join(pp_path, 'Machines', 'machine_postprocess.xml')
    xml_tree = etree.ElementTree()
    xml_tree.parse(xml_file)
    for xmlmachine in xml_tree.findall('machine'):
        if mach.lower() == xmlmachine.get('name').lower():
            if script == 'timeseries':
                tseries_pes = xmlmachine.find('timeseries_pes')
                pes = tseries_pes.text
                queue = tseries_pes.get('queue').lower()
                ppn = tseries_pes.get('pes_per_node').lower()
                wallclock = tseries_pes.get('wallclock').lower()
                nodes = ''
                if 'nodes' in tseries_pes.attrib:
                    nodes = tseries_pes.get('nodes').lower()
                found = True
            elif script == 'xconform':
                xconform_pes = xmlmachine.find('xconform_pes')
                pes = xconform_pes.text
                queue = xconform_pes.get('queue').lower()
                ppn = xconform_pes.get('pes_per_node').lower()
                wallclock = xconform_pes.get('wallclock').lower()
                nodes = ''
                if 'nodes' in xconform_pes.attrib:
                    nodes = xconform_pes.get('nodes').lower()
                found = True
            elif script in ['averages','regrid','diagnostics'] and len(comp) > 0:
                for comp_xml in xmlmachine.findall("components/component"):
                    compName = comp_xml.get("name").lower()

                    if script == 'averages' and compName == comp:
                        avg = comp_xml.find('averages_pes')
                        pes = avg.text
                        queue = avg.get('queue').lower()
                        ppn = avg.get('pes_per_node').lower()
                        wallclock = avg.get('wallclock').lower()
                        nodes = ''
                        if 'nodes' in avg.attrib:
                            nodes = avg.get('nodes').lower()
                        found = True

                    if script == 'diagnostics' and compName == comp:
                        diags = comp_xml.find('diagnostics_pes')
                        pes = diags.text
                        queue = diags.get('queue').lower()
                        ppn = diags.get('pes_per_node').lower()
                        wallclock = diags.get('wallclock').lower()
                        nodes = ''
                        if 'nodes' in diags.attrib:
                            nodes = diags.get('nodes').lower()
                        found = True

                    if script == 'regrid' and compName == comp:
                        regrid = comp_xml.find('regrid_pes')
                        if regrid is not None:
                            pes = regrid.text
                            queue = regrid.get('queue').lower()
                            ppn = regrid.get('pes_per_node').lower()
                            wallclock = regrid.get('wallclock').lower()
                            nodes = ''
                            if 'nodes' in regrid.attrib:
                                nodes = regrid.get('nodes').lower()
                            found = True
    
    # load up the template and print it out
    if found:
        batchTmpl = 'cylc_batch_{0}.tmpl'.format(mach)
        templateLoader = jinja2.FileSystemLoader( searchpath='{0}/Templates'.format(pp_path) )
        templateEnv = jinja2.Environment( loader=templateLoader )
        template = templateEnv.get_template( batchTmpl )
        templateVars = { 'pes' : pes,
                         'queue' : queue,
                         'nodes' : nodes,
                         'processName' : script,
                         'project' : project,
                         'ppn' : ppn,
                         'wallclock' : wallclock }

        # render this template into the batchdirectives string
        batchdirectives = template.render( templateVars )

        # print the batchdirectives
        print (batchdirectives)
    else:
        msg = "pp_config INFO: no matching XML records found for " \
              "comp='{0}', script='{1}' on machine='{2}' in XML file='{3}'".format(comp, script, mach, xml_file)
        print (msg)

    return 0

# -------------------------------------------------------------------------------
# main
# -------------------------------------------------------------------------------
def main(options):

    debugMsg = vprinter.VPrinter(header='', verbosity=0)
    if options.debug:
        header = 'pp_config: DEBUG... '
        debugMsg = vprinter.VPrinter(header=header, verbosity=options.debug[0])

    case_dir = options.caseroot[0]
    debugMsg("Using case directory : {0}".format(case_dir), header=True, verbosity=1)
    os.chdir(case_dir)
    
    xml_filenames = glob.glob('*.xml')

    xml_trees = []
    for filename in xml_filenames:
        file_path = os.path.join(case_dir, filename)
        file_path = os.path.abspath(file_path)
        if os.path.isfile(file_path):
            xml_trees.append(etree.parse(file_path))
        else:
            msg = 'pp_config WARNING: {0} does not exist.'.format(file_path)
            print(msg)

    # assume that all env*.xml files are the same version in a case
    xml_processor = processXmlLib.post_processing_xml_factory(xml_trees[0])

    envDict = os.environ.copy()
    for tree in xml_trees:
        xml_processor.xml_to_dict(tree, envDict)

    # this should be done in the xml_to_dict!
    # resolve all the xml variables
    envDict = cesmEnvLib.readXML(case_dir, xml_filenames)

    # 'get' user input
    if options.get:
        entry_id = options.get[0]
        if options.value:
            print("{0}".format(envDict[entry_id]))
        else:
            print("{0}={1}".format(entry_id, envDict[entry_id]))

    # 'set' user input
    if options.set:
        key_value = options.set[0].split('=')
        new_entry_id = key_value[0].strip()
        new_entry_value = key_value[1].strip().replace('\n','')

        # get the component name based on the entry_id 
        entry_parts = new_entry_id.split('_')
        comp = entry_parts[0][:3]
        if comp.upper() not in ['ATM','ICE','LND','OCN','CON'] or len(entry_parts[0]) == 3:
            comp = ''

        xml_processor.write(envDict, comp.lower(), new_entry_id, new_entry_value)

    # print out the batch information
    if options.getbatch:
        script = options.getbatch[0]
        if not options.machine:
            print("Option --getbatch requires a matching --machine option be specified")
            return 1
        mach = options.machine[0]
        comp = ''
        if script != 'timeseries' and script != 'xconform':
            if not options.comp:
                print("Option --getbatch with script not set to 'timeseries' requires --comp option be spcified")
                return 1
            comp = options.comp[0]
        get_batch(envDict['POSTPROCESS_PATH'], script, mach, comp, envDict['PROJECT'], debugMsg)

    return 0

# -------------------------------------------------------------------------------
if __name__ == "__main__":

    try:
        options = commandline_options()
        status = main(options)
        sys.exit(status)
    except Exception as error:
        print(str(error))
        if options.backtrace:
            traceback.print_exc()
        sys.exit(1)
