import blast
from fwk.wutils import parseargs
from fwk.coloring import ccolors
import numpy as np

def initBL(Re, Minf, CFL0, nSections, xtrF = [None, None], span=0, verb=None):
    """ Initialize boundary layer solver.
    
    Params
    ------
    - Re: Flow Reynolds number.
    - Minf: Freestream Mach number.
    - CFL0: Initial CFL number for time integration.
    - nSections: Number of sections in the domain.
    - span: Wing span (not used for 2D computations.
    - verb: Verbosity level of the solver.
    
    Return
    ------
    - solver: blast::Driver class.
    """
    if Re<=0.:
        print(ccolors.ANSI_RED, "blast::vUtils Error : Negative Reynolds number.", Re, ccolors.ANSI_RESET)
        raise RuntimeError("Invalid parameter")
    if Minf<0.:
        print(ccolors.ANSI_RED, "blast::vUtils Error : Negative Mach number.", Minf, ccolors.ANSI_RESET)
        raise RuntimeError("Invalid parameter")
    elif Minf>=1.:
        print(ccolors.ANSI_YELLOW, "blast::vUtils Warning : (Super)sonic freestream Mach number.", Minf, ccolors.ANSI_RESET)
    if nSections < 0:
        print(ccolors.ANSI_RED, "blast::vUtils Fatal error : Negative number of sections.", nSections, ccolors.ANSI_RESET)
        raise RuntimeError("Invalid parameter")
    if verb is None:
      args = parseargs()
      verbose = args.verb
    else:
      if not(0<=verb<=3):
        print(ccolors.ANSI_RED, "blast::vUtils Fatal error : verbose not in valid range.", verbose, ccolors.ANSI_RESET)
        raise RuntimeError("Invalid parameter")
      else:
        verbose = verb
    
    for i in range(len(xtrF)):
        if xtrF[i] is None:
            xtrF[i] = -1
        if xtrF[i] != -1 and not(0<= xtrF[i] <= 1):
            raise RuntimeError('Incorrect forced transition location') 

    solver = blast.Driver(Re, Minf, CFL0, nSections, xtrF_top=xtrF[0], xtrF_bot=xtrF[1], _span=span, _verbose=verbose)
    return solver

def initBlast(iconfig, vconfig, iSolver='DART'):
    """Initialize blast coupling objects.
    
    Params
    ------
    - iconfig (dict): Dictionnary to initialize solver 'iSolver'.
    - vconfig (dict): Dictionnary to initialize boundary layer solver.
    - iSolver (string): Name of the inviscid solver to use.

    Return
    ------
    - coupler: coupler object
    - iSolverAPI: api interface of the inviscid solver selected with 'iSolver'.
    - vSolver: blast::Driver class.
    """

    if 'nSections' not in vconfig:
        vconfig['nSections'] = len(vconfig['sections'])
    if 'sfx' not in vconfig:
        vconfig['sfx'] = ''
    # Viscous solver
    vSolver = initBL(vconfig['Re'], vconfig['Minf'], vconfig['CFL0'], vconfig['nSections'], xtrF=vconfig['xtrF'])

    # Inviscid solver
    if iSolver == 'DART':
        from blast.interfaces.dart.DartInterface import DartInterface as interface
    else:
        print(ccolors.ANSI_RED + 'Solver '+iSolver+' currently not implemented' + ccolors.ANSI_RESET)
        raise RuntimeError
    iSolverAPI = interface(iconfig, vSolver, vconfig)

    # Coupler
    import blast.coupler as blastCoupler
    coupler = blastCoupler.Coupler(iSolverAPI, vSolver, _maxCouplIter=vconfig['couplIter'], _couplTol=vconfig['couplTol'], _iterPrint=vconfig['iterPrint'], _resetInv=vconfig['resetInv'], sfx=vconfig['sfx'])
    return coupler, iSolverAPI, vSolver

def mesh(file, pars):
    """Initialize mesh and mesh writer

    Parameters
    ----------
    file : str
        file contaning mesh (w/o extention)
    pars : dict
        parameters for mesh
    """
    import tbox.gmsh as tmsh
    # Load the mesh.
    msh = tmsh.MeshLoader(file,__file__).execute(**pars)
    return msh


def getSolution(sections, write=False, toW=['x', 'xelm', 'theta', 'H', 'deltaStar', 'Cf', 'blowingVelocity'], sfx=''):
    """ Extract viscous solution.
    
    Parameters
    ----------
    
        - sections: list of blast::BoundaryLayer class. Sections of the boundary layer solver.
        - write: bool. Flag to write the results in a file.
        - sfx: str. Suffix to add to the file name.
    
    Return
    ------
    
        - sol: list[dicts]. Dictionary containing the boundary layer solution.
    """

    nVar = sections[0][0].getnVar()
    sol = []

    varKeys = ['theta', 'H', 'N', 'ue', 'Ct']
    attrKeys = ['x', 'y', 'z', 'xoc', 'deltaStar', \
               'cd', 'cf', 'Hstar', 'Hstar2', 'Hk', 'ctEq', 'us', 'delta',\
                'vt', 'rhoe', 'Ret']
    elemsKeys = ['xelm', 'yelm', 'zelm']

    for iSec, sec in enumerate(sections):
        nNodes = 0
        for side in sec:
            nNodes += len(side.x)
        
        sol.append({})

        for key in elemsKeys:
            sol[iSec][key] = np.zeros(nNodes-len(sec))
        for key in varKeys + attrKeys:
            sol[iSec][key] = np.zeros(nNodes)

        for side in sec:
            nNodes_side = side.nNodes
            if side.name == "upper":
                for key in attrKeys:
                    sol[iSec][key][:nNodes_side] = np.flip(getattr(side, key))
                for k, key in enumerate(varKeys):
                    sol[iSec][key][:nNodes_side] = np.flip(side.u[k::nVar])
            else:
                if side.name == "lower":
                    slicer = np.s_[sec[0].x.size():sec[0].x.size()+sec[1].x.size()]
                elif side.name == "wake":
                    slicer = np.s_[sec[0].x.size()+sec[1].x.size():]

                for key in attrKeys:
                    sol[iSec][key][slicer] = getattr(side, key)
                for k, key in enumerate(varKeys):
                    sol[iSec][key][slicer] = side.u[k::nVar]
        
            # Compute elements cgs
            elemsCoord = np.zeros((nNodes_side-1, 3))
            for k, key in enumerate(elemsKeys):
                elemsCoord[:,k] = 0.5 * (np.array(getattr(side, key[0]))[1:] + np.array(getattr(side, key[0]))[:-1])
            if side.name == "upper":
                for k, key in enumerate(elemsKeys):
                    sol[iSec][key][:nNodes_side-1] = np.flip(elemsCoord[:,k])
            else:
                if side.name == "lower":
                    slicer = np.s_[(sec[0].x.size()-1):(sec[0].x.size()-1)+(sec[1].x.size()-1)]
                elif side.name == "wake":
                    slicer = np.s_[(sec[0].x.size()-1)+(sec[1].x.size()-1):]
                for k, key in enumerate(elemsKeys):
                    sol[iSec][key][slicer] = elemsCoord[:,k]
        sol[iSec]['blowingVelocity'] = np.concatenate((np.flip(sec[0].blowingVelocity),\
                                                       sec[1].blowingVelocity, sec[2].blowingVelocity))

    if write:
        import datetime
        if toW == 'all':
            toW = varKeys + attrKeys
        import os
        if not os.path.exists('blSlices'):
            os.makedirs('blSlices')
        print('Writing file: /blSlices/bl'+sfx+'.dat...', end = '')
        for iSec, sec in enumerate(sections):
            f = open('blSlices/bl'+str(iSec)+sfx+'.dat', 'w+')
            f.write('# BLASTER boundary layer output file\n')
            f.write('# Generated on '+datetime.datetime.now().strftime('%Y-%m-%d %H:%M')+'\n')
            f.write('\n')
            f.write('# Section '+str(iSec)+'\n')
            f.write('# ')
            for side in sec:
                f.write('xtr_'+side.name+' = '+str(round(side.xtr,6))+' ')
            f.write('\n')
            f.write('# Stagnation point at index '+str(sec[0].x.size()-1)+'\n')
            f.write('# Wake at index '+str(sec[0].x.size()+sec[1].x.size())+'\n')
        
            for i, key in enumerate(toW):
                f.write(' {0:>21s}'.format(str(i)+' '+key))
            f.write('\n')
            for i in range(len(sol[iSec]['x'])):
                for key in toW:
                    f.write(' {0:21.16f}'.format(sol[iSec][key][i]))
                f.write('\n')
            f.close()
        print('done.')
    return sol

def read(filename, delim=',', skip=1):
    """Read from file and store in data array
    """
    import io
    import numpy as np
    # read file
    fl = io.open(filename, 'r')
    label = fl.readline().split(',')
    fl.close()
    data = np.loadtxt(filename, delimiter=delim, skiprows=skip)
    return data

def plot(cfg):
    from matplotlib import pyplot as plt
    for i in range(len(cfg['curves'])):
        plt.plot(cfg['curves'][i][:,0], cfg['curves'][i][:,1], cfg['ls'][i], color=cfg['color'][i], lw=cfg['lw'][i], label=cfg['labels'][i])
    if 'yreverse' in cfg and cfg['yreverse'] is True: plt.gca().invert_yaxis()
    if 'xreverse' in cfg and cfg['xreverse'] is True: plt.gca().invert_xaxis()
    if 'title' in cfg: plt.title(cfg['title'])
    if 'xlim' in cfg: plt.xlim(cfg['xlim'])
    if 'ylim' in cfg: plt.ylim(cfg['ylim'])
    if 'legend' in cfg and cfg['legend'] is True: plt.legend(frameon=False)
    if 'xlabel' in cfg: plt.xlabel(cfg['xlabel'])
    if 'ylabel' in cfg: plt.ylabel(cfg['ylabel'])
    plt.show()