# interpolation.py
#
# Interpolation routines for scalar and vector fields.
"""
Interpolation routines for scalar and vector fields.
"""


def vec_int(xyz, field, dxyz, oxyz, nxyz, interpolation="trilinear", splines=None):
    """
    vec_int(xyz, field, dxyz, oxyz, nxyz, interpolation='trilinear')

    Interpolates the field around position xyz.

    Parameters
    ----------
    xyz : ndarray
        Position vector around which will be interpolated.

    field : ndarray
        Vector field to be interpolated with shape [nz, ny, nx].

    dxyz : ndarray
        Array with the three deltas.

    oxyz : ndarray
        Array with the position of the origin.

    nxyz : ndarray
        Number of grid points in each direction.

    interpolation : string
        Interpolation method. Can be 'mean', 'trilinear' or 'tricubic'.
        If 'tricubic' interpolation is chosen, then the splines functions need
        to be specified.

    splies : eqtools.trispline Spline
        Spline function generated by eqtools.trispline.

    Returns
    -------
    ndarray with interpolated vector field.
    """

    import numpy as np

    # If the point lies outside the domain, places it on the boundary.
    for p in range(3):
        if (xyz[p] < oxyz[p]):
            xyz[p] = oxyz[p]
        if (xyz[p] > oxyz[p] + dxyz[p]*(nxyz[p]-1)):
            xyz[p] = oxyz[p] + dxyz[p]*(nxyz[p]-1)

    if (interpolation == "mean") or (interpolation == "trilinear"):
        # Find the adjacent indices.
        i = (xyz[0] - oxyz[0]) / dxyz[0]
        ii = np.array([int(np.floor(i)), int(np.ceil(i))])
        if i < 0:
            i = 0
        if i > nxyz[0] - 1:
            i = nxyz[0] - 1
        ii = np.array([int(np.floor(i)), int(np.ceil(i))])

        j = (xyz[1] - oxyz[1]) / dxyz[1]
        jj = np.array([int(np.floor(j)), int(np.ceil(j))])
        if j < 0:
            j = 0
        if j > nxyz[1] - 1:
            j = nxyz[1] - 1
        jj = np.array([int(np.floor(j)), int(np.ceil(j))])

        k = (xyz[2] - oxyz[2]) / dxyz[2]
        kk = np.array([int(np.floor(k)), int(np.ceil(k))])
        if k < 0:
            k = 0
        if k > nxyz[2] - 1:
            k = nxyz[2] - 1
        kk = np.array([int(np.floor(k)), int(np.ceil(k))])

        ## If the point lies outside the domain, return 0.
        #if (
            #(ii[0] < -1)
            #or (ii[1] > nxyz[0])
            #or (jj[0] < -1)
            #or (jj[1] > nxyz[1])
            #or (kk[0] < -1)
            #or (kk[1] > nxyz[2])
        #):
            #return np.zeros([0, 0, 0])

    # Interpolate the field.
    if interpolation == "mean":
        sub_field = field[:, :, :, ii[0] : ii[1] + 1]
        sub_field = sub_field[:, :, jj[0] : jj[1] + 1, :]
        sub_field = sub_field[:, kk[0] : kk[1] + 1, :, :]
        return np.mean(sub_field, axis=(1, 2, 3))

    if interpolation == "trilinear":
        if ii[0] == ii[1]:
            w1 = np.array([1, 1])
        else:
            w1 = i - ii[::-1]

        if jj[0] == jj[1]:
            w2 = np.array([1, 1])
        else:
            w2 = np.array([nxyz[1] - j, j - jj[0]])

        if kk[0] == kk[1]:
            w3 = np.array([1, 1])
        else:
            w3 = np.array([nxyz[2] - k, k - kk[0]])

        weight = abs(
            w3.reshape((2, 1, 1)) * w2.reshape((1, 2, 1)) * w1.reshape((1, 1, 2))
        )

        sub_field = field[:, :, :, ii[0] : ii[1] + 1]
        sub_field = sub_field[:, :, jj[0] : jj[1] + 1, :]
        sub_field = sub_field[:, kk[0] : kk[1] + 1, :, :]
        return np.sum(sub_field * weight, axis=(1, 2, 3)) / np.sum(weight)

    if interpolation == 'tricubic':
        return np.array([splines[0].ev(xyz[2], xyz[1], xyz[0]), splines[1].ev(xyz[2], xyz[1], xyz[0]), splines[2].ev(xyz[2], xyz[1], xyz[0])])
