"""Functions for inverting a 1D (time)coordinate transform given as numpy array The inversion function internally requires an interpolation operator implementing the RegularInterpolator interface, and which is provided by the user. Use make_shift_inverse_lagrange_numpy to create an inversion operator employing Lagrange interpolation. """ from __future__ import annotations from typing import Callable, Final import numpy as np from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d from lisainstrument.regular_interpolators import ( RegularInterpolator, make_regular_interpolator_lagrange, ) def fixed_point_iter( f: Callable[[NumpyArray1D], NumpyArray1D], ferr: Callable[[NumpyArray1D, NumpyArray1D], float], x: NumpyArray1D, tolerance: float, max_iter: int, ) -> NumpyArray1D: r"""Perform a fixed point iteration for functions operating on a 1D array. This uses fixed-point iteration to find a solution for $$ x = f(x) $$ where $x$ is a 1D array and $f$ returns an array of the same size. The convergence criterion is provided by the user via a function $r(x)$ that returns a scalar error measure. The iteration is performed until $r(x) < \epsilon$. If convergence is not achieved after a given number of iterations, an exception is raised. Arguments: f: The function $f(x)$ ferr: The error measure $r(x)$ x: The initial data for the iteration. tolerance: The tolerance $\epsilon$ max_iter: Maximum number of iterations Returns: Array with solution """ for _ in range(max_iter): x_next = f(x) err = ferr(x, x_next) if err < tolerance: return x_next x = x_next msg = ( f"ShiftInverseNumpy: iteration did not converge (error={err}, " f"tolerance={tolerance}), iterations={max_iter}" ) raise RuntimeError(msg) class ShiftInverseNumpy: # pylint: disable=too-few-public-methods """Invert coordinate transformation given as shift""" def __init__( self, max_abs_shift: float, interp: RegularInterpolator, max_iter: int, tolerance: float, ): """Set up interpolator. Arguments: max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space interp: Interpolation method max_iter: Maximum iterations before fail tolerance: Maximum absolute error of result """ self._max_abs_shift: Final = int(np.ceil(max_abs_shift)) self._interp_np: Final = interp self._max_iter = int(max_iter) self._tolerance = float(tolerance) @property def margin_left(self) -> int: """Left margin size. Specifies how many samples on the left have to be added by boundary conditions. """ return self._interp_np.margin_left + self._max_abs_shift @property def margin_right(self) -> int: """Right margin size. Specifies how many samples on the right have to be added by boundary conditions. """ return self._interp_np.margin_right + self._max_abs_shift def __call__(self, shift: np.ndarray, fsample: float) -> NumpyArray1D: """Find the shift for the inverse transformation given by a shift. Arguments: shift: 1D numpy array with shifts of the coordinate transform fsample: sample rate of shift array Returns: 1D numpy array with shift at transformed coordinate """ shift_idx = shift * fsample dx = make_numpy_array_1d(shift_idx) dx_pad = np.pad( dx, (self.margin_left, self.margin_right), mode="constant", constant_values=(shift_idx[0], shift_idx[-1]), ) def f_iter(x: NumpyArray1D) -> NumpyArray1D: return self._interp_np.apply_shift(dx_pad, -x, self.margin_left) def f_err(x1: NumpyArray1D, x2: NumpyArray1D) -> float: return np.max(np.abs(x1 - x2)[self.margin_left : -self.margin_right]) shift_idx_inv = fixed_point_iter( f_iter, f_err, dx, self._tolerance, self._max_iter ) shift_inv = shift_idx_inv / fsample return make_numpy_array_1d(shift_inv) def make_shift_inverse_lagrange_numpy( order: int, max_abs_shift: float, max_iter: int, tolerance: float, ) -> ShiftInverseNumpy: """Set up ShiftInverseNumpy instance with Lagrange interpolation method. Arguments: order: Order of the Lagrange polynomials max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space max_iter: Maximum iterations before fail tolerance: Maximum absolute error of result Returns: Inversion function of type ShiftInverseNumpy """ interp = make_regular_interpolator_lagrange(order) return ShiftInverseNumpy(max_abs_shift, interp, max_iter, tolerance)