Skip to content
Snippets Groups Projects
shift_inversion_numpy.py 4.21 KiB
Newer Older
"""Functions for inverting a 1D (time)coordinate transform given as numpy array

Wolfgang Kastaun's avatar
Wolfgang Kastaun committed
The inversion function internally requires an interpolation operator implementing the
RegularInterpolator interface, and which is provided by the user. Use
make_shift_inverse_lagrange_numpy to create an inversion operator employing Lagrange
interpolation.
"""

from __future__ import annotations

from typing import Callable, Final
from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d
from lisainstrument.regular_interpolators import (
    RegularInterpolator,
    make_regular_interpolator_lagrange,
)


Wolfgang Kastaun's avatar
Wolfgang Kastaun committed
class ShiftInverseNumpy:  # pylint: disable=too-few-public-methods
    """Invert coordinate transformation given as shift"""

    def __init__(
        self,
        max_abs_shift: float,
        interp: RegularInterpolator,
        max_iter: int,
        tolerance: float,
    ):
        """Set up interpolator.

        Arguments:
            max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
            interp: Interpolation method
            max_iter: Maximum iterations before fail
            tolerance: Maximum absolute error of result
        """
        self._max_abs_shift: Final = int(np.ceil(max_abs_shift))
        self._interp_np: Final = interp
        self._max_iter = int(max_iter)
        self._tolerance = float(tolerance)

    @property
        """Left margin size.

        Specifies how many samples on the left have to be added by boundary conditions.
        """
        return self._interp_np.margin_left + self._max_abs_shift

    @property
        """Right margin size.

        Specifies how many samples on the right have to be added by boundary conditions.
        """
        return self._interp_np.margin_right + self._max_abs_shift

    def _fixed_point_iter(
        self,
        f: Callable[[NumpyArray1D], NumpyArray1D],
        ferr: Callable[[NumpyArray1D, NumpyArray1D], float],
        x: NumpyArray1D,
    ) -> NumpyArray1D:
        for _ in range(self._max_iter):
            x_next = f(x)
            err = ferr(x, x_next)
            if err < self._tolerance:
                return x_next
            x = x_next
        msg = (
            f"ShiftInverseNumpy: iteration did not converge (error={err}, "
            f"tolerance={self._tolerance}), iterations={self._max_iter}"
    def __call__(self, shift: np.ndarray, fsample: float) -> NumpyArray1D:
        """Find the shift for the inverse transformation given by a shift.

        Arguments:
            shift: 1D numpy array with shifts of the coordinate transform
            1D numpy array with shift at transformed coordinate

        shift_idx = shift * fsample
        dx = make_numpy_array_1d(shift_idx)
        dx_pad = np.pad(
            dx,
            (self.margin_left, self.margin_right),
            constant_values=(shift_idx[0], shift_idx[-1]),
        def f_iter(x: NumpyArray1D) -> NumpyArray1D:
            return self._interp_np.apply_shift(dx_pad, -x, self.margin_left)

        def f_err(x1: NumpyArray1D, x2: NumpyArray1D) -> float:
            return np.max(np.abs(x1 - x2)[self.margin_left : -self.margin_right])

        shift_idx_inv = self._fixed_point_iter(f_iter, f_err, dx)
        shift_inv = shift_idx_inv / fsample


def make_shift_inverse_lagrange_numpy(
    order: int,
    max_abs_shift: float,
    max_iter: int,
    tolerance: float,
) -> ShiftInverseNumpy:
    """Set up ShiftInverseNumpy instance with Lagrange interpolation method.

    Arguments:
        order: Order of the Lagrange polynomials
        max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
        max_iter: Maximum iterations before fail
        tolerance: Maximum absolute error of result

    Returns:
        Inversion function of type ShiftInverseNumpy
    """
    interp = make_regular_interpolator_lagrange(order)
    return ShiftInverseNumpy(max_abs_shift, interp, max_iter, tolerance)