Newer
Older
"""
Dask-based time coordinate inversion
====================================
Functions for inverting a 1D (time)coordinate transform given as dask array.
This is the same functionality as provided by module shift_inversion_numpy for
numpy arrays.
.. autoclass:: ShiftInverseDask
:members:
:private-members:
:special-members:
There are two choices for the interpolator used internally. One is the
newly written Lagrange interpolator and the other is the Lagrange interpolator
from dsp.timeshift.
.. autofunction:: make_shift_inverse_lagrange_dask
.. autofunction:: make_shift_inverse_dsp_dask
"""
from __future__ import annotations
from typing import Final
import dask
import dask.array as da
import numpy as np
from lisainstrument.dynamic_delay_dsp import make_regular_interpolator_dsp
from lisainstrument.fir_filters_dask import DaskArray1D, make_dask_array_1d
from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d
from lisainstrument.regular_interpolators import (
RegularInterpolator,
make_regular_interpolator_lagrange,
)
from lisainstrument.shift_inversion_numpy import fixed_point_iter
class ShiftInverseDask:
"""Invert time coordinate transformation given as dask array
See shift_inversion_numpy.ShiftInverseNumpy for details.
"""
fsample: float,
max_abs_shift: float,
interp: RegularInterpolator,
max_iter: int,
tolerance: float,
):
r"""Set up the coordinate inversion operator.
See shift_inversion_numpy.ShiftInverseNumpy for details
fsample: Sample rate :math:`f_s > 0` [s]
max_abs_shift: Upper limit :math:`S_\mathrm{max} \ge 0` [s] for coordinate shift
interp: Interpolation operator
max_iter: Maximum iterations before fail
tolerance: Maximum absolute error [s] of result
self._fsample: Final = float(fsample)
self._max_abs_shift_idx: Final = int(np.ceil(max_abs_shift * self._fsample))
self._interp_np: Final = interp
self._max_iter = int(max_iter)
self._tolerance_idx = float(tolerance * self._fsample)
if self._fsample <= 0:
msg = f"ShiftInverseDask: fsample must be strictly positive, got {fsample}"
raise ValueError(msg)
if max_abs_shift < 0:
msg = (
f"ShiftInverseDask: max_abs_shift must be positive, got {max_abs_shift}"
)
raise ValueError(msg)
if self._max_iter <= 0:
msg = f"ShiftInverseDask: max_iter must be strictly positive integer, got {max_iter}"
raise ValueError(msg)
@property
def margin_left(self) -> int:
"""Left margin size.
Specifies how many samples on the left have to be added by boundary conditions.
"""
return self._interp_np.margin_left + self._max_abs_shift_idx
@property
def margin_right(self) -> int:
"""Right margin size.
Specifies how many samples on the right have to be added by boundary conditions.
"""
return self._interp_np.margin_right + self._max_abs_shift_idx
def _fixed_point_iter(self, dx_pad: np.ndarray, dx: np.ndarray) -> NumpyArray1D:
"""Call the fixed point iteration on a single chunk
The input shift includes margin points in addition to the corresponding
output points. One needs to specify initial data for the iteration, with
same shape as the result. Both are shorter than the input shift by the
left plus right margin size.
Arguments:
dx_pad: numpy array with shift
dx: initial value for iteration
Returns:
Converged iteration result
"""
def f_iter(x: NumpyArray1D) -> NumpyArray1D:
return self._interp_np.apply_shift(
make_numpy_array_1d(dx_pad), -x, self.margin_left
)
def f_err(x1: NumpyArray1D, x2: NumpyArray1D) -> float:
f_iter, f_err, make_numpy_array_1d(dx), self._tolerance_idx, self._max_iter
def __call__(self, shift: da.Array) -> DaskArray1D:
"""Compute the inverse coordinate transform
See shift_inversion_numpy.ShiftInverseNumpy for details
shift: 1D dask array with shifts of the coordinate transform [s]
1D dask array with shift [s] at transformed coordinate
shift_idx = shift * self._fsample
make_dask_array_1d(shift_idx)
dx_pad = da.pad(
shift_idx,
(self.margin_left, self.margin_right),
mode="edge",
# ~ constant_values=(shift_idx[0], shift_idx[-1]),
)
results = []
chunks = shift_idx.to_delayed()
delayed_op = dask.delayed(self._fixed_point_iter)
pos = 0
for chunk, chunk_shape in zip(chunks, shift.chunks[0], strict=True):
n_size = self.margin_left + chunk_shape + self.margin_right
n_first = pos
samples_needed = dx_pad[n_first : n_first + n_size]
samples_shifted = delayed_op(samples_needed, chunk)
delayed_chunk = da.from_delayed(
samples_shifted, (chunk_shape,), shift_idx.dtype
)
results.append(delayed_chunk)
pos += chunk_shape
shift_idx_inv = da.concatenate(results, axis=0)
shift_inv = shift_idx_inv / self._fsample
return make_dask_array_1d(shift_inv)
def make_shift_inverse_lagrange_dask(
order: int,
fsample: float,
max_abs_shift: float,
max_iter: int,
tolerance: float,
) -> ShiftInverseDask:
r"""Set up ShiftInverseDask instance with new Lagrange interpolation method.
Arguments:
order: Order of the Lagrange polynomials
fsample: Sample rate :math:`f_s > 0` [s]
max_abs_shift: Upper limit :math:`S_\mathrm{max} \ge 0` [s] for coordinate shift
max_iter: Maximum iterations before fail
tolerance: Maximum absolute error of result
Returns:
Inversion function of type ShiftInverseDask
"""
interp = make_regular_interpolator_lagrange(order)
return ShiftInverseDask(
fsample=fsample,
max_abs_shift=max_abs_shift,
interp=interp,
max_iter=max_iter,
tolerance=tolerance,
)
def make_shift_inverse_dsp_dask(
order: int,
fsample: float,
max_abs_shift: float,
max_iter: int,
tolerance: float,
) -> ShiftInverseDask:
r"""Set up ShiftInverseDask instance using dsp.timeshift as interpolation method.
Arguments:
order: Order of the Lagrange polynomials
fsample: Sample rate :math:`f_s > 0` [s]
max_abs_shift: Upper limit :math:`S_\mathrm{max} \ge 0` [s] for coordinate shift
max_iter: Maximum iterations before fail
tolerance: Maximum absolute error of result
Returns:
Inversion function of type ShiftInverseDask
"""
interp = make_regular_interpolator_dsp(order)
return ShiftInverseDask(
fsample=fsample,
max_abs_shift=max_abs_shift,
interp=interp,
max_iter=max_iter,
tolerance=tolerance,
)