Skip to content
Snippets Groups Projects
Commit 83e289d3 authored by Wolfgang Kastaun's avatar Wolfgang Kastaun
Browse files

ShiftInverseNumpy: include sample rate to specify shifts in physical units

Also, exclude invalid margin points from error measure. The iteration stop criterion
differs from the original one, which used hardcoded margins for this.
parent b7a306a7
No related branches found
No related tags found
No related merge requests found
......@@ -43,7 +43,7 @@ class ShiftInverseNumpy: # pylint: disable=too-few-public-methods
self._tolerance = float(tolerance)
@property
def _margin_left(self) -> int:
def margin_left(self) -> int:
"""Left margin size.
Specifies how many samples on the left have to be added by boundary conditions.
......@@ -51,7 +51,7 @@ class ShiftInverseNumpy: # pylint: disable=too-few-public-methods
return self._interp_np.margin_left + self._max_abs_shift
@property
def _margin_right(self) -> int:
def margin_right(self) -> int:
"""Right margin size.
Specifies how many samples on the right have to be added by boundary conditions.
......@@ -59,43 +59,59 @@ class ShiftInverseNumpy: # pylint: disable=too-few-public-methods
return self._interp_np.margin_right + self._max_abs_shift
def _fixed_point_iter(
self, f: Callable[[NumpyArray1D], NumpyArray1D], x: NumpyArray1D
):
self,
f: Callable[[NumpyArray1D], NumpyArray1D],
ferr: Callable[[NumpyArray1D, NumpyArray1D], float],
x: NumpyArray1D,
) -> NumpyArray1D:
for _ in range(self._max_iter):
x_next = f(x)
error = np.max(np.abs(x - x_next))
if error < self._tolerance:
err = ferr(x, x_next)
if err < self._tolerance:
return x_next
x = x_next
msg = (
f"ShiftInverseNumpy: iteration did not converge ({error=}, "
f"iterations={self._max_iter}, tolerance={self._tolerance})"
f"ShiftInverseNumpy: iteration did not converge (error={err}, "
f"tolerance={self._tolerance}), iterations={self._max_iter}"
)
raise RuntimeError(msg)
def __call__(self, shift: np.ndarray) -> NumpyArray1D:
def __call__(self, shift: np.ndarray, fsample: float) -> NumpyArray1D:
"""Find the shift for the inverse transformation given by a shift.
Arguments:
shift: 1D numpy array with shifts of the coordinate transform
fsample: sample rate of shift array
Returns:
1D numpy array with shift for inverse coordinate transform
1D numpy array with shift at transformed coordinate
"""
dx = make_numpy_array_1d(shift)
shift_idx = shift * fsample
dx = make_numpy_array_1d(shift_idx)
dx_pad = np.pad(
dx,
(self._margin_left, self._margin_right),
(self.margin_left, self.margin_right),
mode="constant",
constant_values=(shift[0], shift[-1]),
constant_values=(shift_idx[0], shift_idx[-1]),
)
def f_iter(x: NumpyArray1D) -> NumpyArray1D:
return self._interp_np.apply_shift(dx_pad, -x, self._margin_left)
return self._interp_np.apply_shift(dx_pad, -x, self.margin_left)
def f_err(x1: NumpyArray1D, x2: NumpyArray1D) -> float:
return np.max(np.abs(x1 - x2)[self.margin_left : -self.margin_right])
shift_idx_inv = self._fixed_point_iter(f_iter, f_err, dx)
shift_inv = shift_idx_inv / fsample
return self._fixed_point_iter(f_iter, dx)
return make_numpy_array_1d(shift_inv)
def make_shift_inverse_lagrange_numpy(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment