#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Codegeneration utilities.
"""

# Copyright (C) 2008-2009 Martin Sandve Alnes and Simula Resarch Laboratory
#
# This file is part of SyFi.
#
# SyFi is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# SyFi is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with SyFi. If not, see <http://www.gnu.org/licenses/>.
#
# Modified by Kent-Andre Mardal, 2010.
#
# First added:  2008-08-13
# Last changed: 2009-05-15

from itertools import chain
import re
import ufl
import ufc_utils
import SyFi

from ufl.classes import Measure
from ufl.algorithms import FormData, preprocess

from dolfin_utils.wrappers import generate_dolfin_code, UFCFormNames

from sfc.common.utilities import unique
from sfc.common.output import sfc_assert, sfc_error, sfc_info, sfc_debug
#from sfc.common import ParameterDict
from sfc.common.options import default_parameters
from sfc.representation import ElementRepresentation, FormRepresentation, \
        CellIntegralRepresentation, ExteriorFacetIntegralRepresentation, \
        InteriorFacetIntegralRepresentation
from sfc.codegeneration.formcg import FormCG
from sfc.codegeneration.dofmapcg import DofMapCG
from sfc.codegeneration.finiteelementcg import FiniteElementCG
from sfc.codegeneration.cellintegralcg import CellIntegralCG
from sfc.codegeneration.exteriorfacetintegralcg import ExteriorFacetIntegralCG
from sfc.codegeneration.interiorfacetintegralcg import InteriorFacetIntegralCG

# version %(version)s
_header_template = r"""/*
 * %(name)s.h
 *
 * This file was automatically generated by SFC.
 *
 * http://www.fenics.org/syfi/
 *
 */

#ifndef GENERATED_SFC_CODE_%(name)s
#define GENERATED_SFC_CODE_%(name)s

#include "ufc.h"

%(includes)s

//namespace sfc
//{
%(body)s
//}

#endif
"""

# version %(version)s
_implementation_template = r"""/*
 * %(name)s.cpp
 *
 * This file was automatically generated by SFC.
 *
 * http://www.fenics.org/syfi/
 *
 */

#include "ufc.h"

#include "%(name)s.h"

%(includes)s

