Skip to content
Snippets Groups Projects
shift_inversion_dask.py 7.22 KiB
Newer Older
"""
Dask-based time coordinate inversion
====================================

Functions for inverting a 1D (time)coordinate transform given as dask array.
This is the same functionality as provided by module shift_inversion_numpy for
numpy arrays.


.. autoclass:: ShiftInverseDask
    :members:
    :private-members:
    :special-members:


There are two choices for the interpolator used internally. One is the
newly written Lagrange interpolator and the other is the Lagrange interpolator
from dsp.timeshift.

.. autofunction:: make_shift_inverse_lagrange_dask
.. autofunction:: make_shift_inverse_dsp_dask

"""

from __future__ import annotations

from typing import Final

import dask
import dask.array as da
import numpy as np

Wolfgang Kastaun's avatar
Wolfgang Kastaun committed
from lisainstrument.dynamic_delay_dsp import make_regular_interpolator_dsp
from lisainstrument.fir_filters_dask import DaskArray1D, make_dask_array_1d
from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d
from lisainstrument.regular_interpolators import (
    RegularInterpolator,
    make_regular_interpolator_lagrange,
)
from lisainstrument.shift_inversion_numpy import fixed_point_iter


class ShiftInverseDask:
    """Invert time coordinate transformation given as dask array

    See shift_inversion_numpy.ShiftInverseNumpy for details.
    """
        max_abs_shift: float,
        interp: RegularInterpolator,
        max_iter: int,
        tolerance: float,
    ):
        r"""Set up the coordinate inversion operator.

        See shift_inversion_numpy.ShiftInverseNumpy for details
            fsample: Sample rate :math:`f_s > 0` [s]
            max_abs_shift: Upper limit :math:`S_\mathrm{max} \ge 0` [s] for coordinate shift
            interp: Interpolation operator
            max_iter: Maximum iterations before fail
            tolerance: Maximum absolute error [s] of result
        self._fsample: Final = float(fsample)
        self._max_abs_shift_idx: Final = int(np.ceil(max_abs_shift * self._fsample))

        self._interp_np: Final = interp
        self._max_iter = int(max_iter)
        self._tolerance_idx = float(tolerance * self._fsample)

        if self._fsample <= 0:
            msg = f"ShiftInverseDask: fsample must be strictly positive, got {fsample}"
            raise ValueError(msg)

        if max_abs_shift < 0:
            msg = (
                f"ShiftInverseDask: max_abs_shift must be positive, got {max_abs_shift}"
            )
            raise ValueError(msg)

        if self._max_iter <= 0:
            msg = f"ShiftInverseDask: max_iter must be strictly positive integer, got {max_iter}"
            raise ValueError(msg)

    @property
    def margin_left(self) -> int:
        """Left margin size.

        Specifies how many samples on the left have to be added by boundary conditions.
        """
        return self._interp_np.margin_left + self._max_abs_shift_idx

    @property
    def margin_right(self) -> int:
        """Right margin size.

        Specifies how many samples on the right have to be added by boundary conditions.
        """
        return self._interp_np.margin_right + self._max_abs_shift_idx

    def _fixed_point_iter(self, dx_pad: np.ndarray, dx: np.ndarray) -> NumpyArray1D:
        """Call the fixed point iteration on a single chunk

        The input shift includes margin points in addition to the corresponding
        output points. One needs to specify initial data for the iteration, with
        same shape as the result. Both are shorter than the input shift by the
        left plus right margin size.

        Arguments:
            dx_pad: numpy array with shift
            dx: initial value for iteration

        Returns:
            Converged iteration result
        """

        def f_iter(x: NumpyArray1D) -> NumpyArray1D:
            return self._interp_np.apply_shift(
                make_numpy_array_1d(dx_pad), -x, self.margin_left
            )

        def f_err(x1: NumpyArray1D, x2: NumpyArray1D) -> float:
            return np.max(np.abs(x1 - x2))

        return fixed_point_iter(
            f_iter, f_err, make_numpy_array_1d(dx), self._tolerance_idx, self._max_iter
    def __call__(self, shift: da.Array) -> DaskArray1D:
        """Compute the inverse coordinate transform

        See shift_inversion_numpy.ShiftInverseNumpy for details
            shift: 1D dask array with shifts of the coordinate transform [s]
            1D dask array with shift [s] at transformed coordinate
        shift_idx = shift * self._fsample
        make_dask_array_1d(shift_idx)

        dx_pad = da.pad(
            shift_idx,
            (self.margin_left, self.margin_right),
            mode="edge",
            # ~ constant_values=(shift_idx[0], shift_idx[-1]),
        )

        results = []

        chunks = shift_idx.to_delayed()
        delayed_op = dask.delayed(self._fixed_point_iter)
        pos = 0
        for chunk, chunk_shape in zip(chunks, shift.chunks[0], strict=True):
            n_size = self.margin_left + chunk_shape + self.margin_right
            n_first = pos
            samples_needed = dx_pad[n_first : n_first + n_size]
            samples_shifted = delayed_op(samples_needed, chunk)
            delayed_chunk = da.from_delayed(
                samples_shifted, (chunk_shape,), shift_idx.dtype
            )
            results.append(delayed_chunk)
            pos += chunk_shape

        shift_idx_inv = da.concatenate(results, axis=0)
        shift_inv = shift_idx_inv / self._fsample

        return make_dask_array_1d(shift_inv)


def make_shift_inverse_lagrange_dask(
    order: int,
    max_abs_shift: float,
    max_iter: int,
    tolerance: float,
) -> ShiftInverseDask:
    r"""Set up ShiftInverseDask instance with new Lagrange interpolation method.

    Arguments:
        order: Order of the Lagrange polynomials
        fsample: Sample rate :math:`f_s > 0` [s]
        max_abs_shift: Upper limit :math:`S_\mathrm{max} \ge 0` [s] for coordinate shift
        max_iter: Maximum iterations before fail
        tolerance: Maximum absolute error of result
    Returns:
        Inversion function of type ShiftInverseDask
    """
    interp = make_regular_interpolator_lagrange(order)
    return ShiftInverseDask(
        fsample=fsample,
        max_abs_shift=max_abs_shift,
        interp=interp,
        max_iter=max_iter,
        tolerance=tolerance,
    )
Wolfgang Kastaun's avatar
Wolfgang Kastaun committed

def make_shift_inverse_dsp_dask(
    order: int,
    max_abs_shift: float,
    max_iter: int,
    tolerance: float,
) -> ShiftInverseDask:
    r"""Set up ShiftInverseDask instance using dsp.timeshift as interpolation method.

    Arguments:
        order: Order of the Lagrange polynomials
        fsample: Sample rate :math:`f_s > 0` [s]
        max_abs_shift: Upper limit :math:`S_\mathrm{max} \ge 0` [s] for coordinate shift
        max_iter: Maximum iterations before fail
        tolerance: Maximum absolute error of result

    Returns:
        Inversion function of type ShiftInverseDask
    """
    interp = make_regular_interpolator_dsp(order)
    return ShiftInverseDask(
        fsample=fsample,
        max_abs_shift=max_abs_shift,
        interp=interp,
        max_iter=max_iter,
        tolerance=tolerance,
    )