# rvid_box.py
#
# Generate 3D visualisation from 4x 2D video slices.
#
# Author: S. Long (shiting.long@aalto.fi) and F. Gent (fred.gent.ncl@gmail.com)
"""
This script contains three functions, whereas plot_box is the only function
that needs to be called by the user.
One can call plot_box from a Python interpreter after importing rvid_box.py.
An example call: plot_box(field='lnTT', slice=50, datadir=<path>, unit='temperature').
The output image will be saved to directory visu/images/.

The details of the three functions are listed as follows:

1) plot_box : the main function to plot cube objects by assembling slice_xy, slice_xy2, slice_xz and slice_yz together,
it calls read_slices first then plot.

Arguments:

field --- a string variable, indicating which variable to slice.

datadir --- a string variable, indicating the path to the data directory.

proc --- an integer giving the processor to read a slice from.

scale --- a string variable, the scale of the colorbar; default is regular scale,
    can be changed to logarithmic by setting it 'log' (the base number is 10).

quiet --- a boolean variable, print debug information.

slice --- the sequential integer number of a slice in given period (starting from 0), default is the first slice (slice 0),
    can be set to '-1' to generate all slices in given period.

colorscale --- a string variable,  color scale for the plot. Default is 'brbg_r'.
    Plotly supports most matplotlib colormaps and you can customize a colorscale on your own.
    See https://plot.ly/python/colorscales/ for more information.

unit --- a string variable, define the unit of the colorbar, shown on its title.

axestitle --- a three-tuple indicating the title of x, y and z axis. It is 'x', 'y', 'z' by default.

viewpoint --- a three-tuple indicating the viewpoint of the center of the box (default is (0, 0, 0),
    you can change the center in the layout of the figure in function plot()).

imageformat --- a string variable, indicating the format of the output image. Set to 'png' by default.
    It supports png, jpeg, webp, svg and pdf format.

2) pcn.read.slices: read the four slices
3) plot: plot cube objects

"""


import numpy as np
import os
import plotly.graph_objs as go
import plotly.io as pio
import plotly.io._orca
import retrying

unwrapped = plotly.io._orca.request_image_with_retrying.__wrapped__
wrapped = retrying.retry(wait_random_min=1000)(unwrapped)
plotly.io._orca.request_image_with_retrying = wrapped


