"""Functions for applying dynamic real-valued shifts to dask arrays using Lagrange interpolation Use make_dynamic_shift_lagrange_dask to create a Lagrange interpolator for dask arrays. """ from __future__ import annotations from typing import Final import dask import dask.array as da import numpy as np from lisainstrument.dynamic_delay_dsp import make_regular_interpolator_dsp from lisainstrument.fir_filters_dask import DaskArray1D, make_dask_array_1d from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d from lisainstrument.regular_interpolators import ( RegularInterpolator, make_regular_interpolator_lagrange, ) from lisainstrument.shift_inversion_numpy import fixed_point_iter class ShiftInverseDask: """Invert time coordinate transformation given as dask array See shift_inversion_numpy.ShiftInverseNumpy for details. """ def __init__( self, fsample: float, max_abs_shift: float, interp: RegularInterpolator, max_iter: int, tolerance: float, ): r"""Set up the coordinate inversion operator. See shift_inversion_numpy.ShiftInverseNumpy for details Arguments: fsample: Sample rate $f_s > 0$ [s] max_abs_shift: Upper limit $S_\mathrm{max} \ge 0 $ [s] for coordinate shift interp: Interpolation operator max_iter: Maximum iterations before fail tolerance: Maximum absolute error [s] of result """ self._fsample: Final = float(fsample) self._max_abs_shift_idx: Final = int(np.ceil(max_abs_shift * self._fsample)) self._interp_np: Final = interp self._max_iter = int(max_iter) self._tolerance_idx = float(tolerance * self._fsample) if self._fsample <= 0: msg = f"ShiftInverseDask: fsample must be strictly positive, got {fsample}" raise ValueError(msg) if max_abs_shift < 0: msg = ( f"ShiftInverseDask: max_abs_shift must be positive, got {max_abs_shift}" ) raise ValueError(msg) if self._max_iter <= 0: msg = f"ShiftInverseDask: max_iter must be strictly positive integer, got {max_iter}" raise ValueError(msg) @property def margin_left(self) -> int: """Left margin size. Specifies how many samples on the left have to be added by boundary conditions. """ return self._interp_np.margin_left + self._max_abs_shift_idx @property def margin_right(self) -> int: """Right margin size. Specifies how many samples on the right have to be added by boundary conditions. """ return self._interp_np.margin_right + self._max_abs_shift_idx def _fixed_point_iter(self, dx_pad: np.ndarray, dx: np.ndarray) -> NumpyArray1D: """Call the fixed point iteration on a single chunk The input shift includes margin points in addition to the corresponding output points. One needs to specify initial data for the iteration, with same shape as the result. Both are shorter than the input shift by the left plus right margin size. Arguments: dx_pad: numpy array with shift dx: initial value for iteration Returns: Converged iteration result """ def f_iter(x: NumpyArray1D) -> NumpyArray1D: return self._interp_np.apply_shift( make_numpy_array_1d(dx_pad), -x, self.margin_left ) def f_err(x1: NumpyArray1D, x2: NumpyArray1D) -> float: return np.max(np.abs(x1 - x2)) return fixed_point_iter( f_iter, f_err, make_numpy_array_1d(dx), self._tolerance_idx, self._max_iter ) def __call__(self, shift: da.Array) -> DaskArray1D: """Compute the inverse coordinate transform See shift_inversion_numpy.ShiftInverseNumpy for details Arguments: shift: 1D dask array with shifts of the coordinate transform [s] Returns: 1D dask array with shift [s] at transformed coordinate """ shift_idx = shift * self._fsample make_dask_array_1d(shift_idx) dx_pad = da.pad( shift_idx, (self.margin_left, self.margin_right), mode="edge", # ~ constant_values=(shift_idx[0], shift_idx[-1]), ) results = [] chunks = shift_idx.to_delayed() delayed_op = dask.delayed(self._fixed_point_iter) pos = 0 for chunk, chunk_shape in zip(chunks, shift.chunks[0], strict=True): n_size = self.margin_left + chunk_shape + self.margin_right n_first = pos samples_needed = dx_pad[n_first : n_first + n_size] samples_shifted = delayed_op(samples_needed, chunk) delayed_chunk = da.from_delayed( samples_shifted, (chunk_shape,), shift_idx.dtype ) results.append(delayed_chunk) pos += chunk_shape shift_idx_inv = da.concatenate(results, axis=0) shift_inv = shift_idx_inv / self._fsample return make_dask_array_1d(shift_inv) def make_shift_inverse_lagrange_dask( order: int, fsample: float, max_abs_shift: float, max_iter: int, tolerance: float, ) -> ShiftInverseDask: r"""Set up ShiftInverseDask instance with new Lagrange interpolation method. Arguments: order: Order of the Lagrange polynomials fsample: Sample rate $f_s > 0$ [s] max_abs_shift: Upper limit $S_\mathrm{max} \ge 0 $ [s] for coordinate shift max_iter: Maximum iterations before fail tolerance: Maximum absolute error of result Returns: Inversion function of type ShiftInverseDask """ interp = make_regular_interpolator_lagrange(order) return ShiftInverseDask( fsample=fsample, max_abs_shift=max_abs_shift, interp=interp, max_iter=max_iter, tolerance=tolerance, ) def make_shift_inverse_dsp_dask( order: int, fsample: float, max_abs_shift: float, max_iter: int, tolerance: float, ) -> ShiftInverseDask: r"""Set up ShiftInverseDask instance using dsp.timeshift as interpolation method. Arguments: order: Order of the Lagrange polynomials fsample: Sample rate $f_s > 0$ [s] max_abs_shift: Upper limit $S_\mathrm{max} \ge 0 $ [s] for coordinate shift max_iter: Maximum iterations before fail tolerance: Maximum absolute error of result Returns: Inversion function of type ShiftInverseDask """ interp = make_regular_interpolator_dsp(order) return ShiftInverseDask( fsample=fsample, max_abs_shift=max_abs_shift, interp=interp, max_iter=max_iter, tolerance=tolerance, )