Skip to content
Snippets Groups Projects
Commit b7a306a7 authored by Wolfgang Kastaun's avatar Wolfgang Kastaun
Browse files

improving type annotations

parent 3d0c5073
No related branches found
No related tags found
No related merge requests found
Pipeline #384093 passed
...@@ -8,11 +8,11 @@ interpolation. ...@@ -8,11 +8,11 @@ interpolation.
from __future__ import annotations from __future__ import annotations
from typing import Final from typing import Callable, Final
import numpy as np import numpy as np
from lisainstrument.fir_filters_numpy import NumpyArray1D from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d
from lisainstrument.regular_interpolators import ( from lisainstrument.regular_interpolators import (
RegularInterpolator, RegularInterpolator,
make_regular_interpolator_lagrange, make_regular_interpolator_lagrange,
...@@ -37,7 +37,7 @@ class ShiftInverseNumpy: # pylint: disable=too-few-public-methods ...@@ -37,7 +37,7 @@ class ShiftInverseNumpy: # pylint: disable=too-few-public-methods
max_iter: Maximum iterations before fail max_iter: Maximum iterations before fail
tolerance: Maximum absolute error of result tolerance: Maximum absolute error of result
""" """
self._max_abs_shift: Final = float(max_abs_shift) self._max_abs_shift: Final = int(np.ceil(max_abs_shift))
self._interp_np: Final = interp self._interp_np: Final = interp
self._max_iter = int(max_iter) self._max_iter = int(max_iter)
self._tolerance = float(tolerance) self._tolerance = float(tolerance)
...@@ -58,7 +58,9 @@ class ShiftInverseNumpy: # pylint: disable=too-few-public-methods ...@@ -58,7 +58,9 @@ class ShiftInverseNumpy: # pylint: disable=too-few-public-methods
""" """
return self._interp_np.margin_right + self._max_abs_shift return self._interp_np.margin_right + self._max_abs_shift
def _fixed_point_iter(self, f, x): def _fixed_point_iter(
self, f: Callable[[NumpyArray1D], NumpyArray1D], x: NumpyArray1D
):
for _ in range(self._max_iter): for _ in range(self._max_iter):
x_next = f(x) x_next = f(x)
error = np.max(np.abs(x - x_next)) error = np.max(np.abs(x - x_next))
...@@ -81,18 +83,19 @@ class ShiftInverseNumpy: # pylint: disable=too-few-public-methods ...@@ -81,18 +83,19 @@ class ShiftInverseNumpy: # pylint: disable=too-few-public-methods
Returns: Returns:
1D numpy array with shift for inverse coordinate transform 1D numpy array with shift for inverse coordinate transform
""" """
dx = make_numpy_array_1d(shift)
shift_pad = np.pad( dx_pad = np.pad(
shift, dx,
(self._margin_left, self._margin_right), (self._margin_left, self._margin_right),
mode="constant", mode="constant",
constant_values=(shift[0], shift[-1]), constant_values=(shift[0], shift[-1]),
) )
def f_iter(x: np.ndarray) -> np.ndarray: def f_iter(x: NumpyArray1D) -> NumpyArray1D:
return self._interp_np.apply_shift(shift_pad, -x, self._margin_left) return self._interp_np.apply_shift(dx_pad, -x, self._margin_left)
return self._fixed_point_iter(f_iter, shift) return self._fixed_point_iter(f_iter, dx)
def make_shift_inverse_lagrange_numpy( def make_shift_inverse_lagrange_numpy(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment