Skip to content
Snippets Groups Projects
shift_inversion_numpy.py 3.78 KiB
Newer Older
"""Functions for inverting a 1D (time)coordinate transform given as numpy array

Wolfgang Kastaun's avatar
Wolfgang Kastaun committed
The inversion function internally requires an interpolation operator implementing the
RegularInterpolator interface, and which is provided by the user. Use
make_shift_inverse_lagrange_numpy to create an inversion operator employing Lagrange
interpolation.
"""

from __future__ import annotations

from typing import Callable, Final
from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d
from lisainstrument.regular_interpolators import (
    RegularInterpolator,
    make_regular_interpolator_lagrange,
)


Wolfgang Kastaun's avatar
Wolfgang Kastaun committed
class ShiftInverseNumpy:  # pylint: disable=too-few-public-methods
    """Invert coordinate transformation given as shift"""

    def __init__(
        self,
        max_abs_shift: float,
        interp: RegularInterpolator,
        max_iter: int,
        tolerance: float,
    ):
        """Set up interpolator.

        Arguments:
            max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
            interp: Interpolation method
            max_iter: Maximum iterations before fail
            tolerance: Maximum absolute error of result
        """
        self._max_abs_shift: Final = int(np.ceil(max_abs_shift))
        self._interp_np: Final = interp
        self._max_iter = int(max_iter)
        self._tolerance = float(tolerance)

    @property
    def _margin_left(self) -> int:
        """Left margin size.

        Specifies how many samples on the left have to be added by boundary conditions.
        """
        return self._interp_np.margin_left + self._max_abs_shift

    @property
    def _margin_right(self) -> int:
        """Right margin size.

        Specifies how many samples on the right have to be added by boundary conditions.
        """
        return self._interp_np.margin_right + self._max_abs_shift

    def _fixed_point_iter(
        self, f: Callable[[NumpyArray1D], NumpyArray1D], x: NumpyArray1D
    ):
        for _ in range(self._max_iter):
            x_next = f(x)
            error = np.max(np.abs(x - x_next))
            if error < self._tolerance:
                return x_next
            x = x_next
        msg = (
            f"ShiftInverseNumpy: iteration did not converge ({error=}, "
            f"iterations={self._max_iter}, tolerance={self._tolerance})"
        )
        raise RuntimeError(msg)

    def __call__(self, shift: np.ndarray) -> NumpyArray1D:
        """Find the shift for the inverse transformation given by a shift.


        Arguments:
            shift: 1D numpy array with shifts of the coordinate transform

        Returns:
            1D numpy array with shift for inverse coordinate transform
        """
        dx = make_numpy_array_1d(shift)
        dx_pad = np.pad(
            dx,
            (self._margin_left, self._margin_right),
            mode="constant",
            constant_values=(shift[0], shift[-1]),
        )

        def f_iter(x: NumpyArray1D) -> NumpyArray1D:
            return self._interp_np.apply_shift(dx_pad, -x, self._margin_left)
        return self._fixed_point_iter(f_iter, dx)


def make_shift_inverse_lagrange_numpy(
    order: int,
    max_abs_shift: float,
    max_iter: int,
    tolerance: float,
) -> ShiftInverseNumpy:
    """Set up ShiftInverseNumpy instance with Lagrange interpolation method.

    Arguments:
        order: Order of the Lagrange polynomials
        max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
        max_iter: Maximum iterations before fail
        tolerance: Maximum absolute error of result

    Returns:
        Inversion function of type ShiftInverseNumpy
    """
    interp = make_regular_interpolator_lagrange(order)
    return ShiftInverseNumpy(max_abs_shift, interp, max_iter, tolerance)