# pc2vtk.py
#
# Convert data from Pencil Code format to vtk. Adopted from pc2vtk.pro.
#
# Author: Simon Candelaresi (iomsn1@googlemail.com).
"""
Contains the routines for converting PencilCode data data into vtk format.
"""


def var2vtk(
    var_file="var.dat",
    datadir="data",
    proc=-1,
    variables=None,
    b_ext=False,
    magic=[],
    destination="work",
    quiet=True,
    trimall=True,
    ti=-1,
    tf=-1,
):
    """
    Convert data from PencilCode format to vtk.

    call signature::

      var2vtk(var_file='', datadir='data', proc=-1,
             variables='', b_ext=False,
             destination='work', quiet=True, trimall=True, ti=-1, tf=-1)

    Read *var_file* and convert its content into vtk format. Write the result
    in *destination*.

    Keyword arguments:

      *var_file*:
        The original var_file.

      *datadir*:
        Directory where the data is stored.

      *proc*:
        Processor which should be read. Set to -1 for all processors.

      *variables*:
        List of variables which should be written. If None all.

      *b_ext*:
        Add the external magnetic field.

      *destination*:
        Destination file.

      *quiet*:
        Keep quiet when reading the var files.

      *trimall*:
        Trim the data cube to exclude ghost zones.

      *ti, tf*:
        Start and end index for animation. Leave negative for no animation.
        Overwrites variable var_file.
    """

    import numpy as np
    import sys
    from pencil import read
    from pencil import math

    # Determine of we want an animation.
    if ti < 0 or tf < 0:
        animation = False
    else:
        animation = True

    # If no variables specified collect all by default
    if not variables:
        variables = []
        indx = read.index()
        for key in indx.__dict__.keys():
            if "keys" not in key:
                variables.append(key)
        if "uu" in variables:
            magic.append("vort")
            variables.append("vort")
        if "rho" in variables or "lnrho" in variables:
            if "ss" in variables:
                magic.append("tt")
                variables.append("tt")
                magic.append("pp")
                variables.append("pp")
        if "aa" in variables:
            magic.append("bb")
            variables.append("bb")
            magic.append("jj")
            variables.append("jj")
            variables.append("ab")
            variables.append("b_mag")
            variables.append("j_mag")
    else:
        # Convert single variable string into length 1 list of arrays.
        if len(variables) > 0:
            if len(variables[0]) == 1:
                variables = [variables]
        if "tt" in variables:
            magic.append("tt")
        if "pp" in variables:
            magic.append("pp")
        if "bb" in variables:
            magic.append("bb")
        if "jj" in variables:
            magic.append("jj")
        if "vort" in variables:
            magic.append("vort")
        if "b_mag" in variables and not "bb" in magic:
            magic.append("bb")
        if "j_mag" in variables and not "jj" in magic:
            magic.append("jj")
        if "ab" in variables and not "bb" in magic:
            magic.append("bb")

    for t_idx in range(ti, tf + 1):
        if animation:
            var_file = "VAR" + str(t_idx)

        # Read the PencilCode variables and set the dimensions.
        var = read.var(
            var_file=var_file,
            datadir=datadir,
            proc=proc,
            magic=magic,
            trimall=True,
            quiet=quiet,
        )

        grid = read.grid(datadir=datadir, proc=proc, trim=trimall, quiet=True)

        params = read.param(quiet=True)

        # Add external magnetic field.
        if b_ext == True:
            B_ext = np.array(params.b_ext)
            var.bb[0, ...] += B_ext[0]
            var.bb[1, ...] += B_ext[1]
            var.bb[2, ...] += B_ext[2]

        dimx = len(grid.x)
        dimy = len(grid.y)
        dimz = len(grid.z)
        dim = dimx * dimy * dimz
        dx = (np.max(grid.x) - np.min(grid.x)) / (dimx - 1)
        dy = (np.max(grid.y) - np.min(grid.y)) / (dimy - 1)
        dz = (np.max(grid.z) - np.min(grid.z)) / (dimz - 1)

        # Write the vtk header.
        if animation:
            fd = open(destination + str(t_idx) + ".vtk", "wb")
        else:
            fd = open(destination + ".vtk", "wb")
        fd.write("# vtk DataFile Version 2.0\n".encode("utf-8"))
        fd.write("VAR files\n".encode("utf-8"))
        fd.write("BINARY\n".encode("utf-8"))
        fd.write("DATASET STRUCTURED_POINTS\n".encode("utf-8"))
        fd.write(
            "DIMENSIONS {0:9} {1:9} {2:9}\n".format(dimx, dimy, dimz).encode("utf-8")
        )
        fd.write(
            "ORIGIN {0:8.12} {1:8.12} {2:8.12}\n".format(
                grid.x[0], grid.y[0], grid.z[0]
            ).encode("utf-8")
        )
        fd.write(
            "SPACING {0:8.12} {1:8.12} {2:8.12}\n".format(dx, dy, dz).encode("utf-8")
        )
        fd.write("POINT_DATA {0:9}\n".format(dim).encode("utf-8"))

        # Write the data.
        for v in variables:
            print("Writing {0}.".format(v))
            # Prepare the data to the correct format.
            if v == "ab":
                data = math.dot(var.aa, var.bb)
            elif v == "b_mag":
                data = np.sqrt(math.dot2(var.bb))
            elif v == "j_mag":
                data = np.sqrt(math.dot2(var.jj))
            else:
                data = getattr(var, v)
            if sys.byteorder == "little":
                data = data.astype(np.float32).byteswap()
            else:
                data = data.astype(np.float32)
            # Check if we have vectors or scalars.
            if data.ndim == 4:
                data = np.moveaxis(data, 0, 3)
                fd.write("VECTORS {0} float\n".format(v).encode("utf-8"))
            else:
                fd.write("SCALARS {0} float\n".format(v).encode("utf-8"))
                fd.write("LOOKUP_TABLE default\n".encode("utf-8"))
            fd.write(data.tobytes())

        del var

        fd.close()


