Newer
Older
"""Functions for applying dynamic real-valued shifts to dask arrays using Lagrange interpolation
Use make_dynamic_shift_lagrange_dask to create a Lagrange interpolator for dask arrays.
"""
from __future__ import annotations
from typing import Final
import dask
import dask.array as da
import numpy as np
from lisainstrument.fir_filters_dask import DaskArray1D, make_dask_array_1d
from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d
from lisainstrument.regular_interpolators import (
RegularInterpolator,
make_regular_interpolator_lagrange,
)
from lisainstrument.dynamic_delay_dsp import make_regular_interpolator_dsp
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from lisainstrument.shift_inversion_numpy import fixed_point_iter
class ShiftInverseDask:
"""Invert coordinate transformation given as shift"""
def __init__(
self,
max_abs_shift: float,
interp: RegularInterpolator,
max_iter: int,
tolerance: float,
):
"""Set up interpolator.
Arguments:
max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
interp: Interpolation method
max_iter: Maximum iterations before fail
tolerance: Maximum absolute error of result
"""
self._max_abs_shift: Final = int(np.ceil(max_abs_shift))
if self._max_abs_shift < 0:
msg = (
f"ShiftInverseDask: max_abs_shift must be positive, got {max_abs_shift}"
)
raise ValueError(msg)
self._interp_np: Final = interp
self._max_iter = int(max_iter)
self._tolerance = float(tolerance)
@property
def margin_left(self) -> int:
"""Left margin size.
Specifies how many samples on the left have to be added by boundary conditions.
"""
return self._interp_np.margin_left + self._max_abs_shift
@property
def margin_right(self) -> int:
"""Right margin size.
Specifies how many samples on the right have to be added by boundary conditions.
"""
return self._interp_np.margin_right + self._max_abs_shift
def _fixed_point_iter(
self, dx_pad: np.ndarray, dx: np.ndarray, tol: float
) -> NumpyArray1D:
def f_iter(x: NumpyArray1D) -> NumpyArray1D:
return self._interp_np.apply_shift(
make_numpy_array_1d(dx_pad), -x, self.margin_left
)
def f_err(x1: NumpyArray1D, x2: NumpyArray1D) -> float:
f_iter, f_err, make_numpy_array_1d(dx), tol, self._max_iter
)
def __call__(self, shift: da.Array, fsample: float) -> DaskArray1D:
"""Find the shift for the inverse transformation given by a shift.
Arguments:
shift: 1D dask array with shifts of the coordinate transform
fsample: sample rate of shift array
Returns:
1D dask array with shift at transformed coordinate
"""
shift_idx = shift * fsample
make_dask_array_1d(shift_idx)
tol_idx = self._tolerance * fsample
dx_pad = da.pad(
shift_idx,
(self.margin_left, self.margin_right),
mode="edge",
# ~ constant_values=(shift_idx[0], shift_idx[-1]),
)
results = []
chunks = shift_idx.to_delayed()
delayed_op = dask.delayed(self._fixed_point_iter)
pos = 0
for chunk, chunk_shape in zip(chunks, shift.chunks[0], strict=True):
n_size = self.margin_left + chunk_shape + self.margin_right
n_first = pos
samples_needed = dx_pad[n_first : n_first + n_size]
samples_shifted = delayed_op(samples_needed, chunk, tol_idx)
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
delayed_chunk = da.from_delayed(
samples_shifted, (chunk_shape,), shift_idx.dtype
)
results.append(delayed_chunk)
pos += chunk_shape
shift_idx_inv = da.concatenate(results, axis=0)
shift_inv = shift_idx_inv / fsample
return make_dask_array_1d(shift_inv)
def make_shift_inverse_lagrange_dask(
order: int,
max_abs_shift: float,
max_iter: int,
tolerance: float,
) -> ShiftInverseDask:
"""Set up ShiftInverseDask instance with Lagrange interpolation method.
Arguments:
order: Order of the Lagrange polynomials
max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
max_iter: Maximum iterations before fail
tolerance: Maximum absolute error of result
Returns:
Inversion function of type ShiftInverseNumpy
"""
interp = make_regular_interpolator_lagrange(order)
return ShiftInverseDask(max_abs_shift, interp, max_iter, tolerance)
def make_shift_inverse_dsp_dask(
order: int,
max_abs_shift: float,
max_iter: int,
tolerance: float,
) -> ShiftInverseDask:
"""Set up ShiftInverseDask instance using dsp.timeshift as interpolation method.
Arguments:
order: Order of the Lagrange polynomials
max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
max_iter: Maximum iterations before fail
tolerance: Maximum absolute error of result
Returns:
Inversion function of type ShiftInverseNumpy
"""
interp = make_regular_interpolator_dsp(order)
return ShiftInverseDask(max_abs_shift, interp, max_iter, tolerance)