import numpy as np
from scipy.interpolate import RBFInterpolator
class RbfInterpolator:
    """
    Radial Basis Function (RBF) Interpolator for inviscid and viscous data.

    Attributes:
    ----------
    _cfg : dict
        Configuration dictionary.
    _ndim : int
        Number of dimensions (2 or 3).
    """
    def __init__(self, cfg, ndim):
        """
        Initialize the RbfInterpolator.

        Parameters:
        ----------
        cfg : dict
            Configuration dictionary.
        ndim : int
            Number of dimensions (must be 2 or 3).
        """
        self._cfg = cfg
        if ndim != 2 and ndim != 3:
            raise ValueError('Number of dimensions must be 2 or 3 but {} was given'.format(ndim))
        self._ndim = ndim

    def inviscidToViscous(self, iDict, vDict):
        """
        Interpolate inviscid data to viscous data.

        Parameters:
        ----------
        iDict : dict
            Inviscid data dictionary.
        vDict : dict
            Viscous data dictionary.
        """
        ## Airfoil
        # Find stagnation point
        for iSec in range(len(vDict)):
            for iReg, reg in enumerate(vDict[iSec]):
                v = np.zeros((reg.nodesCoord.shape[0], 3))
                M = np.zeros(reg.nodesCoord.shape[0])
                rho = np.zeros(reg.nodesCoord.shape[0])
                for iDim in range(3):
                    v[:,iDim] = self.__rbfToSection(iDict[iReg].nodesCoord[:,:(self._ndim if iDict[iReg].name == 'iWing' else 1)], iDict[iReg].V[:,iDim], reg.nodesCoord[:,:(self._ndim if 'vAirfoil' in reg.name else 1)])
                M = self.__rbfToSection(iDict[iReg].nodesCoord[:,:(self._ndim if iDict[iReg].name == 'iWing' else 1)], iDict[iReg].M, reg.nodesCoord[:,:(self._ndim if 'vAirfoil' in reg.name else 1)])
                rho = self.__rbfToSection(iDict[iReg].nodesCoord[:,:(self._ndim if iDict[iReg].name == 'iWing' else 1)], iDict[iReg].Rho, reg.nodesCoord[:,:(self._ndim if 'vAirfoil' in reg.name else 1)])
                vDict[iSec][iReg].updateVars(v, M, rho)

    def viscousToInviscid(self, iDict, vDict):
        """
        Interpolate viscous data to inviscid data.

        Parameters:
        ----------
        iDict : dict
            Inviscid data dictionary.
        vDict : dict
            Viscous data dictionary.
        """
        if self._ndim == 2:
            for iReg, reg in enumerate(iDict):
                iDict[iReg].blowingVel = self.__rbfToSection(vDict[0][iReg].elemsCoordTr[:,:(self._ndim-1 if 'vAirfoil' in reg.name else 1)], vDict[0][iReg].blowingVel, reg.elemsCoordTr[:,:(self._ndim-1 if reg.name == 'iWing' else 1)])
        elif self._ndim == 3:
            for iReg, reg in enumerate(iDict):
                viscElemsCoord = np.zeros((0,3))
                viscBlowing = np.zeros(0)
                for iSec, sec in enumerate(vDict):
                    viscElemsCoord = np.row_stack((viscElemsCoord, sec[iReg].elemsCoord[:,:3]))
                    viscBlowing = np.concatenate((viscBlowing, sec[iReg].blowingVel))
                if 'Sym' in self._cfg:
                    for s in self._cfg['Sym']:
                        dummy = viscElemsCoord.copy()
                        dummy[:,1] = (dummy[:,1] - s)*(-1) + s
                        viscElemsCoord = np.row_stack((viscElemsCoord, dummy))
                        viscBlowing = np.concatenate((viscBlowing, viscBlowing))
                reg.blowingVel = self.__rbfToSection(viscElemsCoord, viscBlowing, reg.elemsCoord[:,:3])

    def __rbfToSection(self, x, var, xs):
        """
        Perform RBF interpolation.

        Parameters:
        ----------
        x : array_like
            Input coordinates.
        var : array_like
            Variable values at input coordinates.
        xs : array_like
            Output coordinates.

        Returns:
        -------
        array_like
            Interpolated variable values at output coordinates.
        """
        if np.all(var == var[0]):
            varOut = RBFInterpolator(x, var, neighbors=1, kernel='linear', smoothing=0.0, degree=0)(xs)
        else:
            varOut = RBFInterpolator(x, var, neighbors=self._cfg['neighbors'], kernel=self._cfg['rbftype'], smoothing=self._cfg['smoothing'], degree=self._cfg['degree'])(xs)
        return varOut