"""Functions for applying dynamic real-valued shifts to dask arrays using Lagrange interpolation Use make_dynamic_shift_lagrange_dask to create a Lagrange interpolator for dask arrays. """ from __future__ import annotations from typing import Final import dask import dask.array as da import numpy as np from lisainstrument.fir_filters_dask import DaskArray1D, make_dask_array_1d from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d from lisainstrument.regular_interpolators import ( RegularInterpolator, make_regular_interpolator_lagrange, ) from lisainstrument.shift_inversion_numpy import fixed_point_iter class ShiftInverseDask: """Invert coordinate transformation given as shift""" def __init__( self, max_abs_shift: float, interp: RegularInterpolator, max_iter: int, tolerance: float, ): """Set up interpolator. Arguments: max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space interp: Interpolation method max_iter: Maximum iterations before fail tolerance: Maximum absolute error of result """ self._max_abs_shift: Final = int(np.ceil(max_abs_shift)) if self._max_abs_shift < 0: msg = ( f"ShiftInverseDask: max_abs_shift must be positive, got {max_abs_shift}" ) raise ValueError(msg) self._interp_np: Final = interp self._max_iter = int(max_iter) self._tolerance = float(tolerance) @property def margin_left(self) -> int: """Left margin size. Specifies how many samples on the left have to be added by boundary conditions. """ return self._interp_np.margin_left + self._max_abs_shift @property def margin_right(self) -> int: """Right margin size. Specifies how many samples on the right have to be added by boundary conditions. """ return self._interp_np.margin_right + self._max_abs_shift def _fixed_point_iter(self, dx_pad: np.ndarray, dx: np.ndarray) -> NumpyArray1D: def f_iter(x: NumpyArray1D) -> NumpyArray1D: return self._interp_np.apply_shift( make_numpy_array_1d(dx_pad), -x, self.margin_left ) def f_err(x1: NumpyArray1D, x2: NumpyArray1D) -> float: return np.max(np.abs(x1 - x2)[self.margin_left : -self.margin_right]) return fixed_point_iter( f_iter, f_err, make_numpy_array_1d(dx), self._tolerance, self._max_iter ) def __call__(self, shift: da.Array, fsample: float) -> DaskArray1D: """Find the shift for the inverse transformation given by a shift. Arguments: shift: 1D dask array with shifts of the coordinate transform fsample: sample rate of shift array Returns: 1D dask array with shift at transformed coordinate """ shift_idx = shift * fsample make_dask_array_1d(shift_idx) dx_pad = da.pad( shift_idx, (self.margin_left, self.margin_right), mode="edge", # ~ constant_values=(shift_idx[0], shift_idx[-1]), ) results = [] chunks = shift_idx.to_delayed() delayed_op = dask.delayed(self._fixed_point_iter) pos = 0 for chunk, chunk_shape in zip(chunks, shift.chunks[0], strict=True): n_size = self.margin_left + chunk_shape + self.margin_right n_first = pos samples_needed = dx_pad[n_first : n_first + n_size] samples_shifted = delayed_op(samples_needed, chunk) delayed_chunk = da.from_delayed( samples_shifted, (chunk_shape,), shift_idx.dtype ) results.append(delayed_chunk) pos += chunk_shape shift_idx_inv = da.concatenate(results, axis=0) shift_inv = shift_idx_inv / fsample return make_dask_array_1d(shift_inv) def make_shift_inverse_lagrange_dask( order: int, max_abs_shift: float, max_iter: int, tolerance: float, ) -> ShiftInverseDask: """Set up ShiftInverseDask instance with Lagrange interpolation method. Arguments: order: Order of the Lagrange polynomials max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space max_iter: Maximum iterations before fail tolerance: Maximum absolute error of result Returns: Inversion function of type ShiftInverseNumpy """ interp = make_regular_interpolator_lagrange(order) return ShiftInverseDask(max_abs_shift, interp, max_iter, tolerance)