# Copyright 2024 University of Liège
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# RBF interpolator class
# Paul Dechamps

from blast.interfaces.interpolators.blInterpolator import Interpolator
import numpy as np
from scipy.interpolate import RBFInterpolator
class RbfInterpolator(Interpolator):
    """
    Radial Basis Function (RBF) Interpolator for inviscid and viscous data.

    Attributes:
    ----------
    _neighbors : int
        Number of neighbors to use for interpolation.
    _rbftype : str
        Type of RBF kernel to use.
    _smoothing : float
        Smoothing factor.
    _degree : int
        Degree of the polynomial kernel.
    _sym : list
        List of symmetry planes.
    """
    def __init__(self, ndim, **kwargs):
        """
        Initialize the RbfInterpolator.

        Parameters:
        ----------
        ndim : int
            Number of dimensions (must be 2 or 3).
        kwargs : dict
            Optional arguments.
            
        Optional arguments:
        -------------------
        nneighbors : int
            Number of neighbors to use for interpolation. Default is 10.
        rbftype : str
            Type of RBF kernel to use. Default is 'linear'.
        smoothing : float
            Smoothing factor. Default is 0.0.
        degree : int
            Degree of the polynomial kernel. Default is 0.
        """
        super().__init__(ndim)
        self._neighbors = kwargs.get('neighbors', 10)
        self._rbftype = kwargs.get('rbftype', 'linear')
        self._smoothing = kwargs.get('smoothing', 0.0)
        self._degree = kwargs.get('degree', 0)
        self._sym = kwargs.get('Sym', [])

    def inviscidToViscous(self, iDict, vDict):
        """
        Interpolate inviscid data to viscous data.

        Parameters:
        ----------
        iDict : dict
            Inviscid data dictionary.
        vDict : dict
            Viscous data dictionary.
        """
        ## Airfoil
        # Find stagnation point
        for iSec in range(len(vDict)):
            for iReg, reg in enumerate(vDict[iSec]):
                v = np.zeros((reg.nodesCoord.shape[0], 3))
                M = np.zeros(reg.nodesCoord.shape[0])
                rho = np.zeros(reg.nodesCoord.shape[0])
                for iDim in range(3):
                    v[:,iDim] = self.__rbfToSection(iDict[iReg].nodesCoord[:,:(self.ndim if iDict[iReg].name == 'iWing' else 1)], iDict[iReg].V[:,iDim], reg.nodesCoord[:,:(self.ndim if 'vAirfoil' in reg.name else 1)])
                M = self.__rbfToSection(iDict[iReg].nodesCoord[:,:(self.ndim if iDict[iReg].name == 'iWing' else 1)], iDict[iReg].M, reg.nodesCoord[:,:(self.ndim if 'vAirfoil' in reg.name else 1)])
                rho = self.__rbfToSection(iDict[iReg].nodesCoord[:,:(self.ndim if iDict[iReg].name == 'iWing' else 1)], iDict[iReg].Rho, reg.nodesCoord[:,:(self.ndim if 'vAirfoil' in reg.name else 1)])
                vDict[iSec][iReg].updateVars(v, M, rho)

    def viscousToInviscid(self, iDict, vDict):
        """
        Interpolate viscous data to inviscid data.

        Parameters:
        ----------
        iDict : dict
            Inviscid data dictionary.
        vDict : dict
            Viscous data dictionary.
        """
        if self.ndim == 2:
            for iReg, reg in enumerate(iDict):
                iDict[iReg].blowingVel = self.__rbfToSection(vDict[0][iReg].elemsCoordTr[:,:(self.ndim-1 if 'vAirfoil' in reg.name else 1)], vDict[0][iReg].blowingVel, reg.elemsCoordTr[:,:(self.ndim-1 if reg.name == 'iWing' else 1)])
        elif self.ndim == 3:
            for iReg, reg in enumerate(iDict):
                viscElemsCoord = np.zeros((0,3))
                viscBlowing = np.zeros(0)
                for iSec, sec in enumerate(vDict):
                    viscElemsCoord = np.row_stack((viscElemsCoord, sec[iReg].elemsCoord[:,:3]))
                    viscBlowing = np.concatenate((viscBlowing, sec[iReg].blowingVel))
                for s in self._sym:
                    dummy = viscElemsCoord.copy()
                    dummy[:,1] = (dummy[:,1] - s)*(-1) + s
                    viscElemsCoord = np.row_stack((viscElemsCoord, dummy))
                    viscBlowing = np.concatenate((viscBlowing, viscBlowing))
                reg.blowingVel = self.__rbfToSection(viscElemsCoord, viscBlowing, reg.elemsCoord[:,:3])

    def __rbfToSection(self, x, var, xs):
        """
        Perform RBF interpolation.

        Parameters:
        ----------
        x : array_like
            Input coordinates.
        var : array_like
            Variable values at input coordinates.
        xs : array_like
            Output coordinates.

        Returns:
        -------
        array_like
            Interpolated variable values at output coordinates.
        """
        if np.all(var == var[0]):
            varOut = RBFInterpolator(x, var, neighbors=1, kernel='linear', smoothing=0.0, degree=0)(xs)
        else:
            varOut = RBFInterpolator(x, var, neighbors=self._neighbors, kernel=self._rbftype, smoothing=self._smoothing, degree=self._degree)(xs)
        return varOut