#!/usr/bin/python3
# -*- coding: utf-8 -*-   vim: set fileencoding=utf-8 :

"""Test reading data files from Python"""


import numpy as np
import os
from typing import Any, Tuple
import pytest

from test_utils import (
    assert_equal,
    assert_true,
    _assert_close,
    _assert_equal_tuple,
    cmp_extracted,
    require_sample,
)

import pencil as pc
from pencil.read.timeseries import ts
from pencil.read.dims import dim
from pencil.read.varfile import var
from pencil.read.params import param
from pencil.read.powers import power


DATA_DIR = pytest.static_data_location/"serial-1"

def test_read_time_series() -> None:
    """Read time series."""
    time_series = ts(datadir=DATA_DIR, quiet=True)
    expected = {
        "it": np.array([0, 50, 100, 150]),
        "t": np.array([0.000, 0.441, 0.939, 1.480]),
        "dt": np.array([8.63e-3, 9.18e-3, 1.10e-2, 1.02e-2]),
        "urms": np.array([0.7071, 0.5917, 0.3064, 0.26]),
        "rhom": np.array([1.0, 1.0, 1.0, 1.0]),
        "ecrm": np.array([1.0, 1.020, 1.059, 1.058]),
        "ecrmax": np.array([1.000, 1.930, 2.492, 1.835]),
    }
    for key, val in expected.items():
        expect = val
        actual = getattr(time_series, key)
        assert_true(
            np.allclose(expect, actual),
            "time_series.{}: expected {}, got {}".format(
                key, expect, actual
            ),
        )
    _assert_close(time_series.rhom[2], 1.0, "rhom[2]")
    _assert_close(time_series.urms[3], 0.26, "urms[3]")
    _assert_close(time_series.ecrmax[3], 1.835, "ecrmax[3]")


def test_read_dim() -> None:
    """Read dim.dat file."""
    global_dim = dim(DATA_DIR)
    assert_equal(global_dim.mx, 10)
    assert_equal(global_dim.my, 12)
    assert_equal(global_dim.mz, 11)
    assert_equal(global_dim.mvar, 5)
    assert_equal(global_dim.precision, "S")
    assert_equal(global_dim.nghostx, 3)
    assert_equal(global_dim.nghosty, 3)
    assert_equal(global_dim.nghostz, 3)
    assert_equal(global_dim.nprocx, 1)
    assert_equal(global_dim.nprocy, 1)
    assert_equal(global_dim.nprocz, 1)

    proc_dim = dim(DATA_DIR, 0)
    # As we don't have a Dim.__eq__() method:
    attributes = [
        "mx",
        "my",
        "mz",
        "mvar",
        "precision",
        "nghostx",
        "nghosty",
        "nghostz",
    ]
    for attr in attributes:
        assert_equal(
            getattr(global_dim, attr),
            getattr(proc_dim, attr),
            "global_dim.{0} = {1} ≠ proc_dim.{0} = {2}".format(
                attr, getattr(global_dim, attr), getattr(proc_dim, attr)
            ),
        )


def test_read_param() -> None:
    """Read param.nml file."""
    params = param(DATA_DIR)
    assert_equal(params.coord_system, "cartesian")
    assert_equal(params.lcollective_io, False)
    assert_equal(params.gamma, 1.666_666_6)
    assert_equal(params.kx_uu, 1.0)
    assert_equal(params.cs2top, 1.0)
    assert_equal(params.lhydro, True)
    assert_equal(params.ldensity, True)
    assert_equal(params.lentropy, True)
    assert_equal(params.ltemperature, False)


