From 50ee7bc7547e815fbece89f27796c9f6f09e4d9e Mon Sep 17 00:00:00 2001 From: Wolfgang Kastaun <wolfgang.kastaun@aei.mpg.de> Date: Tue, 11 Feb 2025 16:41:36 +0100 Subject: [PATCH] Added unit test for numpy-based adaptive shift wrapper --- tests/test_fixed_shift_numpy.py | 47 +++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/tests/test_fixed_shift_numpy.py b/tests/test_fixed_shift_numpy.py index 907cbd2..9a2cd39 100644 --- a/tests/test_fixed_shift_numpy.py +++ b/tests/test_fixed_shift_numpy.py @@ -4,10 +4,15 @@ import numpy as np import pytest from lisainstrument.dynamic_delay_numpy import ShiftBC -from lisainstrument.fixed_shift_numpy import make_fixed_shift_lagrange_numpy +from lisainstrument.fixed_shift_numpy import ( + AdaptiveShiftNumpy, + make_fixed_shift_lagrange_numpy, +) -def test_fixed_shift_lagrange_dask() -> None: +def test_fixed_shift_lagrange_numpy() -> None: + """Test Lagrange interpolation of numpy arrays specialized to fixed shift""" + order = 5 t, dt = np.linspace(-5.345, 10.345, 1003, retstep=True) @@ -53,3 +58,41 @@ def test_fixed_shift_lagrange_dask() -> None: with pytest.raises(RuntimeError): op2_np(y, -d) + + +def test_adaptive_shift_numpy() -> None: + """Test AdaptiveShiftNumpy wrapper + + Only checks that it calls correct methods based on argument type combinations. + """ + + def mock_fix(d: np.ndarray, s: float) -> np.ndarray: + """Bogus but unique result to check this function was used""" + return d * s + + def mock_dyn(d: np.ndarray, s: np.ndarray) -> np.ndarray: + """Bogus but unique result to check this function was used""" + return d / s + + mock_fsamp = 13 + op = AdaptiveShiftNumpy(mock_fix, mock_dyn, mock_fsamp) + + y_fix = 12.0 + y_dyn = np.ones(10) * y_fix + s_fix = 6.0 + s_dyn = np.ones(10) * s_fix + + r_dyn_fix = op(y_dyn, s_fix) + assert np.all(r_dyn_fix == mock_fix(y_dyn, s_fix * mock_fsamp)) + + r_dyn_dyn = op(y_dyn, s_dyn) + assert np.all(r_dyn_dyn == mock_dyn(y_dyn, s_dyn * mock_fsamp)) + + r_fix_fix = op(y_fix, s_fix) + assert r_fix_fix == y_fix + + r_fix_zero = op(y_fix, 0.0) + assert r_fix_zero == y_fix + + r_fix_dyn = op(y_fix, s_dyn) + assert r_fix_dyn == y_fix -- GitLab