#!/usr/bin/env python3

import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable

from pywarpx import picmi

# Number of time steps
max_steps = 100

# Grid
nx = 128
nz = 128

# Domain
xmin = 0.0e-6
zmin = 0.0e-6
xmax = 50.0e-6
zmax = 50.0e-6

# Cell size
dx = (xmax - xmin) / nx
dz = (zmax - zmin) / nz

# Domain decomposition
max_grid_size_x = 64
max_grid_size_z = 64

# PML
nxpml = 10
nzpml = 10
field_boundary = ["open", "open"]

# Spectral order
nox = 8
noz = 8

# Guard cells
nxg = 8
nzg = 8

# Initialize grid
grid = picmi.Cartesian2DGrid(
    number_of_cells=[nx, nz],
    lower_bound=[xmin, zmin],
    upper_bound=[xmax, zmax],
    lower_boundary_conditions=field_boundary,
    upper_boundary_conditions=field_boundary,
    guard_cells=[nxg, nzg],
    moving_window_velocity=[0.0, 0.0, 0],
    warpx_max_grid_size_x=max_grid_size_x,
    warpx_max_grid_size_y=max_grid_size_z,
)

# Initialize field solver
solver = picmi.ElectromagneticSolver(
    grid=grid,
    cfl=0.95,
    method="PSATD",
    stencil_order=[nox, noz],
    divE_cleaning=1,
    divB_cleaning=1,
    pml_divE_cleaning=1,
    pml_divB_cleaning=1,
    warpx_psatd_update_with_rho=True,
)

# Initialize diagnostics
diag_field_list = ["E", "B"]
particle_diag = picmi.ParticleDiagnostic(
    name="diag1",
    period=10,
    data_list=diag_field_list,
)
field_diag = picmi.FieldDiagnostic(
    name="diag1",
    grid=grid,
    period=10,
    data_list=diag_field_list,
)

# Initialize simulation
sim = picmi.Simulation(
    solver=solver,
    max_steps=max_steps,
    verbose=1,
    particle_shape="cubic",
    warpx_current_deposition_algo="direct",
    warpx_particle_pusher_algo="boris",
    warpx_field_gathering_algo="energy-conserving",
    warpx_use_filter=1,
)

# Add diagnostics to simulation
sim.add_diagnostic(particle_diag)
sim.add_diagnostic(field_diag)

# Write input file to run with compiled version
sim.write_input_file(file_name="inputs_2d")


# Compute min and max of fields data
def compute_minmax(data):
    vmax = np.abs(data).max()
    vmin = -vmax
    return vmin, vmax