def test_read_var() -> None:
    """Read var.dat (data cube) file."""
    data = var("var.dat", DATA_DIR, proc=0, quiet=True)
    _assert_equal_tuple(data.f.shape, (5, 11, 12, 10))

    def ident(x: Any) -> Any:
        return x

    expected = [
        # (key, extractor, expected, eps)
        ("t", ident, 3.865971, 1.0e-6),
        ("dx", ident, 1.333333, 1.0e-6),
        ("x", np.mean, 0.0, 1.0e-6),
        ("dx", ident, 1.3333334, 1.0e-6),
        ("z", np.mean, 1.6332519, 1.0e-6),
        ("z", lambda z: np.std(z), 5.918408, 1.0e-6),
        ("f", lambda f: np.mean(f[0, :, :, :]), 0.0, 1.0e-6),
        ("f", lambda f: np.mean(f[1, :, :, :]), -1.668_489e-16, 1.0e-22),
        ("f", lambda f: np.mean(f[2, :, :, :]), -7.817_168e-11, 1.0e-17),
        ("f", lambda f: np.mean(f[3, :, :, :]), 1.763_629e-9, 1.0e-15),
        ("f", lambda f: np.mean(f[4, :, :, :]), 2.544_411e-19, 1.0e-25),
        ("f", lambda f: np.std(f[0, :, :, :]), 0.0, 1.0e-6),
        ("f", lambda f: np.std(f[1, :, :, :]), 1.705_128e-9, 1.0e-15),
        ("f", lambda f: np.std(f[2, :, :, :]), 1.171_468e-9, 1.0e-15),
        ("f", lambda f: np.std(f[3, :, :, :]), 2.497_441e-9, 1.0e-15),
        ("f", lambda f: np.std(f[4, :, :, :]), 2.047_645e-19, 1.0e-25),
    ]
    for (key, extract, expect, eps) in expected:
        cmp_extracted(getattr(data, key), extract, expect, key, eps)


def test_read_power() -> None:
    """Read power spectra"""
    ps = power(datadir=DATA_DIR, quiet=True)

    expected = {
        "t": np.array([1.0477389, 2.0494874]),
        "krms": np.array([0.0, 1.29]),
        "kin": np.array([[1.88e-10, 1.41e-07], [8.16e-10, 1.40e-06]]),
        "hel_kin": np.array([[-2.08e-15, 1.14e-08], [2.26e-15, 2.66e-07]]),
        # TODO: test reading complex 'spectra' as well.
    }
    for key, val in expected.items():
        expect = val
        actual = getattr(ps, key)
        assert_true(
            np.allclose(expect, actual),
            "power.{}: expected {}, got {}".format(key, expect, actual),
        )

@require_sample("samples/helical-MHDturb")
def test_read_var_2(datadir_helical_MHDTurb):
    var = pc.read.var(
        datadir=datadir_helical_MHDTurb,
        trimall=False,
        lpersist=True,
        magic=["bb"],
        )

    assert len(var.x) == 38
    assert len(var.y) == 38
    assert len(var.z) == 38

    assert np.isclose(var.f[0,6,8,9], 0.02334117174211011)
    assert np.isclose(var.uy[3,9,5], -0.05910974500656841)
    assert np.isclose(var.uz[8,13,30], -0.02635602018447831)

    assert np.isclose(var.persist.forcing_tsforce, 0.3999999999999999)

@require_sample("samples/helical-MHDturb")
def test_read_var_2_trim(datadir_helical_MHDTurb):
    var = pc.read.var(datadir=datadir_helical_MHDTurb, trimall=True, lpersist=True)

    assert len(var.x) == 32
    assert len(var.y) == 32
    assert len(var.z) == 32

    assert np.isclose(var.f[0,3,5,6], 0.02334117174211011)
    assert np.isclose(var.uy[0,6,2], -0.05910974500656841)
    assert np.isclose(var.uz[5,10,27], -0.02635602018447831)

    assert np.isclose(var.persist.forcing_tsforce, 0.3999999999999999)

@require_sample("samples/conv-slab_cp_2")
def test_read_var_3_local(datadir_conv_slab_cp_2):
    kwargs = {
        'datadir': datadir_conv_slab_cp_2,
        'trimall': True,
        'lpersist': True,
        }

    var_g = pc.read.var(**kwargs)
    var_p0 = pc.read.var(**kwargs, proc=0)
    var_p1 = pc.read.var(**kwargs, proc=1)

    assert len(var_g.y) == 2*len(var_p0.y)
    assert len(var_p0.y) == len(var_p1.y)
    assert len(var_g.x) == len(var_p0.x)

    assert var_g.uz[13,17,5] == var_p1.uz[13,1,5]
    assert var_g.lnrho[13,11,5] == var_p0.lnrho[13,11,5]
