Skip to content
Snippets Groups Projects
shift_inversion_numpy.py 4.95 KiB
Newer Older
"""Functions for inverting a 1D (time)coordinate transform given as numpy array

Wolfgang Kastaun's avatar
Wolfgang Kastaun committed
The inversion function internally requires an interpolation operator implementing the
RegularInterpolator interface, and which is provided by the user. Use
make_shift_inverse_lagrange_numpy to create an inversion operator employing Lagrange
interpolation.
"""

from __future__ import annotations

from typing import Callable, Final
from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d
from lisainstrument.regular_interpolators import (
    RegularInterpolator,
    make_regular_interpolator_lagrange,
)


def fixed_point_iter(
    f: Callable[[NumpyArray1D], NumpyArray1D],
    ferr: Callable[[NumpyArray1D, NumpyArray1D], float],
    x: NumpyArray1D,
    tolerance: float,
    max_iter: int,
) -> NumpyArray1D:
    r"""Perform a fixed point iteration for functions operating on a 1D array.

    This uses fixed-point iteration to find a solution for
    $$ x = f(x) $$
    where $x$ is a 1D array and $f$ returns an array of the same size.

    The convergence criterion is provided by the user
    via a function $r(x)$ that returns a scalar error measure. The iteration
    is performed until $r(x) < \epsilon$. If convergence is not achieved
    after a given number of iterations, an exception is raised.

    Arguments:
        f: The function $f(x)$
        ferr: The error measure $r(x)$
        x: The initial data for the iteration.
        tolerance: The tolerance $\epsilon$
        max_iter: Maximum number of iterations

    Returns:
        Array with solution
    """
    for _ in range(max_iter):
        x_next = f(x)
        err = ferr(x, x_next)
        if err < tolerance:
            return x_next
        x = x_next
    msg = (
        f"ShiftInverseNumpy: iteration did not converge (error={err}, "
        f"tolerance={tolerance}), iterations={max_iter}"
    )
    raise RuntimeError(msg)


Wolfgang Kastaun's avatar
Wolfgang Kastaun committed
class ShiftInverseNumpy:  # pylint: disable=too-few-public-methods
    """Invert coordinate transformation given as shift"""

    def __init__(
        self,
        max_abs_shift: float,
        interp: RegularInterpolator,
        max_iter: int,
        tolerance: float,
    ):
        """Set up interpolator.

        Arguments:
            max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
            interp: Interpolation method
            max_iter: Maximum iterations before fail
            tolerance: Maximum absolute error of result
        """
        self._max_abs_shift: Final = int(np.ceil(max_abs_shift))
        self._interp_np: Final = interp
        self._max_iter = int(max_iter)
        self._tolerance = float(tolerance)

    @property
        """Left margin size.

        Specifies how many samples on the left have to be added by boundary conditions.
        """
        return self._interp_np.margin_left + self._max_abs_shift

    @property
        """Right margin size.

        Specifies how many samples on the right have to be added by boundary conditions.
        """
        return self._interp_np.margin_right + self._max_abs_shift

    def __call__(self, shift: np.ndarray, fsample: float) -> NumpyArray1D:
        """Find the shift for the inverse transformation given by a shift.

        Arguments:
            shift: 1D numpy array with shifts of the coordinate transform
            1D numpy array with shift at transformed coordinate

        shift_idx = shift * fsample
        dx = make_numpy_array_1d(shift_idx)
        dx_pad = np.pad(
            dx,
            constant_values=(shift_idx[0], shift_idx[-1]),
        def f_iter(x: NumpyArray1D) -> NumpyArray1D:
            return self._interp_np.apply_shift(dx_pad, -x, self.margin_left)

        def f_err(x1: NumpyArray1D, x2: NumpyArray1D) -> float:
            return np.max(np.abs(x1 - x2)[self.margin_left : -self.margin_right])

        shift_idx_inv = fixed_point_iter(
            f_iter, f_err, dx, self._tolerance, self._max_iter
        )


def make_shift_inverse_lagrange_numpy(
    order: int,
    max_abs_shift: float,
    max_iter: int,
    tolerance: float,
) -> ShiftInverseNumpy:
    """Set up ShiftInverseNumpy instance with Lagrange interpolation method.

    Arguments:
        order: Order of the Lagrange polynomials
        max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
        max_iter: Maximum iterations before fail
        tolerance: Maximum absolute error of result

    Returns:
        Inversion function of type ShiftInverseNumpy
    """
    interp = make_regular_interpolator_lagrange(order)
    return ShiftInverseNumpy(max_abs_shift, interp, max_iter, tolerance)