From 50ee7bc7547e815fbece89f27796c9f6f09e4d9e Mon Sep 17 00:00:00 2001
From: Wolfgang Kastaun <wolfgang.kastaun@aei.mpg.de>
Date: Tue, 11 Feb 2025 16:41:36 +0100
Subject: [PATCH] Added unit test for numpy-based adaptive shift wrapper

---
 tests/test_fixed_shift_numpy.py | 47 +++++++++++++++++++++++++++++++--
 1 file changed, 45 insertions(+), 2 deletions(-)

diff --git a/tests/test_fixed_shift_numpy.py b/tests/test_fixed_shift_numpy.py
index 907cbd2..9a2cd39 100644
--- a/tests/test_fixed_shift_numpy.py
+++ b/tests/test_fixed_shift_numpy.py
@@ -4,10 +4,15 @@ import numpy as np
 import pytest
 
 from lisainstrument.dynamic_delay_numpy import ShiftBC
-from lisainstrument.fixed_shift_numpy import make_fixed_shift_lagrange_numpy
+from lisainstrument.fixed_shift_numpy import (
+    AdaptiveShiftNumpy,
+    make_fixed_shift_lagrange_numpy,
+)
 
 
-def test_fixed_shift_lagrange_dask() -> None:
+def test_fixed_shift_lagrange_numpy() -> None:
+    """Test Lagrange interpolation of numpy arrays specialized to fixed shift"""
+
     order = 5
 
     t, dt = np.linspace(-5.345, 10.345, 1003, retstep=True)
@@ -53,3 +58,41 @@ def test_fixed_shift_lagrange_dask() -> None:
 
         with pytest.raises(RuntimeError):
             op2_np(y, -d)
+
+
+def test_adaptive_shift_numpy() -> None:
+    """Test AdaptiveShiftNumpy wrapper
+
+    Only checks that it calls correct methods based on argument type combinations.
+    """
+
+    def mock_fix(d: np.ndarray, s: float) -> np.ndarray:
+        """Bogus but unique result to check this function was used"""
+        return d * s
+
+    def mock_dyn(d: np.ndarray, s: np.ndarray) -> np.ndarray:
+        """Bogus but unique result to check this function was used"""
+        return d / s
+
+    mock_fsamp = 13
+    op = AdaptiveShiftNumpy(mock_fix, mock_dyn, mock_fsamp)
+
+    y_fix = 12.0
+    y_dyn = np.ones(10) * y_fix
+    s_fix = 6.0
+    s_dyn = np.ones(10) * s_fix
+
+    r_dyn_fix = op(y_dyn, s_fix)
+    assert np.all(r_dyn_fix == mock_fix(y_dyn, s_fix * mock_fsamp))
+
+    r_dyn_dyn = op(y_dyn, s_dyn)
+    assert np.all(r_dyn_dyn == mock_dyn(y_dyn, s_dyn * mock_fsamp))
+
+    r_fix_fix = op(y_fix, s_fix)
+    assert r_fix_fix == y_fix
+
+    r_fix_zero = op(y_fix, 0.0)
+    assert r_fix_zero == y_fix
+
+    r_fix_dyn = op(y_fix, s_dyn)
+    assert r_fix_dyn == y_fix
-- 
GitLab