def slices2vtk(field="", extension="", datadir="data", destination="slices", proc=-1):
    """
    Convert slices from PencilCode format to vtk.

    call signature::

      slices2vtk(field='', extension='', datadir='data', destination='slices', proc=-1)

    Read slice files specified by *variables* and convert
    them into vtk format for the specified extensions.
    Write the result in *destination*.
    NB: You need to have called src/read_videofiles.x before using this script.

    Keyword arguments:

      *field*:
        All allowed fields which can be written as slice files, e.g. b2, uu1, lnrho, ...
        See the pencil code manual for more (chapter: "List of parameters for `video.in'").

      *extension*:
        List of slice positions.

      *datadir*:
        Directory where the data is stored.

      *destination*:
        Destination files.

      *proc*:
        Processor which should be read. Set to -1 for all processors.
    """

    import sys
    import numpy as np
    from pencil import read

    # Convert single variable string into length 1 list of arrays.
    if len(field) > 0:
        if len(field[0]) == 1:
            field = [field]
    if len(extension) > 0:
        if len(extension[0]) == 1:
            extension = [extension]

    # Read the grid dimensions.
    grid = read.grid(datadir=datadir, proc=proc, trim=True, quiet=True)

    # Read the dimensions.
    dim = read.dim(datadir=datadir, proc=proc)

    # Read the user given parameters for the slice positions.
    params = read.param(quiet=True)

    # Read the slice file for all specified variables and extensions.
    slices = read.slices(field=field, extension=extension, datadir=datadir, proc=proc)

    # Determine the position of the slices.
    if params.ix != -1:
        x0 = grid.x[params.ix]
    elif params.slice_position == "m":
        x0 = grid.x[int(len(grid.x) / 2)]
    if params.iy != -1:
        y0 = grid.y[params.iy]
    elif params.slice_position == "m":
        y0 = grid.y[int(len(grid.y) / 2)]
    if params.iz != -1:
        z0 = grid.z[params.iz]
    elif params.slice_position == "m":
        z0 = grid.z[int(len(grid.z) / 2)]
    if params.iz2 != -1:
        z02 = grid.z[params.iz]
    elif params.slice_position == "m":
        z02 = grid.z[int(len(grid.z) / 2)]

    for t_idx, t in enumerate(slices.t):
        for ext in extension:
            # Open the destination file for writing.
            fd = open(destination + "_" + ext + "_" + str(t_idx) + ".vtk", "wb")

            # Write the header.
            fd.write("# vtk DataFile Version 2.0\n".encode("utf-8"))
            fd.write("slices {0}\n".format(ext).encode("utf-8"))
            fd.write("BINARY\n".encode("utf-8"))
            fd.write("DATASET STRUCTURED_POINTS\n".encode("utf-8"))
            if ext == "xy":
                fd.write(
                    "DIMENSIONS {0:9} {1:9} {2:9}\n".format(dim.nx, dim.ny, 1).encode(
                        "utf-8"
                    )
                )
                fd.write(
                    "ORIGIN {0:8.12} {1:8.12} {2:8.12}\n".format(
                        grid.x[0], grid.y[0], z0
                    ).encode("utf-8")
                )
                fd.write(
                    "SPACING {0:8.12} {1:8.12} {2:8.12}\n".format(
                        grid.dx, grid.dy, 1.0
                    ).encode("utf-8")
                )
                dim_p = dim.nx
                dim_q = dim.ny
            if ext == "xy2":
                fd.write(
                    "DIMENSIONS {0:9} {1:9} {2:9}\n".format(dim.nx, dim.ny, 1).encode(
                        "utf-8"
                    )
                )
                fd.write(
                    "ORIGIN {0:8.12} {1:8.12} {2:8.12}\n".format(
                        grid.x[0], grid.y[0], z02
                    ).encode("utf-8")
                )
                fd.write(
                    "SPACING {0:8.12} {1:8.12} {2:8.12}\n".format(
                        grid.dx, grid.dy, 1.0
                    ).encode("utf-8")
                )
                dim_p = dim.nx
                dim_q = dim.ny
            if ext == "xz":
                fd.write(
                    "DIMENSIONS {0:9} {1:9} {2:9}\n".format(dim.nx, 1, dim.nz).encode(
                        "utf-8"
                    )
                )
                fd.write(
                    "ORIGIN {0:8.12} {1:8.12} {2:8.12}\n".format(
                        grid.x[0], y0, grid.z[0]
                    ).encode("utf-8")
                )
                fd.write(
                    "SPACING {0:8.12} {1:8.12} {2:8.12}\n".format(
                        grid.dx, 1.0, grid.dz
                    ).encode("utf-8")
                )
                dim_p = dim.nx
                dim_q = dim.nz
            if ext == "yz":
                fd.write(
                    "DIMENSIONS {0:9} {1:9} {2:9}\n".format(1, dim.ny, dim.nz).encode(
                        "utf-8"
                    )
                )
                fd.write(
                    "ORIGIN {0:8.12} {1:8.12} {2:8.12}\n".format(
                        x0, grid.y[0], grid.z[0]
                    ).encode("utf-8")
                )
                fd.write(
                    "SPACING {0:8.12} {1:8.12} {2:8.12}\n".format(
                        1.0, grid.dy, grid.dz
                    ).encode("utf-8")
                )
                dim_p = dim.ny
                dim_q = dim.nz
            fd.write("POINT_DATA {0:9}\n".format(dim_p * dim_q).encode("utf-8"))

            # Write the data.
            for fi in field:
                data = getattr(getattr(slices, ext), fi)
                fd.write(("SCALARS " + ext + "_" + fi + " float\n").encode("utf-8"))
                fd.write("LOOKUP_TABLE default\n".encode("utf-8"))
                if sys.byteorder == "little":
                    data = data.astype(np.float32).byteswap()
                else:
                    data = data.astype(np.float32)
                fd.write(data[t_idx].tobytes())

            fd.close()