def plot(
    # field (or multiple todo) and 4 surfaces (extra todo)
    slice_obj,
    fields,
    xyzplane,
    itt,
    it,
    quiet=True,
    # save output
    figdir="./images/",
    imageformat="png",
    # set color parameters
    norm="linear",
    colorscale="brbg_r",
    cmin=0,
    cmax=1,
    dtype="f",
    # locate the box and axes
    viewpoint=(-1.35, -2.1, 0.5),
    offset=2.0,
    margin=(20, 20, 30, 0),
    autosize=False,
    image_dim=(800, 500),
    # handle axes properties
    visxyz=[True, True, True],
    axestitle=("x", "y", "z"),
    xyz=None,
    # color bar properties
    cbar_label="",
    cbar_loc=1.0,
    cbar_label_pos="right",
    cbar_thickness=10.0,
    cbar_borderwidth=None,
    # add text for time stamp
    time=0,
    textxy=(0, 0),
    str_unit="",
    tscale=1.0,
    isd=3,
    fontsize=25,
):

    """
    Arguments:
        yzslices --- a numpy array, data of a yz slice.
        xyslices --- a numpy array, data of a xy slice.
        xzslices --- a numpy array, data of a xz slice.
        xy2slices --- a numpy array, data of a xy2 slice.
        scale --- a string variable, scale of the colorbar; default is regular scale,
            can be changed to logarithmic by setting it 'log' (the base number is 10).
        slice --- the sequential integer number of a slice in given period (starting from 0),
            default is the first slice (slice 0), can be set to '-1' to generate all slices in given period.
        colorscale --- a string variable,  color scale for the plot. Default is 'brbg_r'.
            Plotly supports most Matplotlib colormaps and you can customize a colorscale on your own.
            See https://plot.ly/python/colorscales/ for more information.
        unit --- a string variable, define the unit of the colorbar, shown on its title.
        axestitle --- a three-tuple indicating the title of x, y and z axis. It is 'x', 'y', 'z' by default.
        viewpoint --- a three-tuple indicating the viewpoint of the center of the box (default is (0, 0, 0),
            you can change the center in the layout of the figure in function plot()).
        field --- a string variable, which variable to slice.
        time --- a double variable, the simulation time when the data to be plotted is generated.
        xyz --- a three-tuple indicating the value range of the three axes.
        offset --- a double variable, the offset of xy2 slice below the box.
            It is set to 2 by default, meaning the xy2 slice is located below the box by half of the height of the box.
        imageformat --- a string variable, indicating the format of the output image. Set to 'png' by default.
            It supports png, jpeg, webp, svg and pdf format.
    """
    """
    options = Options()
    options.add_argument("--headless")
    # options.add_argument( "--screenshot test.jpg http://google.com/" )
    driver = webdriver.Firefox(options=options, executable_path="/homeappl/home/longs1/tmp/geckodriver")
    driver.set_window_size(1000, 500)
    time.sleep(5)
    driver.get('file:///homeappl/home/longs1/appl_taito/pencil-code/python/pencil/visu/temp-plot.html')
    driver.save_screenshot('test.png')
    #driver.close()
    #imgkit.from_file('test.html', 'out.png')
    #pio.write_image(fig, "test", format="png")
    """

    if not quiet:
        print("Printing t={:.2g} box plot".format(time))
    for field in fields:

        height = globals()["xzslice"][itt].shape[0]
        width = globals()["xyslice"][itt].shape[0]
        depth = globals()["xyslice"][itt].shape[1]

        if xyz is None:
            xx = np.linspace(0, depth - 1, depth)
            yy = np.linspace(0, width - 1, width)
            zz = np.linspace(0, height - 1, height)
            x_z, z_x = np.meshgrid(xx, zz)
            y_z, z_y = np.meshgrid(yy, zz)
            x_y, y_x = np.meshgrid(xx, yy)
        else:
            # python array order is z,y,x
            x_z, z_x = np.meshgrid(xyz[0], xyz[2])
            y_z, z_y = np.meshgrid(xyz[1], xyz[2])
            x_y, y_x = np.meshgrid(xyz[0], xyz[1])
        if norm == "log":
            # field argu
            if "ln" in field:
                z1 = (np.log10(np.exp(np.float32(globals()["xyslice"][itt])))).astype(
                    dtype
                )
                y1 = (np.log10(np.exp(np.float32(globals()["xzslice"][itt])))).astype(
                    dtype
                )
                x1 = (np.log10(np.exp(np.float32(globals()["yzslice"][itt])))).astype(
                    dtype
                )
                z2 = (np.log10(np.exp(np.float32(globals()["xy2slice"][itt])))).astype(
                    dtype
                )
                cmax = (np.log10(np.exp(np.float32(cmax)))).astype(dtype)
                cmin = (np.log10(np.exp(np.float32(cmin)))).astype(dtype)
            else:
                z1 = np.log10(globals()["xyslice"][itt].astype("f")).astype(dtype)
                y1 = np.log10(globals()["xzslice"][itt].astype("f")).astype(dtype)
                x1 = np.log10(globals()["yzslice"][itt].astype("f")).astype(dtype)
                z2 = np.log10(globals()["xy2slice"][itt].astype("f")).astype(dtype)
                cmax = np.log10(cmax)
                cmin = np.log10(cmin)
        elif norm == "linear":
            z1 = globals()["xyslice"][itt].astype(dtype)
            y1 = globals()["xzslice"][itt].astype(dtype)
            x1 = globals()["yzslice"][itt].astype(dtype)
            z2 = globals()["xy2slice"][itt].astype(dtype)
            cmax = max(-cmin, cmax)
            cmin = -cmax
        else:
            print("WARNING: 'norm' undefined, applying 'linear'")
            z1 = globals()["xyslice"][itt].astype(dtype)
            y1 = globals()["xzslice"][itt].astype(dtype)
            x1 = globals()["yzslice"][itt].astype(dtype)
            z2 = globals()["xy2slice"][itt].astype(dtype)
            cmax = max(-cmin, cmax)
            cmin = -cmax
        ratios = [width / height, depth / height, 1]

        # set offsets of four slices for placing them in correct positions
        if xyz is None:
            z1_offset = -height / offset * np.ones(z1.shape)
            y1_offset = 0 * np.ones(y1.shape)
            x1_offset = 0 * np.ones(x1.shape)
            z2_offset = (height - 1) * np.ones(z2.shape)
        else:
            z1_offset = (
                xyz[2].min() - (xyz[2].max() - xyz[2].min()) / offset
            ) * np.ones(z1.shape)
            y1_offset = xyz[1].min() * np.ones(y1.shape)
            x1_offset = xyz[0].min() * np.ones(x1.shape)
            z2_offset = xyz[2].max() * np.ones(z2.shape)

        hmax = max(x_y.max(), y_x.max(), z_x.max())
        hmin = min(x_y.min(), y_x.min(), z_x.min())
        zmax = max(x_y.max(), y_x.max(), z_x.max())
        zmin = min(x_y.min(), y_x.min(), z1_offset.min())

        # projection in the z-direction
        proj_z = lambda x, y, z: z

        # for projection slice xy
        colorsurfz1 = proj_z(x_y, y_x, z1.tolist())
        colorsurfz2 = proj_z(x_y, y_x, z2.tolist())

        # for projection slice xz
        colorsurfy1 = proj_z(x_z, z_x, y1.tolist())

        # for projection slice yz
        colorsurfx1 = proj_z(y_z, z_y, x1.tolist())

        # plot slices to surfaces
        trace_y1 = go.Surface(
            z=list(z_x),
            x=list(x_z),
            y=list(y1_offset),
            showscale=True,
            colorscale=colorscale,
            surfacecolor=colorsurfy1,
            cmin=cmin,
            cmax=cmax,
            colorbar=dict(
                x=cbar_loc, thickness=cbar_thickness, borderwidth=cbar_borderwidth
            ),
        )
        trace_y1.colorbar.title.side = cbar_label_pos
        trace_y1.colorbar.title.text = cbar_label
        trace_y1.colorbar.title.font.size = fontsize

        trace_x1 = go.Surface(
            z=list(z_y),
            x=list(x1_offset),
            y=list(y_z),
            showscale=False,
            surfacecolor=colorsurfx1,
            colorscale=colorscale,
            cmin=cmin,
            cmax=cmax,
            colorbar=dict(
                x=cbar_loc,
            ),
        )
        trace_z1 = go.Surface(
            z=list(z1_offset),
            x=list(x_y),
            y=list(y_x),
            showscale=False,
            surfacecolor=colorsurfz1,
            colorscale=colorscale,
            cmin=cmin,
            cmax=cmax,
            colorbar=dict(
                x=cbar_loc,
            ),
        )
        trace_z2 = go.Surface(
            z=list(z2_offset),
            x=list(x_y),
            y=list(y_x),
            showscale=False,
            surfacecolor=colorsurfz2,
            colorscale=colorscale,
            cmin=cmin,
            cmax=cmax,
            colorbar=dict(
                x=cbar_loc,
            ),
        )

        data = [trace_y1, trace_x1, trace_z1, trace_z2]
        layout = go.Layout(
            annotations=[
                dict(
                    text=r"$t={}\,".format(round(time * tscale, isd)) + str_unit + r"$",
                    x=textxy[0],
                    y=textxy[1],
                    showarrow=False,
                )
            ],
            autosize=autosize,
            width=image_dim[0],
            height=image_dim[1],
            scene=dict(
                aspectmode="data",
                camera=dict(eye=dict(x=viewpoint[0], y=viewpoint[1], z=viewpoint[2])),
                xaxis=dict(
                    title=axestitle[0],
                    visible=visxyz[0],
                    # backgroundcolor='white',
                    autorange=False,
                    range=(x_y.min(), x_y.max()),
                    rangemode="normal",
                    # gridwidth=ratios[0]
                ),
                yaxis=dict(
                    title=axestitle[1],
                    visible=visxyz[1],
                    # backgroundcolor='white',
                    autorange=False,
                    range=(y_x.min(), y_x.max()),
                    rangemode="normal",
                    # gridwidth=ratios[1]
                ),
                zaxis=dict(
                    title=axestitle[2],
                    visible=visxyz[2],
                    # backgroundcolor='white',
                    autorange=False,
                    range=(zmin, zmax),
                    rangemode="normal",
                    # gridwidth=ratios[2]
                ),
            ),
            margin=dict(
                l=margin[0],
                r=margin[1],
                b=margin[2],
                t=margin[3],
            ),
        )
        # plot the figuresurface
        fig = go.Figure(data=data, layout=layout)
        # set filename of the figure
        filename = field + "_{0:04d}".format(itt) + "." + imageformat
        # display the figure html
        # po.plot(fig, image="png", auto_open=True,
        #          image_height=500, image_width=1000, filename=filename)
        # print the figure to file
        if not os.path.exists(figdir):
            os.mkdir(figdir)
        pio.write_image(fig, figdir + filename)


