Skip to content
Snippets Groups Projects
shift_inversion_dask.py 5.42 KiB
Newer Older
"""Functions for applying dynamic real-valued shifts to dask arrays using Lagrange interpolation

Use make_dynamic_shift_lagrange_dask to create a Lagrange interpolator for dask arrays.
"""

from __future__ import annotations

from typing import Final

import dask
import dask.array as da
import numpy as np

from lisainstrument.fir_filters_dask import DaskArray1D, make_dask_array_1d
from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d
from lisainstrument.regular_interpolators import (
    RegularInterpolator,
    make_regular_interpolator_lagrange,
)
from lisainstrument.dynamic_delay_dsp import make_regular_interpolator_dsp
from lisainstrument.shift_inversion_numpy import fixed_point_iter


class ShiftInverseDask:
    """Invert coordinate transformation given as shift"""

    def __init__(
        self,
        max_abs_shift: float,
        interp: RegularInterpolator,
        max_iter: int,
        tolerance: float,
    ):
        """Set up interpolator.

        Arguments:
            max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
            interp: Interpolation method
            max_iter: Maximum iterations before fail
            tolerance: Maximum absolute error of result
        """

        self._max_abs_shift: Final = int(np.ceil(max_abs_shift))
        if self._max_abs_shift < 0:
            msg = (
                f"ShiftInverseDask: max_abs_shift must be positive, got {max_abs_shift}"
            )
            raise ValueError(msg)

        self._interp_np: Final = interp
        self._max_iter = int(max_iter)
        self._tolerance = float(tolerance)

    @property
    def margin_left(self) -> int:
        """Left margin size.

        Specifies how many samples on the left have to be added by boundary conditions.
        """
        return self._interp_np.margin_left + self._max_abs_shift

    @property
    def margin_right(self) -> int:
        """Right margin size.

        Specifies how many samples on the right have to be added by boundary conditions.
        """
        return self._interp_np.margin_right + self._max_abs_shift

    def _fixed_point_iter(
        self, dx_pad: np.ndarray, dx: np.ndarray, tol: float
    ) -> NumpyArray1D:
        def f_iter(x: NumpyArray1D) -> NumpyArray1D:
            return self._interp_np.apply_shift(
                make_numpy_array_1d(dx_pad), -x, self.margin_left
            )

        def f_err(x1: NumpyArray1D, x2: NumpyArray1D) -> float:
            return np.max(np.abs(x1 - x2))

        return fixed_point_iter(
            f_iter, f_err, make_numpy_array_1d(dx), tol, self._max_iter
        )

    def __call__(self, shift: da.Array, fsample: float) -> DaskArray1D:
        """Find the shift for the inverse transformation given by a shift.

        Arguments:
            shift: 1D dask array with shifts of the coordinate transform
            fsample: sample rate of shift array

        Returns:
            1D dask array with shift at transformed coordinate
        """

        shift_idx = shift * fsample
        make_dask_array_1d(shift_idx)

        tol_idx = self._tolerance * fsample

        dx_pad = da.pad(
            shift_idx,
            (self.margin_left, self.margin_right),
            mode="edge",
            # ~ constant_values=(shift_idx[0], shift_idx[-1]),
        )

        results = []

        chunks = shift_idx.to_delayed()
        delayed_op = dask.delayed(self._fixed_point_iter)
        pos = 0
        for chunk, chunk_shape in zip(chunks, shift.chunks[0], strict=True):
            n_size = self.margin_left + chunk_shape + self.margin_right
            n_first = pos
            samples_needed = dx_pad[n_first : n_first + n_size]
            samples_shifted = delayed_op(samples_needed, chunk, tol_idx)
            delayed_chunk = da.from_delayed(
                samples_shifted, (chunk_shape,), shift_idx.dtype
            )
            results.append(delayed_chunk)
            pos += chunk_shape

        shift_idx_inv = da.concatenate(results, axis=0)
        shift_inv = shift_idx_inv / fsample

        return make_dask_array_1d(shift_inv)


def make_shift_inverse_lagrange_dask(
    order: int,
    max_abs_shift: float,
    max_iter: int,
    tolerance: float,
) -> ShiftInverseDask:
    """Set up ShiftInverseDask instance with Lagrange interpolation method.

    Arguments:
        order: Order of the Lagrange polynomials
        max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
        max_iter: Maximum iterations before fail
        tolerance: Maximum absolute error of result

    Returns:
        Inversion function of type ShiftInverseNumpy
    """
    interp = make_regular_interpolator_lagrange(order)
    return ShiftInverseDask(max_abs_shift, interp, max_iter, tolerance)

def make_shift_inverse_dsp_dask(
    order: int,
    max_abs_shift: float,
    max_iter: int,
    tolerance: float,
) -> ShiftInverseDask:
    """Set up ShiftInverseDask instance using dsp.timeshift as interpolation method.

    Arguments:
        order: Order of the Lagrange polynomials
        max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
        max_iter: Maximum iterations before fail
        tolerance: Maximum absolute error of result

    Returns:
        Inversion function of type ShiftInverseNumpy
    """
    interp = make_regular_interpolator_dsp(order)
    return ShiftInverseDask(max_abs_shift, interp, max_iter, tolerance)