//namespace sfc
//{
%(body)s
//}
"""

def common_system_headers():
    return ("iostream", "cmath", "stdexcept", "cstring")

def apply_code_dict(format_string, code_dict):
    "Template formatting with improved checking for template argument mismatch."
    expected_keys = set(re.findall('%\((\w+)\)', format_string))
    actual_keys = set(code_dict.keys())

    missing_keys = expected_keys.difference(actual_keys)
    if missing_keys:
        print "When formatting code using template, missing keys:"
        print "\n".join(sorted(missing_keys))
        print

    #superfluous_keys = actual_keys.difference(expected_keys)
    #if superfluous_keys:
    #    print "When formatting code using template, got superfluous keys:"
    #    print "\n".join(sorted(superfluous_keys))
    #    print

    return format_string % code_dict

def generate_finite_element_code(ferep):
    sfc_debug("Entering generate_finite_element_code")
    classname = ferep.finite_element_classname
    cg = FiniteElementCG(ferep)
    code_dict = cg.generate_code_dict()
    supportcode = cg.generate_support_code()

    hcode = apply_code_dict(ufc_utils.finite_element_header, code_dict)
    ccode = supportcode + "\n"*3
    ccode += apply_code_dict(ufc_utils.finite_element_implementation, code_dict)

    includes = cg.hincludes() + cg.cincludes()

    system_headers = common_system_headers()
    local_headers = unique("%s.h" % e.finite_element_classname for e in ferep.sub_elements)

    hincludes = "\n".join('#include "%s"' % inc for inc in cg.hincludes())
    cincludes =  "\n".join('#include <%s>' % f for f in system_headers)
    cincludes += "\n"
    cincludes += "\n".join('#include "%s"' % inc for inc in chain(cg.cincludes(), local_headers))
    
    hcode = _header_template         % { "body": hcode, "name": classname, "includes": hincludes }
    ccode = _implementation_template % { "body": ccode, "name": classname, "includes": cincludes }
    sfc_debug("Leaving generate_finite_element_code")
    return classname, (hcode, ccode), includes

def generate_dof_map_code(ferep):
    sfc_debug("Entering generate_dof_map_code")
    classname = ferep.dof_map_classname
    cg = DofMapCG(ferep)
    code_dict = cg.generate_code_dict()
    supportcode = cg.generate_support_code()

    hcode = apply_code_dict(ufc_utils.dofmap_header, code_dict)
    ccode = supportcode + "\n"*3
    ccode += apply_code_dict(ufc_utils.dofmap_implementation, code_dict)
    
    includes = cg.hincludes() + cg.cincludes()
    
    system_headers = common_system_headers()
    local_headers = unique("%s.h" % e.dof_map_classname for e in ferep.sub_elements)

    hincludes = "\n".join('#include "%s"' % inc for inc in cg.hincludes())
    cincludes =  "\n".join('#include <%s>' % f for f in system_headers)
    cincludes += "\n"
    cincludes += "\n".join('#include "%s"' % inc for inc in chain(cg.cincludes(), local_headers))
    
    hcode = _header_template         % { "body": hcode, "name": classname, "includes": hincludes }
    ccode = _implementation_template % { "body": ccode, "name": classname, "includes": cincludes }
    sfc_debug("Leaving generate_dof_map_code")
    return classname, (hcode, ccode), includes

def generate_cell_integral_code(integrals, formrep):
    sfc_debug("Entering generate_cell_integral_code")
    itgrep = CellIntegralRepresentation(integrals, formrep)

    cg = CellIntegralCG(itgrep)
    code_dict = cg.generate_code_dict()
    supportcode = cg.generate_support_code()

    # TODO: Support code might be placed in unnamed namespace?
    hcode = apply_code_dict(ufc_utils.cell_integral_header, code_dict)
    ccode = supportcode + "\n"*3 + apply_code_dict(ufc_utils.cell_integral_implementation, code_dict)
    
    includes = cg.hincludes() + cg.cincludes()

    system_headers = common_system_headers()
    
    hincludes = "\n".join('#include "%s"' % inc for inc in cg.hincludes())
    cincludes =  "\n".join('#include <%s>' % f for f in system_headers)
    cincludes += "\n"
    cincludes += "\n".join('#include "%s"' % inc for inc in cg.cincludes())
    
    hcode = _header_template         % { "body": hcode, "name": itgrep.classname, "includes": hincludes }
    ccode = _implementation_template % { "body": ccode, "name": itgrep.classname, "includes": cincludes }
    sfc_debug("Leaving generate_cell_integral_code")
    return itgrep.classname, (hcode, ccode), includes

def generate_exterior_facet_integral_code(integrals, formrep):
    sfc_debug("Entering generate_exterior_facet_integral_code")
    itgrep = ExteriorFacetIntegralRepresentation(integrals, formrep)

    cg = ExteriorFacetIntegralCG(itgrep)
    code_dict = cg.generate_code_dict()
    supportcode = cg.generate_support_code()

    # TODO: Support code might be placed in unnamed namespace?
    hcode = apply_code_dict(ufc_utils.exterior_facet_integral_header, code_dict)
    ccode = supportcode + "\n"*3 + apply_code_dict(ufc_utils.exterior_facet_integral_implementation, code_dict)
    
    includes = cg.hincludes() + cg.cincludes()

    system_headers = common_system_headers()
    
    hincludes = "\n".join('#include "%s"' % inc for inc in cg.hincludes())
    cincludes =  "\n".join('#include <%s>' % f for f in system_headers)
    cincludes += "\n"
    cincludes += "\n".join('#include "%s"' % inc for inc in cg.cincludes())
    
    hcode = _header_template         % { "body": hcode, "name": itgrep.classname, "includes": hincludes }
    ccode = _implementation_template % { "body": ccode, "name": itgrep.classname, "includes": cincludes }
    sfc_debug("Leaving generate_exterior_facet_integral_code")
    return itgrep.classname, (hcode, ccode), includes

def generate_interior_facet_integral_code(integrals, formrep):
    sfc_debug("Entering generate_interior_facet_integral_code")
    itgrep = InteriorFacetIntegralRepresentation(integrals, formrep)
    
    cg = InteriorFacetIntegralCG(itgrep)
    code_dict = cg.generate_code_dict()
    supportcode = cg.generate_support_code()

    hcode = apply_code_dict(ufc_utils.interior_facet_integral_header, code_dict)
    ccode = supportcode + "\n"*3 + apply_code_dict(ufc_utils.interior_facet_integral_implementation, code_dict)
    
    includes = cg.hincludes() + cg.cincludes()

    system_headers = common_system_headers()
    
    hincludes = "\n".join('#include "%s"' % inc for inc in cg.hincludes())
    cincludes =  "\n".join('#include <%s>' % f for f in system_headers)
    cincludes += "\n"
    cincludes += "\n".join('#include "%s"' % inc for inc in cg.cincludes())
    
    hcode = _header_template         % { "body": hcode, "name": itgrep.classname, "includes": hincludes }
    ccode = _implementation_template % { "body": ccode, "name": itgrep.classname, "includes": cincludes }
    sfc_debug("Leaving generate_interior_facet_integral_code")
    return itgrep.classname, (hcode, ccode), includes

def generate_form_code(formrep):
    sfc_debug("Entering generate_form_code")
    cg = FormCG(formrep)
    code_dict = cg.generate_code_dict()
    supportcode = cg.generate_support_code()

    hcode = apply_code_dict(ufc_utils.form_header, code_dict)
    ccode = supportcode + "\n"*3 + apply_code_dict(ufc_utils.form_implementation, code_dict)
    
    includes = cg.hincludes() + cg.cincludes()

    system_headers = common_system_headers()
    local_headers = unique(chain(formrep.fe_names, formrep.dm_names, \
                        sorted(formrep.itg_names.values()), cg.cincludes()))

    hincludes =  "\n".join('#include "%s"' % inc for inc in cg.hincludes())
    cincludes =  "\n".join('#include <%s>' % f for f in system_headers)
    cincludes += "\n"
    cincludes += "\n".join('#include "%s.h"' % f for f in local_headers)
    
    hcode = _header_template         % { "body": hcode, "name": formrep.classname, "includes": hincludes }
    ccode = _implementation_template % { "body": ccode, "name": formrep.classname, "includes": cincludes }
    sfc_debug("Leaving generate_form_code")
    return (hcode, ccode), includes

def write_file(filename, text):
    f = open(filename, "w")
    f.write(text)
    f.close()

def write_code(classname, code):
    sfc_debug("Entering write_code")
    if isinstance(code, tuple):
        # Code is split in a header and implementation file
        hcode, ccode = code
        hname = classname + ".h"
        cname = classname + ".cpp"
        open(hname, "w").write(hcode)
        open(cname, "w").write(ccode)
        ret = (hname, cname)
    else:
        # All code is combined in a header file
        name = classname + ".h"
        open(name, "w").write(code)
        ret = (name,)
    sfc_debug("Leaving write_code")
    return ret

def compiler_input(input, objects=None, common_cell=None):
    """Map different kinds of input to a list of
    UFL elements and a list of FormData instances.
    
    The following input formats are allowed:
    - ufl.Form
    - ufl.algorithms.FormData
    - ufl.FiniteElementBase
    - list of the above

    Returns:
        elements, formdatas
    """
    sfc_debug("Entering compiler_input")
    sfc_assert(input, "Got no input!")
    fd = []
    fe = []
    if not isinstance(input, list):
        input = [input]
    for d in input:
        if isinstance(d, ufl.form.Form):
            d = d.compute_form_data(object_names=objects, common_cell=common_cell)
        if isinstance(d, ufl.algorithms.formdata.FormData):
            fd.append(d)
            fe.extend(d.sub_elements)
        elif isinstance(d, ufl.FiniteElementBase):
            if d.cell().is_undefined():
                sfc_assert(common_cell and not common_cell.is_undefined(),
                           "Element no defined cell, cannot compile this.")
                d = d.reconstruct(cell=common_cell)
                sfc_assert(not d.cell().is_undefined(), "Still undefined?")
            fe.append(d)
        else:
            sfc_error("Not a FormData or FiniteElementBase object: %s" % str(type(d)))
    fe = sorted(set(fe))
    sfc_debug("Leaving compiler_input")
    for x in fe:
        sfc_assert(not x.cell().is_undefined(), "Can never have undefined cells at this point!")
    for x in fd:
        sfc_assert(not x.cell.is_undefined(), "Can never have undefined cells at this point!")
    return fe, fd

dolfin_header_template = """/*
 * DOLFIN wrapper code generated by the SyFi Form Compiler.
 */