# Plot fields data either in valid domain or in PML
def plot_data(data, comp, pml, title, name):
    fig, ax = plt.subplots(
        nrows=1, ncols=1, gridspec_kw=dict(wspace=0.5), figsize=[6, 5]
    )
    cax = make_axes_locatable(ax).append_axes("right", size="5%", pad="5%")
    lw = 0.8
    ls = "--"
    if pml:
        # Draw PMLs and ghost regions
        ax.axvline(x=0, linewidth=lw, linestyle=ls)
        ax.axvline(x=0 + nxg, linewidth=lw, linestyle=ls)
        ax.axvline(x=-nxpml, linewidth=lw, linestyle=ls)
        ax.axvline(x=nx, linewidth=lw, linestyle=ls)
        ax.axvline(x=nx - nxg, linewidth=lw, linestyle=ls)
        ax.axvline(x=nx + nxpml, linewidth=lw, linestyle=ls)
        ax.axhline(y=0, linewidth=lw, linestyle=ls)
        ax.axhline(y=0 + nzg, linewidth=lw, linestyle=ls)
        ax.axhline(y=-nzpml, linewidth=lw, linestyle=ls)
        ax.axhline(y=nz, linewidth=lw, linestyle=ls)
        ax.axhline(y=nz - nzg, linewidth=lw, linestyle=ls)
        ax.axhline(y=nz + nzpml, linewidth=lw, linestyle=ls)
        # Annotations
        ax.annotate(
            "PML",
            xy=(-nxpml // 2, nz // 2),
            rotation="vertical",
            ha="center",
            va="center",
        )
        ax.annotate(
            "PML",
            xy=(nx + nxpml // 2, nz // 2),
            rotation="vertical",
            ha="center",
            va="center",
        )
        ax.annotate(
            "PML",
            xy=(nx // 2, -nzpml // 2),
            rotation="horizontal",
            ha="center",
            va="center",
        )
        ax.annotate(
            "PML",
            xy=(nx // 2, nz + nzpml // 2),
            rotation="horizontal",
            ha="center",
            va="center",
        )
        ax.annotate(
            "PML ghost",
            xy=(nxg // 2, nz // 2),
            rotation="vertical",
            ha="center",
            va="center",
        )
        ax.annotate(
            "PML ghost",
            xy=(-nxpml - nxg // 2, nz // 2),
            rotation="vertical",
            ha="center",
            va="center",
        )
        ax.annotate(
            "PML ghost",
            xy=(nx - nxg // 2, nz // 2),
            rotation="vertical",
            ha="center",
            va="center",
        )
        ax.annotate(
            "PML ghost",
            xy=(nx + nxpml + nxg // 2, nz // 2),
            rotation="vertical",
            ha="center",
            va="center",
        )
        ax.annotate(
            "PML ghost",
            xy=(nx // 2, nzg // 2),
            rotation="horizontal",
            ha="center",
            va="center",
        )
        ax.annotate(
            "PML ghost",
            xy=(nx // 2, -nzpml - nzg // 2),
            rotation="horizontal",
            ha="center",
            va="center",
        )
        ax.annotate(
            "PML ghost",
            xy=(nx // 2, nz - nzg // 2),
            rotation="horizontal",
            ha="center",
            va="center",
        )
        ax.annotate(
            "PML ghost",
            xy=(nx // 2, nz + nzpml + nzg // 2),
            rotation="horizontal",
            ha="center",
            va="center",
        )
        # Set extent and sliced data
        extent = np.array(
            [-nxg - nxpml, nx + nxpml + nxg, -nzg - nzpml, nz + nzpml + nzg]
        )
    else:
        # Draw ghost regions
        ax.axvline(x=0, linewidth=lw, linestyle=ls)
        ax.axvline(x=nx, linewidth=lw, linestyle=ls)
        ax.axhline(y=0, linewidth=lw, linestyle=ls)
        ax.axhline(y=nz, linewidth=lw, linestyle=ls)
        # Annotations
        ax.annotate(
            "ghost",
            xy=(-nxg // 2, nz // 2),
            rotation="vertical",
            ha="center",
            va="center",
        )
        ax.annotate(
            "ghost",
            xy=(nx + nxg // 2, nz // 2),
            rotation="vertical",
            ha="center",
            va="center",
        )
        ax.annotate(
            "ghost",
            xy=(nx // 2, -nzg // 2),
            rotation="horizontal",
            ha="center",
            va="center",
        )
        ax.annotate(
            "ghost",
            xy=(nx // 2, nz + nzg // 2),
            rotation="horizontal",
            ha="center",
            va="center",
        )
        # Set extent and sliced data
        extent = np.array([-nxg, nx + nxg, -nzg, nz + nzg])
    X = data[(), (), comp].transpose()
    # Min and max for colorbar
    vmin, vmax = compute_minmax(X)
    # Display data as image
    im = ax.imshow(
        X=X, origin="lower", extent=extent, vmin=vmin, vmax=vmax, cmap="seismic"
    )
    # Add colorbar to plot
    fig.colorbar(im, cax=cax)
    # Set label for x- and y-axis, set title
    ax.set_xlabel("x")
    ax.set_ylabel("z")
    ax.set_title(title)
    # Set plot title
    suptitle = "PML in (x,z), 4 grids 64 x 64"
    plt.suptitle(suptitle)
    # Save figure
    figname = "figure_" + name + ".png"
    fig.savefig(figname, dpi=100)


# Initialize fields data (unit pulse) and apply smoothing
def init_data(data):
    impulse_1d = np.array([1.0 / 4.0, 1.0 / 2.0, 1.0 / 4.0])
    impulse = np.outer(impulse_1d, impulse_1d)
    data[nx // 2 - 1 : nx // 2 + 2, nz // 2 - 1 : nz // 2 + 2] = impulse


# Initialize inputs and WarpX instance
sim.initialize_inputs()
sim.initialize_warpx()

# Get fields data using Python wrappers
Ex = sim.fields.get("Efield_fp", dir="x", level=0)
Ey = sim.fields.get("Efield_fp", dir="y", level=0)
Ez = sim.fields.get("Efield_fp", dir="z", level=0)
Bx = sim.fields.get("Bfield_fp", dir="x", level=0)
By = sim.fields.get("Bfield_fp", dir="y", level=0)
Bz = sim.fields.get("Bfield_fp", dir="z", level=0)
F = sim.fields.get("F_fp", level=0)
G = sim.fields.get("G_fp", level=0)
Expml = sim.fields.get("pml_E_fp", dir="x", level=0)
Eypml = sim.fields.get("pml_E_fp", dir="y", level=0)
Ezpml = sim.fields.get("pml_E_fp", dir="z", level=0)
Bxpml = sim.fields.get("pml_B_fp", dir="x", level=0)
Bypml = sim.fields.get("pml_B_fp", dir="y", level=0)
Bzpml = sim.fields.get("pml_B_fp", dir="z", level=0)
Fpml = sim.fields.get("pml_F_fp", level=0)
Gpml = sim.fields.get("pml_G_fp", level=0)

# Initialize fields data in valid domain
init_data(Ex)
init_data(Ey)
init_data(Ez)
init_data(Bx)
init_data(By)
init_data(Bz)
init_data(F)
init_data(G)

# Advance simulation until last time step
sim.step(max_steps)

# Plot E
plot_data(Ex, 0, pml=False, title="Ex", name="Ex")
plot_data(Ey, 0, pml=False, title="Ey", name="Ey")
plot_data(Ez, 0, pml=False, title="Ez", name="Ez")

# Plot B
plot_data(Bx, 0, pml=False, title="Bx", name="Bx")
plot_data(By, 0, pml=False, title="By", name="By")
plot_data(Bz, 0, pml=False, title="Bz", name="Bz")

# F and G
plot_data(F, 0, pml=False, title="F", name="F")
plot_data(G, 0, pml=False, title="G", name="G")

# Plot E in PML
plot_data(Expml, 0, pml=True, title="Exy in PML", name="Exy")
plot_data(Expml, 1, pml=True, title="Exz in PML", name="Exz")
plot_data(Expml, 2, pml=True, title="Exx in PML", name="Exx")
plot_data(Eypml, 0, pml=True, title="Eyz in PML", name="Eyz")
plot_data(Eypml, 1, pml=True, title="Eyx in PML", name="Eyx")
plot_data(Eypml, 2, pml=True, title="Eyy in PML", name="Eyy")  # zero
plot_data(Ezpml, 0, pml=True, title="Ezx in PML", name="Ezx")
plot_data(Ezpml, 1, pml=True, title="Ezy in PML", name="Ezy")  # zero
plot_data(Ezpml, 2, pml=True, title="Ezz in PML", name="Ezz")

# Plot B in PML
plot_data(Bxpml, 0, pml=True, title="Bxy in PML", name="Bxy")
plot_data(Bxpml, 1, pml=True, title="Bxz in PML", name="Bxz")
plot_data(Bxpml, 2, pml=True, title="Bxx in PML", name="Bxx")
plot_data(Bypml, 0, pml=True, title="Byz in PML", name="Byz")
plot_data(Bypml, 1, pml=True, title="Byx in PML", name="Byx")
plot_data(Bypml, 2, pml=True, title="Byy in PML", name="Byy")  # zero
plot_data(Bzpml, 0, pml=True, title="Bzx in PML", name="Bzx")
plot_data(Bzpml, 1, pml=True, title="Bzy in PML", name="Bzy")  # zero
plot_data(Bzpml, 2, pml=True, title="Bzz in PML", name="Bzz")

# Plot F and G in PML
plot_data(Fpml, 0, pml=True, title="Fx in PML", name="Fx")
plot_data(Fpml, 1, pml=True, title="Fy in PML", name="Fy")
plot_data(Fpml, 2, pml=True, title="Fz in PML", name="Fz")
plot_data(Gpml, 0, pml=True, title="Gx in PML", name="Gx")
plot_data(Gpml, 1, pml=True, title="Gy in PML", name="Gy")
plot_data(Gpml, 2, pml=True, title="Gz in PML", name="Gz")


# Check values with benchmarks (precomputed from the same Python arrays)
def check_values(benchmark, data, comp, rtol, atol):
    passed = np.allclose(
        benchmark, np.sum(np.abs(data[(), (), comp])), rtol=rtol, atol=atol
    )
    assert passed


rtol = 5e-08
atol = 1e-12

# E
check_values(1013263608.6369569, Ex, 0, rtol, atol)
check_values(717278256.7957529, Ey, 0, rtol, atol)
check_values(717866566.5718911, Ez, 0, rtol, atol)
# B
check_values(3.0214509313437636, Bx, 0, rtol, atol)
check_values(3.0242765102729985, By, 0, rtol, atol)
check_values(3.0214509326970465, Bz, 0, rtol, atol)
# F and G
check_values(3.0188584528062377, F, 0, rtol, atol)
check_values(1013672631.8764204, G, 0, rtol, atol)
# E in PML
check_values(364287936.1526477, Expml, 0, rtol, atol)
check_values(183582352.20753333, Expml, 1, rtol, atol)
check_values(190065766.41491824, Expml, 2, rtol, atol)
check_values(440581907.0828975, Eypml, 0, rtol, atol)
check_values(178117294.05871135, Eypml, 1, rtol, atol)
check_values(0.0, Eypml, 2, rtol, atol)
check_values(430277101.26568377, Ezpml, 0, rtol, atol)
check_values(0.0, Ezpml, 1, rtol, atol)
check_values(190919663.2167449, Ezpml, 2, rtol, atol)
# B in PML
check_values(1.0565189315366146, Bxpml, 0, rtol, atol)
check_values(0.46181913800643065, Bxpml, 1, rtol, atol)
check_values(0.6849858305343736, Bxpml, 2, rtol, atol)
check_values(1.7228584190213505, Bypml, 0, rtol, atol)
check_values(0.47697332248020935, Bypml, 1, rtol, atol)
check_values(0.0, Bypml, 2, rtol, atol)
check_values(1.518338068658267, Bzpml, 0, rtol, atol)
check_values(0.0, Bzpml, 1, rtol, atol)
check_values(0.6849858291863835, Bzpml, 2, rtol, atol)
# F and G in PML
check_values(1.7808748509425263, Fpml, 0, rtol, atol)
check_values(0.0, Fpml, 1, rtol, atol)
check_values(0.4307845604625681, Fpml, 2, rtol, atol)
check_values(536552745.42701197, Gpml, 0, rtol, atol)
check_values(0.0, Gpml, 1, rtol, atol)
check_values(196016270.97767758, Gpml, 2, rtol, atol)