## Convert PencilCode average file to vtk.
# def aver2vtk(varfile = 'xyaverages.dat', datadir = 'data/',
#            destination = 'xyaverages', quiet = 1):
#    """
#    Convert average data from PencilCode format to vtk.
#
#    call signature::
#
#      aver2vtk(varfile = 'xyaverages.dat', datadir = 'data/',
#            destination = 'xyaverages', quiet = 1):
#
#    Read the average file specified in *varfile* and convert the data
#    into vtk format.
#    Write the result in *destination*.
#
#    Keyword arguments:
#
#      *varfile*:
#        Name of the average file. This also specifies which dimensions the
#        averages are taken.
#
#      *datadir*:
#        Directory where the data is stored.
#
#      *destination*:
#        Destination file.
#
#    """
#
#    # read the grid dimensions
#    grid = pc.read_grid(datadir = datadir, trim = True, quiet = True)
#
#    # read the specified average file
#    if varfile[0:2] == 'xy':
#        aver = pc.read_xyaver()
#        line_len = int(np.round(grid.Lz/grid.dz))
#        l0 = grid.z[int((len(grid.z)-line_len)/2)]
#        dl = grid.dz
#    elif varfile[0:2] == 'xz':
#        aver = pc.read_xzaver()
#        line_len = int(np.round(grid.Ly/grid.dy))
#        l0 = grid.y[int((len(grid.y)-line_len)/2)]
#        dl = grid.dy
#    elif varfile[0:2] == 'yz':
#        aver = pc.read_yzaver()
#        line_len = int(np.round(grid.Lx/grid.dx))
#        l0 = grid.x[int((len(grid.x)-line_len)/2)]
#        dl = grid.dx
#    else:
#        print("aver2vtk: ERROR: cannot determine average file\n")
#        print("aver2vtk: The name of the file has to be either xyaver.dat, xzaver.dat or yzaver.dat\n")
#        return -1
#    keys = aver.__dict__.keys()
#    t = aver.t
#    keys.remove('t')
#
#    # open the destination file
#    fd = open(destination + '.vtk', 'wb')
#
#    fd.write('# vtk DataFile Version 2.0\n'.encode('utf-8'))
#    fd.write(varfile[0:2] + 'averages\n'.encode('utf-8'))
#    fd.write('BINARY\n'.encode('utf-8'))
#    fd.write('DATASET STRUCTURED_POINTS\n'.encode('utf-8'))
#    fd.write('DIMENSIONS {0:9} {1:9} {2:9}\n'.format(len(t), line_len, 1).encode('utf-8'))
#    fd.write('ORIGIN {0:8.12} {1:8.12} {2:8.12}\n'.format(float(t[0]), l0, 0.).encode('utf-8'))
#    fd.write('SPACING {0:8.12} {1:8.12} {2:8.12}\n'.format(t[1]-t[0], dl, 1.).encode('utf-8'))
#    fd.write('POINT_DATA {0:9}\n'.format(len(t) * line_len))
#
#    # run through all variables
#    for var in keys:
#        fd.write(('SCALARS ' + var + ' float\n').encode('utf-8'))
#        fd.write('LOOKUP_TABLE default\n'.encode('utf-8'))
#        for j in range(line_len):
#            for i in range(len(t)):
#                    fd.write(struct.pack(">f", aver.__dict__[var][i,j]))
#
#    fd.close()
#
#
#
## Convert PencilCode power spectra to vtk.
# def power2vtk(powerfiles = ['power_mag.dat'],
#              datadir = 'data/', destination = 'power', thickness = 1):
#    """
#    Convert power spectra from PencilCode format to vtk.
#
#    call signature::
#
#      power2vtk(powerfiles = ['power_mag.dat'],
#            datadir = 'data/', destination = 'power.vtk', thickness = 1):
#
#    Read the power spectra stored in the power*.dat files
#    and convert them into vtk format.
#    Write the result in *destination*.
#
#    Keyword arguments:
#
#      *powerfiles*:
#        The files containing the power spectra.
#
#      *datadir*:
#        Directory where the data is stored.
#
#      *destination*:
#        Destination file.
#
#      *thickness*:
#        Dimension in z-direction. Setting it 2 will create n*m*2 dimensional
#        array of data. This is useful in Paraview for visualizing the spectrum
#        in 3 dimensions. Note that this will simply double the amount of data.
#
#    """
#
#    # this should correct for the case the user types only one variable
#    if (len(powerfiles) > 0):
#        if (len(powerfiles[0]) == 1):
#            powerfiles = [powerfiles]
#
#    # read the grid dimensions
#    grid = pc.read_grid(datadir = datadir, trim = True, quiet = True)
#
#    # leave k0 to 1 now, will fix this later
#    k0 = 1.
#    # leave dk to 1 now, will fix this later
#    dk = 1.
#
#    # open the destination file
#    fd = open(destination + '.vtk', 'wb')
#
#    # read the first power spectrum
#    t, power = pc.read_power(datadir + powerfiles[0])
#
#    fd.write('# vtk DataFile Version 2.0\n'.encode('utf-8'))
#    fd.write('power spectra\n'.encode('utf-8'))
#    fd.write('BINARY\n'.encode('utf-8'))
#    fd.write('DATASET STRUCTURED_POINTS\n'.encode('utf-8'))
#    if (thickness == 1):
#        fd.write('DIMENSIONS {0:9} {1:9} {2:9}\n'.format(len(t), power.shape[1], 1).encode('utf-8'))
#    else:
#        fd.write('DIMENSIONS {0:9} {1:9} {2:9}\n'.format(len(t), power.shape[1], 2).encode('utf-8'))
#    fd.write('ORIGIN {0:8.12} {1:8.12} {2:8.12}\n'.format(float(t[0]), k0, 0.).encode('utf-8'))
#    fd.write('SPACING {0:8.12} {1:8.12} {2:8.12}\n'.format(t[1]-t[0], dk, 1.).encode('utf-8'))
#    if (thickness == 1):
#        fd.write('POINT_DATA {0:9}\n'.format(power.shape[0] * power.shape[1]).encode('utf-8'))
#    else:
#        fd.write('POINT_DATA {0:9}\n'.format(power.shape[0] * power.shape[1] * 2).encode('utf-8'))
#
#    for powfile in powerfiles:
#        # read the power spectrum
#        t, power = pc.read_power(datadir + powfile)
#
#        fd.write(('SCALARS ' + powfile[:-4] + ' float\n').encode('utf-8'))
#        fd.write('LOOKUP_TABLE default\n'.encode('utf-8'))
#
#        if (thickness == 1):
#            for j in range(power.shape[1]):
#                for i in range(len(t)):
#                        fd.write(struct.pack(">f", power[i,j]))
#        else:
#            for k in [1,2]:
#                for j in range(power.shape[1]):
#                    for i in range(len(t)):
#                            fd.write(struct.pack(">f", power[i,j]))
#
#    fd.close()
