Skip to content
Snippets Groups Projects
shift_inversion_numpy.py 3.8 KiB
Newer Older
"""Functions for inverting a 1D (time)coordinate transform given as numpy array

The inversion function internally requires an interpolation operator implementing the 
RegularInterpolator interface, and which is provided by the user. Use 
make_shift_inverse_lagrange_numpy to create an inversion operator employing Lagrange 
interpolation.
"""

from __future__ import annotations

# ~ from dataclasses import dataclass

# ~ from enum import Enum
from typing import Final

import numpy as np

# ~ from typing_extensions import assert_never

from lisainstrument.fir_filters_numpy import NumpyArray1D
from lisainstrument.regular_interpolators import (
    RegularInterpolator,
    make_regular_interpolator_lagrange,
    # ~ make_regular_interpolator_linear,
)


class ShiftInverseNumpy: # pylint: disable=too-few-public-methods
    """Invert coordinate transformation given as shift"""

    def __init__(
        self,
        max_abs_shift: float,
        interp: RegularInterpolator,
        max_iter: int,
        tolerance: float,
    ):
        """Set up interpolator.

        Arguments:
            max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
            interp: Interpolation method
            max_iter: Maximum iterations before fail
            tolerance: Maximum absolute error of result
        """
        self._max_abs_shift: Final = float(max_abs_shift)
        self._interp_np: Final = interp
        self._max_iter = int(max_iter)
        self._tolerance = float(tolerance)

    @property
    def _margin_left(self) -> int:
        """Left margin size.

        Specifies how many samples on the left have to be added by boundary conditions.
        """
        return self._interp_np.margin_left + self._max_abs_shift

    @property
    def _margin_right(self) -> int:
        """Right margin size.

        Specifies how many samples on the right have to be added by boundary conditions.
        """
        return self._interp_np.margin_right + self._max_abs_shift

    def _fixed_point_iter(self, f, x):
        for _ in range(self._max_iter):
            x_next = f(x)
            error = np.max(np.abs(x - x_next))
            if error < self._tolerance:
                return x_next
            x = x_next
        msg = (
            f"ShiftInverseNumpy: iteration did not converge ({error=}, "
            f"iterations={self._max_iter}, tolerance={self._tolerance})"
        )
        raise RuntimeError(msg)

    def __call__(self, shift: np.ndarray) -> NumpyArray1D:
        """Find the shift for the inverse transformation given by a shift.


        Arguments:
            shift: 1D numpy array with shifts of the coordinate transform

        Returns:
            1D numpy array with shift for inverse coordinate transform
        """

        shift_pad = np.pad(
            shift,
            (self._margin_left, self._margin_right),
            mode="constant",
            constant_values=(shift[0], shift[-1]),
        )

        def f_iter(x: np.ndarray) -> np.ndarray:
            return self._interp_np.apply_shift(shift_pad, -x, self._margin_left)

        return self._fixed_point_iter(f_iter, shift)


def make_shift_inverse_lagrange_numpy(
    order: int,
    max_abs_shift: float,
    max_iter: int,
    tolerance: float,
) -> ShiftInverseNumpy:
    """Set up ShiftInverseNumpy instance with Lagrange interpolation method.

    Arguments:
        order: Order of the Lagrange polynomials
        max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
        max_iter: Maximum iterations before fail
        tolerance: Maximum absolute error of result

    Returns:
        Inversion function of type ShiftInverseNumpy
    """
    interp = make_regular_interpolator_lagrange(order)
    return ShiftInverseNumpy(max_abs_shift, interp, max_iter, tolerance)