"""Functions for applying fixed real-valued shifts to dask arrays using Lagrange interpolation

Use make_fixed_shift_lagrange_dask to create a Lagrange interpolator for dask arrays.
"""

from __future__ import annotations

from typing import Final

import dask
import dask.array as da
import numpy as np
from typing_extensions import assert_never

from lisainstrument.dynamic_delay_numpy import DynShiftBC
from lisainstrument.fir_filters_dask import DaskArray1D, make_dask_array_1d
from lisainstrument.fixed_shift_numpy import FixedShiftFactory, FixedShiftLagrange


class FixedShiftDask:  # pylint: disable=too-few-public-methods
    """Interpolate Dask arrays to locations given by a const shift.

    This allows to interpolate samples in a Dask array to locations specified
    by a fixed shift. The shift is specified in units of the array index, i.e.
    there is no separate coordinate array. A positive shift refers to values
    left of a given sample, negative shifts to values on the right.

    The boundary treatment can be specified for each boundary in terms of
    DynShiftBC enums.

    The interpolation method is not fixed but provided via an interpolator
    instance implementing the RegularInterpMethod protocol.

    """

    def __init__(
        self,
        left_bound: DynShiftBC,
        right_bound: DynShiftBC,
        interp_fac: FixedShiftFactory,
    ):
        """Not intended for direct use, employ named constructors instead

        Arguments:
            left_bound: boundary treatment on the left
            right_bound: boundary treatment on the right
            interp_fac: Function to create interpolator engine for given shift
        """
        self._left_bound: Final = left_bound
        self._right_bound: Final = right_bound
        self._interp_fac: Final = interp_fac

    def _padding_left(self, samples: da.Array, margin_left: int) -> da.Array:
        if margin_left > 0:
            match self._left_bound:
                case DynShiftBC.ZEROPAD:
                    samples = da.concatenate([da.zeros(margin_left), samples])
                case DynShiftBC.FLAT:
                    samples = da.concatenate(
                        [da.ones(margin_left) * samples[0], samples]
                    )
                case DynShiftBC.EXCEPTION:
                    msg = (
                        f"FixedShiftDask: left edge handling {self._left_bound.name} "
                        f"impossible for given delay, would need margin of {margin_left}."
                    )
                    raise RuntimeError(msg)
                case _ as unreachable:
                    assert_never(unreachable)
        elif margin_left < 0:
            samples = samples[-margin_left:]
        return samples

    def _padding_right(self, samples: da.Array, margin_right: int) -> da.Array:
        if margin_right > 0:
            match self._right_bound:
                case DynShiftBC.ZEROPAD:
                    samples = da.concatenate([samples, da.zeros(margin_right)])
                case DynShiftBC.FLAT:
                    samples = da.concatenate(
                        [samples, da.ones(margin_right) * samples[-1]]
                    )
                case DynShiftBC.EXCEPTION:
                    msg = (
                        f"FixedShiftDask: right edge handling {self._right_bound.name} "
                        f"impossible for given delay, would need margin of {margin_right}."
                    )
                    raise RuntimeError(msg)
                case _ as unreachable:
                    assert_never(unreachable)
        elif margin_right < 0:
            samples = samples[:margin_right]

        return samples

    def __call__(self, samples: da.Array, shift: float) -> DaskArray1D:
        r"""Apply shift $s$ to samples in Dask array.

        Denoting the input data as $y_i$ with $i=0 \ldots N-1$, and the interpolated
        input data as $y(t)$, such that $y(i)=y_i$, the output $z_k$ is given by
        $z_k = y(k-s), k=0 \ ldots N - 1$.

        Required data outside the provided samples is created if the specified
        boundary condition allows it, or an exception is raised if the BC disallows it.
        The output has same length and chunking as the input.


        Arguments:
            samples: 1D Dask array with data samples
            shift: The shift $s$

        Returns:
            Dask array with interpolated samples
        """

        loc = -shift
        loc_int = int(np.floor(loc))
        loc_frac = loc - loc_int

        interp = self._interp_fac(-loc_frac)

        margin_left = interp.margin_left - loc_int
        margin_right = interp.margin_right + loc_int

        samples_checked = make_dask_array_1d(samples)
        samples_padleft = self._padding_left(samples_checked, margin_left)
        samples_padded = self._padding_right(samples_padleft, margin_right)

        delayed_op = dask.delayed(interp.apply)
        results = []
        pos = 0
        for chunk_shape in samples.chunks[0]:
            n_size = interp.margin_left + chunk_shape + interp.margin_right
            samples_shifted = delayed_op(samples_padded[pos : pos + n_size])
            delayed_chunk = da.from_delayed(
                samples_shifted, (chunk_shape,), samples.dtype
            )
            results.append(delayed_chunk)
            pos += chunk_shape

        return make_dask_array_1d(da.concatenate(results, axis=0))


def make_fixed_shift_lagrange_dask(
    left_bound: DynShiftBC, right_bound: DynShiftBC, length: int
) -> FixedShiftDask:
    """Create a FixedShiftDask instance that uses Lagrange interpolator

    Arguments:
        left_bound: boundary treatment on the left
        right_bound: boundary treatment on the right
        length: number of Lagrange plolynomials (=order + 1)

    Returns:
        Fixed shift interpolator
    """
    fac = FixedShiftLagrange.factory(length)
    return FixedShiftDask(left_bound, right_bound, fac)