def plot_box(
    slice_obj,  # slice_obj=pcn.read.slices()
    # acquire the slice objects
    fields=[
        "uu1",
    ],
    datadir="./data/",
    proc=-1,
    xyzplane=[],
    quiet=True,
    oldfile=False,
    # select data to plot
    tstart=0.0,
    tend=1e38,
    islice=-1,
    # set image properties
    imageformat="png",
    figdir="./images/",
    # set color parameters
    colorscale="brbg_r",
    norm="linear",
    # color_range size 2 list cmin and cmax
    color_range=None,
    color_levels=None,
    # locate the box and axes
    viewpoint=(-1.35, -2.1, 0.5),
    offset=2.0,
    margin=(20, 20, 30, 0),
    autosize=False,
    image_dim=(800, 500),
    # handle axes properties
    visxyz=[True, True, True],
    axestitle=("x", "y", "z"),
    xyz=None,
    # color bar properties
    cbar_label=r"$u_x\,[{\rm km s}^{-1}]$",
    cbar_loc=1.0,
    cbar_label_pos="right",  # ['top', 'right', 'bottom']
    cbar_thickness=10.0,
    cbar_borderwidth=None,
    # add text for time stamp
    timestamp=False,
    textxy=(0, 0),
    tscale=1,
    str_unit="",
    isd=2,
    fontsize=25,
    # convert data to cgs from code units and rescale to cbar_label
    # par if present is a param object
    unit="unit_velocity",
    rescale=1.0,
    par=list(),
):

    # gd = pcn.read.grid(trim=True, quiet=True, datadir=datadir)
    ttmp = slice_obj.t[np.where(slice_obj.t <= tend)[0]]
    it = np.where(ttmp >= tstart)[0]
    if len(xyzplane) == 0:
        for key in slice_obj.__dict__.keys():
            if key in ["xy","xy2","xy3","xy4","xz","xz2","yz","r"]:
                xyzplane.append(key)
    if len(xyzplane) < 4:
        raise ValueError("xyzplane: rvid_box requires at least 4 surfaces.")
    # avoid increasing memory
    dtype = type(
        slice_obj.__getattribute__(xyzplane[0]).__getattribute__(fields[0])[0, 0, 0]
    )
    if dtype == np.float16 or dtype == "half":
        print("plot_box: caution dtype {} may cause under/overflow".format(dtype))
    for field in fields:
        if not isinstance(par, list) and len(unit) > 0:
            unitscale = par.__getattribute__(unit) * rescale
        else:
            unitscale = 1.0
        for key in xyzplane:
            if "ln" in field:
                globals()[key + "slice"] = slice_obj.__getattribute__(
                    key
                ).__getattribute__(field) + np.log(unitscale)
            else:
                globals()[key + "slice"] = (
                    slice_obj.__getattribute__(key).__getattribute__(field).astype("f")
                    * unitscale
                )
        if not isinstance(color_range, list):
            for field in fields:
                cmin, cmax = 1e38, -1e38
                # set color limits based on time series or single snapshot
                if color_levels == "common":
                    for key in xyzplane:
                        cmax = max(cmax, globals()[key + "slice"][it].max().astype("f"))
                        cmin = min(cmin, globals()[key + "slice"][it].min().astype("f"))
        else:
            cmin = color_range[0]
            cmax = color_range[1]
        if islice == -1:
            for itt in it:
                if not color_levels == "common" and not isinstance(color_range, list):
                    cmin, cmax = 1e38, -1e38
                    for key in xyzplane:
                        cmax = max(
                            cmax, globals()[key + "slice"][itt].max().astype("f")
                        )
                        cmin = min(
                            cmin, globals()[key + "slice"][itt].min().astype("f")
                        )
                plot(
                    # field (or multiple todo) and 4 surfaces (extra todo)
                    slice_obj,
                    fields,
                    xyzplane,
                    itt,
                    it,
                    quiet=quiet,
                    # yz[i], xy[i], xz[i], xy2[i],
                    # save output
                    figdir=figdir,
                    imageformat=imageformat,
                    # set color parameters
                    norm=norm,
                    colorscale=colorscale,
                    cmin=cmin,
                    cmax=cmax,
                    # locate the box and axes
                    viewpoint=viewpoint,
                    offset=offset,
                    margin=margin,
                    autosize=autosize,
                    image_dim=image_dim,
                    # handle axes properties
                    visxyz=visxyz,
                    axestitle=axestitle,
                    xyz=xyz,
                    # color bar properties
                    cbar_label=cbar_label,
                    cbar_loc=cbar_loc,
                    cbar_label_pos=cbar_label_pos,
                    cbar_thickness=cbar_thickness,
                    cbar_borderwidth=cbar_borderwidth,
                    # add text for time stamp
                    time=slice_obj.t[itt],
                    textxy=textxy,
                    str_unit=str_unit,
                    isd=isd,
                    fontsize=fontsize,
                    tscale=tscale,
                    dtype=dtype,
                )
        else:
            if not isinstance(color_range, list):
                for key in xyzplane:
                    cmax = max(cmax, globals()[key + "slice"][islice].max().astype("f"))
                    cmin = min(cmin, globals()[key + "slice"][islice].min().astype("f"))
            plot(
                # field (or multiple todo) and 4 surfaces (extra todo)
                slice_obj,
                fields,
                xyzplane,
                islice,
                [
                    islice,
                ],
                quiet=quiet,
                # yz[i], xy[i], xz[i], xy2[i],
                # save output
                figdir=figdir,
                imageformat=imageformat,
                # set color parameters
                norm=norm,
                colorscale=colorscale,
                cmin=cmin,
                cmax=cmax,
                # locate the box and axes
                viewpoint=viewpoint,
                offset=offset,
                margin=margin,
                autosize=autosize,
                image_dim=image_dim,
                # handle axes properties
                visxyz=visxyz,
                xyz=xyz,
                axestitle=axestitle,
                # color bar properties
                cbar_label=cbar_label,
                cbar_loc=cbar_loc,
                cbar_label_pos=cbar_label_pos,
                cbar_thickness=cbar_thickness,
                cbar_borderwidth=cbar_borderwidth,
                # add text for time stamp
                time=slice_obj.t[itt],
                textxy=textxy,
                str_unit=str_unit,
                isd=isd,
                fontsize=fontsize,
                tscale=tscale,
                dtype=dtype,
            )