%s
"""

def generate_code(input, objects, options = None, common_cell=None):
    """Generate code from input and options.
    
    @param input:
        TODO
    @param options:
        TODO
    """
    sfc_debug("Entering generate_code")

    if options is None:
        options = default_parameters()

    ufl_elements, formdatas = compiler_input(input, objects, common_cell=common_cell)

    filenames = []
    needed_files = []
    formnames = []

    # Generate UFC code for elements
    element_reps = {}
    last_nsd = None
    generated_elements = set()
    for e in ufl_elements:
        # Initialize global variables in SyFi with the right space dimension (not very nice!)
        nsd = e.cell().geometric_dimension()
        if not nsd == last_nsd and isinstance(nsd, int): 
            SyFi.initSyFi(nsd)
            last_nsd = nsd
        
        # Construct ElementRepresentation objects for all elements
        quad_rule = None # TODO: What to do with this one?
        assert not "Quadrature" in e.family() # No quadrature rule defined!
        if e in element_reps:
            continue

        erep = ElementRepresentation(e, quad_rule, options, element_reps)
        element_reps[e] = erep
        
        # Build flat list of all subelements
        todo = [erep]
        ereps = []
        while todo:
            erep = todo.pop()
            # Skip already generated!
            if not erep.ufl_element in generated_elements:
                generated_elements.add(erep.ufl_element)
                ereps.append(erep)
            # Queue subelements for inspection
            for subrep in erep.sub_elements:
                todo.append(subrep)

        for erep in ereps:
            # Generate code for finite_element
            classname, code, includes = generate_finite_element_code(erep)
            filenames.extend( write_code(classname, code) )
            needed_files.extend(includes)
            
            # Generate code for dof_map
            classname, code, includes = generate_dof_map_code(erep)
            filenames.extend( write_code(classname, code) )
            needed_files.extend(includes)
    
    # Generate UFC code for forms
    for formdata in formdatas:
        # Initialize global variables in SyFi with the right space dimension (not very nice!)
        nsd = formdata.geometric_dimension
        if not nsd == last_nsd: 
            SyFi.initSyFi(nsd)
            last_nsd = nsd
        
        # This object extracts and collects various information about the ufl form
        formrep = FormRepresentation(formdata, element_reps, options)

        pf = formdata.preprocessed_form

        # TODO: Get integrals from formrep, in case of modifications there? Or rather make sure that FormRep doesn't do any such things.
        ig = pf.integral_groups()
        
        # Generate code for cell integrals
        for domain in pf.domains(Measure.CELL): 
            integrals = ig[domain]
            classname, code, includes = generate_cell_integral_code(integrals, formrep)
            filenames.extend( write_code(classname, code) )
            needed_files.extend(includes)
        
        # Generate code for exterior facet integrals
        for domain in pf.domains(Measure.EXTERIOR_FACET): 
            integrals = ig[domain]
            classname, code, includes = generate_exterior_facet_integral_code(integrals, formrep)
            filenames.extend( write_code(classname, code) )
            needed_files.extend(includes)
        
        # Generate code for interior facet integrals
        for domain in pf.domains(Measure.INTERIOR_FACET): 
            integrals = ig[domain]
            classname, code, includes = generate_interior_facet_integral_code(integrals, formrep)
            filenames.extend( write_code(classname, code) )
            needed_files.extend(includes)
        
        # Generate code for form!
        code, includes = generate_form_code(formrep)
        filenames.extend( write_code(formrep.classname, code) )
        needed_files.extend(includes)

        # Collect classnames for use with dolfin wrappers
        namespace = "" # "sfc::"
        ufc_form_classname = namespace + formrep.classname
        ufc_finite_element_classnames = [namespace + name for name in formrep.fe_names]
        ufc_dof_map_classnames        = [namespace + name for name in formrep.dm_names]

        fn = UFCFormNames(formdata.name, formdata.coefficient_names,
                          ufc_form_classname, ufc_finite_element_classnames, ufc_dof_map_classnames)
        formnames.append(fn)
    
    # Get other needed files:
    if needed_files:
        raise NotImplementedError("FIXME: Implement fetching non-ufc-class files like DofPtv and quadrature rule files.")

    filenames = list(unique(chain(filenames, needed_files)))
    hfilenames = [f for f in filenames if f.endswith(".h")]
    cfilenames = [f for f in filenames if f.endswith(".cpp")]

    # Generate DOLFIN wrapper code
    if options.code.dolfin_wrappers:
        if not formnames:
            print "NOT generating dolfin wrappers, missing forms!" # TODO: Generate FunctionSpaces for elements?
        else:
            prefix = options.code.prefix
            header = dolfin_header_template % "\n".join('#include "%s"' % h for h in hfilenames)
            sfc_info("Generating DOLFIN wrapper code, formnames are %s." % "\n".join(map(str,formnames)))
            dolfin_code = generate_dolfin_code(prefix, header, formnames)
            dolfin_filename = "%s.h" % prefix
            write_file(dolfin_filename, dolfin_code)
            hfilenames.append(dolfin_filename)

    sfc_debug("Leaving generate_code")
    return hfilenames, cfilenames

