"""Functions for inverting a 1D (time)coordinate transform given as numpy array The inversion function internally requires an interpolation operator implementing the RegularInterpolator interface, and which is provided by the user. Use make_shift_inverse_lagrange_numpy to create an inversion operator employing Lagrange interpolation. """ from __future__ import annotations # ~ from dataclasses import dataclass # ~ from enum import Enum from typing import Final import numpy as np # ~ from typing_extensions import assert_never from lisainstrument.fir_filters_numpy import NumpyArray1D from lisainstrument.regular_interpolators import ( RegularInterpolator, make_regular_interpolator_lagrange, # ~ make_regular_interpolator_linear, ) class ShiftInverseNumpy: # pylint: disable=too-few-public-methods """Invert coordinate transformation given as shift""" def __init__( self, max_abs_shift: float, interp: RegularInterpolator, max_iter: int, tolerance: float, ): """Set up interpolator. Arguments: max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space interp: Interpolation method max_iter: Maximum iterations before fail tolerance: Maximum absolute error of result """ self._max_abs_shift: Final = float(max_abs_shift) self._interp_np: Final = interp self._max_iter = int(max_iter) self._tolerance = float(tolerance) @property def _margin_left(self) -> int: """Left margin size. Specifies how many samples on the left have to be added by boundary conditions. """ return self._interp_np.margin_left + self._max_abs_shift @property def _margin_right(self) -> int: """Right margin size. Specifies how many samples on the right have to be added by boundary conditions. """ return self._interp_np.margin_right + self._max_abs_shift def _fixed_point_iter(self, f, x): for _ in range(self._max_iter): x_next = f(x) error = np.max(np.abs(x - x_next)) if error < self._tolerance: return x_next x = x_next msg = ( f"ShiftInverseNumpy: iteration did not converge ({error=}, " f"iterations={self._max_iter}, tolerance={self._tolerance})" ) raise RuntimeError(msg) def __call__(self, shift: np.ndarray) -> NumpyArray1D: """Find the shift for the inverse transformation given by a shift. Arguments: shift: 1D numpy array with shifts of the coordinate transform Returns: 1D numpy array with shift for inverse coordinate transform """ shift_pad = np.pad( shift, (self._margin_left, self._margin_right), mode="constant", constant_values=(shift[0], shift[-1]), ) def f_iter(x: np.ndarray) -> np.ndarray: return self._interp_np.apply_shift(shift_pad, -x, self._margin_left) return self._fixed_point_iter(f_iter, shift) def make_shift_inverse_lagrange_numpy( order: int, max_abs_shift: float, max_iter: int, tolerance: float, ) -> ShiftInverseNumpy: """Set up ShiftInverseNumpy instance with Lagrange interpolation method. Arguments: order: Order of the Lagrange polynomials max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space max_iter: Maximum iterations before fail tolerance: Maximum absolute error of result Returns: Inversion function of type ShiftInverseNumpy """ interp = make_regular_interpolator_lagrange(order) return ShiftInverseNumpy(max_abs_shift, interp, max_iter, tolerance)