Skip to content
Snippets Groups Projects
Commit 4233c507 authored by Wolfgang Kastaun's avatar Wolfgang Kastaun
Browse files

Make shift inversion interface tolerance parameter consistent

parent 708a48fd
No related branches found
No related tags found
No related merge requests found
Pipeline #385650 passed
...@@ -66,7 +66,9 @@ class ShiftInverseDask: ...@@ -66,7 +66,9 @@ class ShiftInverseDask:
""" """
return self._interp_np.margin_right + self._max_abs_shift return self._interp_np.margin_right + self._max_abs_shift
def _fixed_point_iter(self, dx_pad: np.ndarray, dx: np.ndarray) -> NumpyArray1D: def _fixed_point_iter(
self, dx_pad: np.ndarray, dx: np.ndarray, tol: float
) -> NumpyArray1D:
def f_iter(x: NumpyArray1D) -> NumpyArray1D: def f_iter(x: NumpyArray1D) -> NumpyArray1D:
return self._interp_np.apply_shift( return self._interp_np.apply_shift(
make_numpy_array_1d(dx_pad), -x, self.margin_left make_numpy_array_1d(dx_pad), -x, self.margin_left
...@@ -76,7 +78,7 @@ class ShiftInverseDask: ...@@ -76,7 +78,7 @@ class ShiftInverseDask:
return np.max(np.abs(x1 - x2)[self.margin_left : -self.margin_right]) return np.max(np.abs(x1 - x2)[self.margin_left : -self.margin_right])
return fixed_point_iter( return fixed_point_iter(
f_iter, f_err, make_numpy_array_1d(dx), self._tolerance, self._max_iter f_iter, f_err, make_numpy_array_1d(dx), tol, self._max_iter
) )
def __call__(self, shift: da.Array, fsample: float) -> DaskArray1D: def __call__(self, shift: da.Array, fsample: float) -> DaskArray1D:
...@@ -93,6 +95,8 @@ class ShiftInverseDask: ...@@ -93,6 +95,8 @@ class ShiftInverseDask:
shift_idx = shift * fsample shift_idx = shift * fsample
make_dask_array_1d(shift_idx) make_dask_array_1d(shift_idx)
tol_idx = self._tolerance * fsample
dx_pad = da.pad( dx_pad = da.pad(
shift_idx, shift_idx,
(self.margin_left, self.margin_right), (self.margin_left, self.margin_right),
...@@ -109,7 +113,7 @@ class ShiftInverseDask: ...@@ -109,7 +113,7 @@ class ShiftInverseDask:
n_size = self.margin_left + chunk_shape + self.margin_right n_size = self.margin_left + chunk_shape + self.margin_right
n_first = pos n_first = pos
samples_needed = dx_pad[n_first : n_first + n_size] samples_needed = dx_pad[n_first : n_first + n_size]
samples_shifted = delayed_op(samples_needed, chunk) samples_shifted = delayed_op(samples_needed, chunk, tol_idx)
delayed_chunk = da.from_delayed( delayed_chunk = da.from_delayed(
samples_shifted, (chunk_shape,), shift_idx.dtype samples_shifted, (chunk_shape,), shift_idx.dtype
) )
......
...@@ -132,7 +132,7 @@ class ShiftInverseNumpy: # pylint: disable=too-few-public-methods ...@@ -132,7 +132,7 @@ class ShiftInverseNumpy: # pylint: disable=too-few-public-methods
return np.max(np.abs(x1 - x2)[self.margin_left : -self.margin_right]) return np.max(np.abs(x1 - x2)[self.margin_left : -self.margin_right])
shift_idx_inv = fixed_point_iter( shift_idx_inv = fixed_point_iter(
f_iter, f_err, dx, self._tolerance, self._max_iter f_iter, f_err, dx, self._tolerance * fsample, self._max_iter
) )
shift_inv = shift_idx_inv / fsample shift_inv = shift_idx_inv / fsample
......
...@@ -25,8 +25,8 @@ def test_shift_inversion_numpy(): ...@@ -25,8 +25,8 @@ def test_shift_inversion_numpy():
fsample = 1 / dt fsample = 1 / dt
dxi = dx_from_x(xi) dxi = dx_from_x(xi)
op_np = make_shift_inverse_lagrange_numpy(order, a_mod * 1.01, max_it, tol / dt) op_np = make_shift_inverse_lagrange_numpy(order, a_mod * 1.01, max_it, tol)
op_da = make_shift_inverse_lagrange_dask(order, a_mod * 1.01, max_it, tol / dt) op_da = make_shift_inverse_lagrange_dask(order, a_mod * 1.01, max_it, tol)
op_na = numpyfy_dask_multi(op_da, chunks) op_na = numpyfy_dask_multi(op_da, chunks)
ai_np = op_np(dxi, fsample) ai_np = op_np(dxi, fsample)
......
...@@ -22,7 +22,7 @@ def test_shift_inversion_numpy(): ...@@ -22,7 +22,7 @@ def test_shift_inversion_numpy():
fsample = 1 / dt fsample = 1 / dt
dxi = dx_from_x(xi) dxi = dx_from_x(xi)
op_np = make_shift_inverse_lagrange_numpy(order, a_mod * 1.01, max_it, tol / dt) op_np = make_shift_inverse_lagrange_numpy(order, a_mod * 1.01, max_it, tol)
ai_np = op_np(dxi, fsample) ai_np = op_np(dxi, fsample)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment