import blast
from fwk.wutils import parseargs
from fwk.coloring import ccolors
import numpy as np

def initBL(Re, Minf, CFL0, nSections, xtrF = [None, None], span=0, verb=None):
    """ Initialize boundary layer solver.
    
    Params
    ------
    - Re: Flow Reynolds number.
    - Minf: Freestream Mach number.
    - CFL0: Initial CFL number for time integration.
    - nSections: Number of sections in the domain.
    - span: Wing span (not used for 2D computations.
    - verb: Verbosity level of the solver.
    
    Return
    ------
    - solver: blast::Driver class.
    """
    if Re<=0.:
        print(ccolors.ANSI_RED, "blast::vUtils Error : Negative Reynolds number.", Re, ccolors.ANSI_RESET)
        raise RuntimeError("Invalid parameter")
    if Minf<0.:
        print(ccolors.ANSI_RED, "blast::vUtils Error : Negative Mach number.", Minf, ccolors.ANSI_RESET)
        raise RuntimeError("Invalid parameter")
    elif Minf>=1.:
        print(ccolors.ANSI_YELLOW, "blast::vUtils Warning : (Super)sonic freestream Mach number.", Minf, ccolors.ANSI_RESET)
    if nSections < 0:
        print(ccolors.ANSI_RED, "blast::vUtils Fatal error : Negative number of sections.", nSections, ccolors.ANSI_RESET)
        raise RuntimeError("Invalid parameter")
    if verb is None:
      args = parseargs()
      verbose = args.verb
    else:
      if not(0<=verb<=3):
        print(ccolors.ANSI_RED, "blast::vUtils Fatal error : verbose not in valid range.", verbose, ccolors.ANSI_RESET)
        raise RuntimeError("Invalid parameter")
      else:
        verbose = verb
    
    for i in range(len(xtrF)):
        if xtrF[i] is None:
            xtrF[i] = -1
        if xtrF[i] != -1 and not(0<= xtrF[i] <= 1):
            raise RuntimeError('Incorrect forced transition location') 

    solver = blast.Driver(Re, Minf, CFL0, nSections, xtrF_top=xtrF[0], xtrF_bot=xtrF[1], _span=span, _verbose=verbose)
    return solver

def initBlast(iconfig, vconfig, iSolver='DART'):
    """Initialize blast coupling objects.
    
    Params
    ------
    - iconfig (dict): Dictionnary to initialize solver 'iSolver'.
    - vconfig (dict): Dictionnary to initialize boundary layer solver.
    - iSolver (string): Name of the inviscid solver to use.

    Return
    ------
    - coupler: coupler object
    - iSolverAPI: api interface of the inviscid solver selected with 'iSolver'.
    - vSolver: blast::Driver class.
    """

    if 'nSections' not in vconfig:
        vconfig['nSections'] = len(vconfig['sections'])
    # Viscous solver
    vSolver = initBL(vconfig['Re'], vconfig['Minf'], vconfig['CFL0'], vconfig['nSections'], xtrF=vconfig['xtrF'])

    # Inviscid solver
    if iSolver == 'DART':
        from blast.interfaces.dart.DartInterface import DartInterface as interface
    else:
        print(ccolors.ANSI_RED + 'Solver '+iSolver+' currently not implemented' + ccolors.ANSI_RESET)
        raise RuntimeError
    iSolverAPI = interface(iconfig, vSolver, vconfig)

    # Coupler
    import blast.coupler as blastCoupler
    coupler = blastCoupler.Coupler(iSolverAPI, vSolver, _maxCouplIter=vconfig['couplIter'], _couplTol=vconfig['couplTol'], _iterPrint=vconfig['iterPrint'], _resetInv=vconfig['resetInv'])
    return coupler, iSolverAPI, vSolver

def mesh(file, pars):
    """Initialize mesh and mesh writer

    Parameters
    ----------
    file : str
        file contaning mesh (w/o extention)
    pars : dict
        parameters for mesh
    """
    import tbox.gmsh as tmsh
    # Load the mesh.
    msh = tmsh.MeshLoader(file,__file__).execute(**pars)
    return msh


def getSolution(vSolver, iSec=0):
    """Extract viscous solution.
    Args
    ----
    - vSolver: blast::Driver class. Driver of the viscous calculations
    - iSec (int): Marker of the section (default: 0)
    """
    if iSec<0:
        raise RuntimeError("blast::viscU Invalid section number", iSec)
    
    solverOutput = vSolver.getSolution(iSec)

    sln = { 'theta'    : solverOutput[0],
            'H'        : solverOutput[1],
            'N'        : solverOutput[2],
            'ue'       : solverOutput[3],
            'Ct'       : solverOutput[4],
            'deltaStar': solverOutput[5],
            'Ret'      : solverOutput[6],
            'Cd'       : solverOutput[7],
            'Cf'       : np.asarray(solverOutput[8])*np.asarray(solverOutput[3])**2,
            'Hstar'    : solverOutput[9],
            'Hstar2'   : solverOutput[10],
            'Hk'       : solverOutput[11],
            'Cteq'     : solverOutput[12],
            'Us'       : solverOutput[13],
            'delta'    : solverOutput[14],
            'x'        : solverOutput[15],
            'xoc'      : solverOutput[16],
            'y'        : solverOutput[17],
            'z'        : solverOutput[18],
            'ueInv'    : solverOutput[19],
            'xtrT'     : vSolver.getxtr(iSec, 0),
            'xtrB'     : vSolver.getxtr(iSec, 1),
            'Cdt_w'    : vSolver.Cdt_sec[iSec],
            'Cdf'      : vSolver.Cdf_sec[iSec],
            'Cdp'      : vSolver.Cdp_sec[iSec]
            }
    return sln

def write(wData, Re, toW=['deltaStar', 'H', 'Hstar', 'Cf', 'Ct', 'ue', 'ueInv', 'delta'], sfx=''):
    """Write the results in bl files
    """
    # Write
    print('Writing file: bl_'+sfx+'.dat...', end = '')
    f = open('bl'+sfx+'.dat', 'w+')

    f.write('$Sectional aerodynamic coefficients\n')
    f.write('             Re             Cdw             Cdp             Cdf         xtr_top         xtr_bot\n')
    f.write('{0:15.6f} {1:15.6f} {2:15.6f} {3:15.6f} {4:15.6f} {5:15.6f}\n'.format(Re, wData['Cdt_w'], wData['Cdp'],
                                                                                   wData['Cdf'], wData['xtrT'],
                                                                                   wData['xtrB']))
    f.write('$Boundary layer variables\n')
    f.write('{0:>15s} {1:>15s} {2:>15s} {3:>15s}'.format('x','y', 'z', 'xoc'))


    for s in toW:
        f.write(' {0:>15s}'.format(s))
    f.write('\n')

    for i in range(len(wData['x'])):
        f.write('{0:>15.6f} {1:>15.6f} {2:>15.6f} {3:>15.6f}'.format(wData['x'][i], wData['y'][i], wData['z'][i], (wData['x'][i] - min(wData['x'])) / (max(wData['x']) - min(wData['x']))))
        for s in toW:
            f.write(' {0:15.6f}'.format(wData[s][i]))
        f.write('\n')

    f.close()
    print('done.')

def read(filename, delim=',', skip=1):
    """Read from file and store in data array
    """
    import io
    import numpy as np
    # read file
    fl = io.open(filename, 'r')
    label = fl.readline().split(',')
    fl.close()
    data = np.loadtxt(filename, delimiter=delim, skiprows=skip)
    return data

def plot(cfg):
    from matplotlib import pyplot as plt
    for i in range(len(cfg['curves'])):
        plt.plot(cfg['curves'][i][:,0], cfg['curves'][i][:,1], cfg['ls'][i], color=cfg['color'][i], lw=cfg['lw'][i], label=cfg['labels'][i])
    if 'yreverse' in cfg and cfg['yreverse'] is True: plt.gca().invert_yaxis()
    if 'xreverse' in cfg and cfg['xreverse'] is True: plt.gca().invert_xaxis()
    if 'title' in cfg: plt.title(cfg['title'])
    if 'xlim' in cfg: plt.xlim(cfg['xlim'])
    if 'ylim' in cfg: plt.xlim(cfg['ylim'])
    if 'legend' in cfg and cfg['legend'] is True: plt.legend(frameon=False)
    if 'xlabel' in cfg: plt.xlabel(cfg['xlabel'])
    if 'ylabel' in cfg: plt.ylabel(cfg['ylabel'])
    plt.show()