import argparse
import configparser
import logging
import pathlib

import numpy as np
from scipy import interpolate

from .lambert import Lambert

parser = argparse.ArgumentParser(description="Pre-process bathymetry")
parser.add_argument("-v", "--verbose", action="count", default=0)
args = parser.parse_args()

logging.basicConfig(level=max((10, 20 - 10 * args.verbose)))
log = logging.getLogger("bathy")

log.info("Starting bathymetry pre-processing")
config = configparser.ConfigParser()
config.read("config.ini")

inp_root = pathlib.Path(config.get("inp", "root"))
out_root = pathlib.Path(config.get("out", "root"))
bathy_inp = out_root.joinpath(config.get("out", "sub"))
hires_inp = inp_root.joinpath(config.get("inp", "hires"))
bathy_out = inp_root.joinpath(config.get("out", "out"))

log.info(f"Loading bathymetry from {bathy_inp}")
bathy_curvi = np.load(bathy_inp)

projection = Lambert()
bathy = np.stack(
    (
        *projection.cartesian(bathy_curvi[:, 0], bathy_curvi[:, 1]),
        bathy_curvi[:, 2],
    ),
    axis=1,
)
log.debug(f"Cartesian bathy: {bathy}")

artha_curvi = np.array(
    (config.getfloat("artha", "lon"), config.getfloat("artha", "lat"))
)
buoy_curvi = np.array(
    (config.getfloat("buoy", "lon"), config.getfloat("buoy", "lat"))
)

artha = np.asarray(projection.cartesian(*artha_curvi))
buoy = np.asarray(projection.cartesian(*buoy_curvi))

D = np.diff(np.stack((artha, buoy)), axis=0)
x = np.arange(
    config.getfloat("out", "left", fallback=0),
    np.sqrt((D**2).sum()) + config.getfloat("out", "right", fallback=0),
    config.getfloat("out", "step", fallback=1),
)
theta = np.angle(D.dot((1, 1j)))

coords = artha + (x * np.stack((np.cos(theta), np.sin(theta)))).T

log.info("Interpolating bathymetry in 1D")
z = interpolate.griddata(bathy[:, :2], bathy[:, 2], coords)
log.debug(f"z: {z}")

_hires = np.loadtxt(hires_inp)[::-1]
bathy_hires = np.stack(
    (
        np.linspace(
            0,
            (_hires.size - 1) * config.getfloat("inp", "hires_step"),
            _hires.size,
        ),
        _hires,
    ),
    axis=1,
)
del _hires
log.debug(f"Bathy hires: {bathy_hires}")

z_cr = 5
hires_crossing = np.diff(np.signbit(bathy_hires[:, 1] - z_cr)).nonzero()[0][-1]
log.debug(f"Hires crossing: {hires_crossing}")
z_crossing = np.diff(np.signbit(z - z_cr)).nonzero()[0][-1]
log.debug(f"Z crossing: {z_crossing}")

x_min_hires = x[z_crossing] + (
    bathy_hires[:, 0].min() - bathy_hires[hires_crossing, 0]
)
x_max_hires = x[z_crossing] + (
    bathy_hires[:, 0].max() - bathy_hires[hires_crossing, 0]
)
log.debug(f"Replacing range: [{x_min_hires},{x_max_hires}]")

flt_x = (x > x_min_hires) & (x < x_max_hires)
z[flt_x] = interpolate.griddata(
    (bathy_hires[:, 0],),
    bathy_hires[:, 1],
    (x[flt_x] - x[z_crossing] + bathy_hires[hires_crossing, 0]),
)

np.savetxt(out_root.joinpath("bathy.dat"), z[::-1], newline=" ")
np.savetxt(out_root.joinpath("hstru.dat"), np.zeros(z.shape), newline=" ")
np.savetxt(out_root.joinpath("poro.dat"), np.zeros(z.shape), newline=" ")
np.savetxt(out_root.joinpath("psize.dat"), np.zeros(z.shape), newline=" ")