"""
      - One of the following named colorscales:
            ['aggrnyl', 'agsunset', 'algae', 'amp', 'armyrose', 'balance',
             'blackbody', 'bluered', 'blues', 'blugrn', 'bluyl', 'brbg',
             'brwnyl', 'bugn', 'bupu', 'burg', 'burgyl', 'cividis', 'curl',
             'darkmint', 'deep', 'delta', 'dense', 'earth', 'edge', 'electric',
             'emrld', 'fall', 'geyser', 'gnbu', 'gray', 'greens', 'greys',
             'haline', 'hot', 'hsv', 'ice', 'icefire', 'inferno', 'jet',
             'magenta', 'magma', 'matter', 'mint', 'mrybm', 'mygbm', 'oranges',
             'orrd', 'oryel', 'oxy', 'peach', 'phase', 'picnic', 'pinkyl',
             'piyg', 'plasma', 'plotly3', 'portland', 'prgn', 'pubu', 'pubugn',
             'puor', 'purd', 'purp', 'purples', 'purpor', 'rainbow', 'rdbu',
             'rdgy', 'rdpu', 'rdylbu', 'rdylgn', 'redor', 'reds', 'solar',
             'spectral', 'speed', 'sunset', 'sunsetdark', 'teal', 'tealgrn',
             'tealrose', 'tempo', 'temps', 'thermal', 'tropic', 'turbid',
             'turbo', 'twilight', 'viridis', 'ylgn', 'ylgnbu', 'ylorbr',
             'ylorrd'].

"""
