Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""Functions for applying dynamic real-valued shifts to dask arrays using Lagrange interpolation
Use make_dynamic_shift_lagrange_dask to create a Lagrange interpolator for dask arrays.
"""
from __future__ import annotations
from typing import Final
import dask
import dask.array as da
import numpy as np
from lisainstrument.fir_filters_dask import DaskArray1D, make_dask_array_1d
from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d
from lisainstrument.regular_interpolators import (
RegularInterpolator,
make_regular_interpolator_lagrange,
)
from lisainstrument.shift_inversion_numpy import fixed_point_iter
class ShiftInverseDask:
"""Invert coordinate transformation given as shift"""
def __init__(
self,
max_abs_shift: float,
interp: RegularInterpolator,
max_iter: int,
tolerance: float,
):
"""Set up interpolator.
Arguments:
max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
interp: Interpolation method
max_iter: Maximum iterations before fail
tolerance: Maximum absolute error of result
"""
self._max_abs_shift: Final = int(np.ceil(max_abs_shift))
if self._max_abs_shift < 0:
msg = (
f"ShiftInverseDask: max_abs_shift must be positive, got {max_abs_shift}"
)
raise ValueError(msg)
self._interp_np: Final = interp
self._max_iter = int(max_iter)
self._tolerance = float(tolerance)
@property
def margin_left(self) -> int:
"""Left margin size.
Specifies how many samples on the left have to be added by boundary conditions.
"""
return self._interp_np.margin_left + self._max_abs_shift
@property
def margin_right(self) -> int:
"""Right margin size.
Specifies how many samples on the right have to be added by boundary conditions.
"""
return self._interp_np.margin_right + self._max_abs_shift
def _fixed_point_iter(self, dx_pad: np.ndarray, dx: np.ndarray) -> NumpyArray1D:
def f_iter(x: NumpyArray1D) -> NumpyArray1D:
return self._interp_np.apply_shift(
make_numpy_array_1d(dx_pad), -x, self.margin_left
)
def f_err(x1: NumpyArray1D, x2: NumpyArray1D) -> float:
return np.max(np.abs(x1 - x2)[self.margin_left : -self.margin_right])
return fixed_point_iter(
f_iter, f_err, make_numpy_array_1d(dx), self._tolerance, self._max_iter
)
def __call__(self, shift: da.Array, fsample: float) -> DaskArray1D:
"""Find the shift for the inverse transformation given by a shift.
Arguments:
shift: 1D dask array with shifts of the coordinate transform
fsample: sample rate of shift array
Returns:
1D dask array with shift at transformed coordinate
"""
shift_idx = shift * fsample
make_dask_array_1d(shift_idx)
dx_pad = da.pad(
shift_idx,
(self.margin_left, self.margin_right),
mode="edge",
# ~ constant_values=(shift_idx[0], shift_idx[-1]),
)
results = []
chunks = shift_idx.to_delayed()
delayed_op = dask.delayed(self._fixed_point_iter)
pos = 0
for chunk, chunk_shape in zip(chunks, shift.chunks[0], strict=True):
n_size = self.margin_left + chunk_shape + self.margin_right
n_first = pos
samples_needed = dx_pad[n_first : n_first + n_size]
samples_shifted = delayed_op(samples_needed, chunk)
delayed_chunk = da.from_delayed(
samples_shifted, (chunk_shape,), shift_idx.dtype
)
results.append(delayed_chunk)
pos += chunk_shape
shift_idx_inv = da.concatenate(results, axis=0)
shift_inv = shift_idx_inv / fsample
return make_dask_array_1d(shift_inv)
def make_shift_inverse_lagrange_dask(
order: int,
max_abs_shift: float,
max_iter: int,
tolerance: float,
) -> ShiftInverseDask:
"""Set up ShiftInverseDask instance with Lagrange interpolation method.
Arguments:
order: Order of the Lagrange polynomials
max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
max_iter: Maximum iterations before fail
tolerance: Maximum absolute error of result
Returns:
Inversion function of type ShiftInverseNumpy
"""
interp = make_regular_interpolator_lagrange(order)
return ShiftInverseDask(max_abs_shift, interp, max_iter, tolerance)