"""Functions for inverting a 1D (time)coordinate transform given as numpy array

The inversion function internally requires an interpolation operator implementing the
RegularInterpolator interface, and which is provided by the user. Use
make_shift_inverse_lagrange_numpy to create an inversion operator employing Lagrange
interpolation.
"""

from __future__ import annotations

from typing import Callable, Final

import numpy as np

from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d
from lisainstrument.regular_interpolators import (
    RegularInterpolator,
    make_regular_interpolator_lagrange,
)


def fixed_point_iter(
    f: Callable[[NumpyArray1D], NumpyArray1D],
    ferr: Callable[[NumpyArray1D, NumpyArray1D], float],
    x: NumpyArray1D,
    tolerance: float,
    max_iter: int,
) -> NumpyArray1D:
    r"""Perform a fixed point iteration for functions operating on a 1D array.

    This uses fixed-point iteration to find a solution for
    $$ x = f(x) $$
    where $x$ is a 1D array and $f$ returns an array of the same size.

    The convergence criterion is provided by the user
    via a function $r(x)$ that returns a scalar error measure. The iteration
    is performed until $r(x) < \epsilon$. If convergence is not achieved
    after a given number of iterations, an exception is raised.

    Arguments:
        f: The function $f(x)$
        ferr: The error measure $r(x)$
        x: The initial data for the iteration.
        tolerance: The tolerance $\epsilon$
        max_iter: Maximum number of iterations

    Returns:
        Array with solution
    """
    for _ in range(max_iter):
        x_next = f(x)
        err = ferr(x, x_next)
        if err < tolerance:
            return x_next
        x = x_next
    msg = (
        f"ShiftInverseNumpy: iteration did not converge (error={err}, "
        f"tolerance={tolerance}), iterations={max_iter}"
    )
    raise RuntimeError(msg)


class ShiftInverseNumpy:  # pylint: disable=too-few-public-methods
    """Invert coordinate transformation given as shift"""

    def __init__(
        self,
        max_abs_shift: float,
        interp: RegularInterpolator,
        max_iter: int,
        tolerance: float,
    ):
        """Set up interpolator.

        Arguments:
            max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
            interp: Interpolation method
            max_iter: Maximum iterations before fail
            tolerance: Maximum absolute error of result
        """
        self._max_abs_shift: Final = int(np.ceil(max_abs_shift))
        self._interp_np: Final = interp
        self._max_iter = int(max_iter)
        self._tolerance = float(tolerance)

    @property
    def margin_left(self) -> int:
        """Left margin size.

        Specifies how many samples on the left have to be added by boundary conditions.
        """
        return self._interp_np.margin_left + self._max_abs_shift

    @property
    def margin_right(self) -> int:
        """Right margin size.

        Specifies how many samples on the right have to be added by boundary conditions.
        """
        return self._interp_np.margin_right + self._max_abs_shift

    def __call__(self, shift: np.ndarray, fsample: float) -> NumpyArray1D:
        """Find the shift for the inverse transformation given by a shift.

        Arguments:
            shift: 1D numpy array with shifts of the coordinate transform
            fsample: sample rate of shift array

        Returns:
            1D numpy array with shift at transformed coordinate
        """

        shift_idx = shift * fsample
        dx = make_numpy_array_1d(shift_idx)

        dx_pad = np.pad(
            dx,
            (self.margin_left, self.margin_right),
            mode="constant",
            constant_values=(shift_idx[0], shift_idx[-1]),
        )

        def f_iter(x: NumpyArray1D) -> NumpyArray1D:
            return self._interp_np.apply_shift(dx_pad, -x, self.margin_left)

        def f_err(x1: NumpyArray1D, x2: NumpyArray1D) -> float:
            return np.max(np.abs(x1 - x2)[self.margin_left : -self.margin_right])

        shift_idx_inv = fixed_point_iter(
            f_iter, f_err, dx, self._tolerance, self._max_iter
        )
        shift_inv = shift_idx_inv / fsample

        return make_numpy_array_1d(shift_inv)


def make_shift_inverse_lagrange_numpy(
    order: int,
    max_abs_shift: float,
    max_iter: int,
    tolerance: float,
) -> ShiftInverseNumpy:
    """Set up ShiftInverseNumpy instance with Lagrange interpolation method.

    Arguments:
        order: Order of the Lagrange polynomials
        max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
        max_iter: Maximum iterations before fail
        tolerance: Maximum absolute error of result

    Returns:
        Inversion function of type ShiftInverseNumpy
    """
    interp = make_regular_interpolator_lagrange(order)
    return ShiftInverseNumpy(max_abs_shift, interp, max_iter, tolerance)