From 50b0b411541128b1254087b17c1da6aaa130f06e Mon Sep 17 00:00:00 2001
From: Wolfgang Kastaun <wolfgang.kastaun@aei.mpg.de>
Date: Fri, 20 Dec 2024 18:18:58 +0100
Subject: [PATCH] added more unit tests for fixed and dynamic shift

---
 tests/test_dynamic_delay_dsp_dask.py | 120 ++++++++++++++++++++++++---
 tests/test_fixed_shift_dask.py       |  70 +++++++++++++---
 2 files changed, 165 insertions(+), 25 deletions(-)

diff --git a/tests/test_dynamic_delay_dsp_dask.py b/tests/test_dynamic_delay_dsp_dask.py
index 8be00af..e69eb5d 100644
--- a/tests/test_dynamic_delay_dsp_dask.py
+++ b/tests/test_dynamic_delay_dsp_dask.py
@@ -1,29 +1,43 @@
 import numpy as np
 import pytest
 
-from lisainstrument.dynamic_delay_dask import numpyfy_dask_multi
+from lisainstrument.dynamic_delay_dask import (
+    make_dynamic_shift_lagrange_dask,
+    numpyfy_dask_multi,
+)
 from lisainstrument.dynamic_delay_dsp_dask import (
     make_dynamic_shift_dsp_dask,
     make_dynamic_shift_dsp_numpy,
 )
 from lisainstrument.dynamic_delay_numpy import DynShiftBC
+from lisainstrument.fixed_shift_dask import make_fixed_shift_lagrange_dask
+from lisainstrument.fixed_shift_dsp import make_fixed_shift_dsp_dask
+
+
+def g(x):
+    """This example functionis a polynomial of order 5, which should be reproduced
+    exactly when being interpolated with Lagrange polynomials of order 5.
+    """
+    return (
+        4.32546
+        + 3.34324 * x
+        + 4.342 * x**2
+        + 0.46 * x**3
+        + 1.43598 * x**4
+        + 3.3456 * x**5
+    )
 
 
 def test_dynamic_shift_dsp_dask() -> None:
-    order = 5
+    """Check basic functioning of dynamic shift of dask arrays
 
-    t, dt = np.linspace(-5.345, 10.345, 1003, retstep=True)
+    Compare against expected result (except margin points affected by BC) and
+    compare against pure numpy-based version.
+    """
 
-    def g(x):
-        return (
-            4.32546
-            + 3.34324 * x
-            + 4.342 * x**2
-            + 0.46 * x**3
-            + 1.43598 * x**4
-            + 3.3456 * x**5
-        )
+    order = 5
 
+    t, dt = np.linspace(-5.345, 10.345, 1003, retstep=True)
     y = g(t)
 
     d = (0.93456456 + 0.0235345 * np.cos(4.3354 * t)) / dt
@@ -47,3 +61,85 @@ def test_dynamic_shift_dsp_dask() -> None:
     s_np = op_np(y, d)
 
     assert np.all(s_np == s_da)
+
+
+def test_dynamic_shift_dsp_dask_orders() -> None:
+    """Test dsp dynamic shift wrapper for different orders
+
+    Compare against other lagrange interpolator"""
+
+    for order in (5, 7, 13, 31):
+        length = order + 1
+        t, dt = np.linspace(-5.345, 10.345, 1003, retstep=True)
+        y = g(t)
+
+        d = 17 + (0.93456456 + 0.0235345 * np.cos(4.3354 * t)) / dt
+
+        op1_da = make_dynamic_shift_dsp_dask(
+            order, d.min(), d.max(), DynShiftBC.FLAT, DynShiftBC.EXCEPTION
+        )
+        op1_na = numpyfy_dask_multi(op1_da, chunks=19)
+
+        s1_na = op1_na(y, d)
+
+        op2_da = make_dynamic_shift_lagrange_dask(
+            length, d.min(), d.max(), DynShiftBC.FLAT, DynShiftBC.EXCEPTION
+        )
+        op2_na = numpyfy_dask_multi(op2_da, chunks=19)
+
+        s2_na = op2_na(y, d)
+
+        assert s1_na == pytest.approx(s2_na, abs=0, rel=5e-15)
+
+
+def test_fixed_shift_dsp_dask() -> None:
+    """Test dsp fixed shift wrapper for positive and negative shifts
+
+    Compares against other fixed shift lagrange interpolator.
+
+    Opportunistic test: validate boundary condition checks raise exception when
+    they should
+    """
+    order = 5
+    length = order + 1
+
+    t, dt = np.linspace(-5.345, 10.345, 1003, retstep=True)
+
+    y = g(t)
+
+    for d in (0.93456456 / dt, 3.1, 3.5, 3.9, 4.0, 4.1, 4.5, 4.9, 116.0):
+        op1_da = make_fixed_shift_lagrange_dask(
+            DynShiftBC.FLAT, DynShiftBC.EXCEPTION, length
+        )
+        op1_na = numpyfy_dask_multi(op1_da, chunks=113)
+
+        op2_da = make_fixed_shift_dsp_dask(
+            DynShiftBC.FLAT, DynShiftBC.EXCEPTION, length
+        )
+        op2_na = numpyfy_dask_multi(op2_da, chunks=113)
+
+        s1_na = op1_na(y, d)
+        s2_na = op2_na(y, d)
+
+        assert s1_na == pytest.approx(s2_na, rel=1e-14, abs=0)
+
+        with pytest.raises(RuntimeError):
+            op2_na(y, -d)
+
+        op3_da = make_fixed_shift_lagrange_dask(
+            DynShiftBC.EXCEPTION, DynShiftBC.FLAT, length
+        )
+        op3_na = numpyfy_dask_multi(op3_da, chunks=113)
+
+        op4_da = make_fixed_shift_dsp_dask(
+            DynShiftBC.EXCEPTION, DynShiftBC.FLAT, length
+        )
+        op4_na = numpyfy_dask_multi(op4_da, chunks=113)
+
+        s3_na = op3_na(y, -d)
+        s4_na = op4_na(y, -d)
+
+        assert s3_na == pytest.approx(s4_na, rel=1e-14, abs=0)
+
+        with pytest.raises(RuntimeError):
+            op4_na(y, d)
diff --git a/tests/test_fixed_shift_dask.py b/tests/test_fixed_shift_dask.py
index 95296b5..3bd8072 100644
--- a/tests/test_fixed_shift_dask.py
+++ b/tests/test_fixed_shift_dask.py
@@ -1,28 +1,41 @@
+"""Unit tests for module fixed_shift_dask"""
+
 import numpy as np
 import pytest
 
-from lisainstrument.dynamic_delay_dask import numpyfy_dask_multi
+from lisainstrument.dynamic_delay_dask import (
+    make_dynamic_shift_lagrange_dask,
+    numpyfy_dask_multi,
+)
 from lisainstrument.dynamic_delay_numpy import DynShiftBC
 from lisainstrument.fixed_shift_dask import make_fixed_shift_lagrange_dask
 from lisainstrument.fixed_shift_numpy import make_fixed_shift_lagrange_numpy
 
 
+def g(x):
+    """This example functionis a polynomial of order 5, which should be reproduced
+    exactly when being interpolated with Lagrange polynomials of order 5.
+    """
+    return (
+        4.32546
+        + 3.34324 * x
+        - 4.342 * x**2
+        - 0.46 * x**3
+        + 1.43598 * x**4
+        - np.pi * x**5
+    )
+
+
 def test_fixed_shift_lagrange_dask() -> None:
+    """Check that dask-based Lagrange interpolator for fixed shifts works.
+
+    Compares dask-based interpolator against numpy based version and against
+    expected result (excluding points affected by boundary treatment).
+    """
     order = 5
     length = order + 1
 
-    t, dt = np.linspace(-5.345, 10.345, 1003, retstep=True)
-
-    def g(x):
-        return (
-            4.32546
-            + 3.34324 * x
-            - 4.342 * x**2
-            - 0.46 * x**3
-            + 1.43598 * x**4
-            - np.pi * x**5
-        )
-
+    t, dt = np.linspace(-5.345, 10.345, 1003, retstep=True, dtype=float)
     y = g(t)
 
     for d in (0.93456456 / dt, 3.1, 3.5, 3.9, 4.0, 4.1, 4.5, 4.9, 116.0):
@@ -79,3 +92,34 @@ def test_fixed_shift_lagrange_dask() -> None:
 
         s4_na = op3_na(y, -d)
         assert s4_na == pytest.approx(s2_np, rel=1e-14, abs=0)
+
+
+def test_shift_lagrange_dask_fixed_vs_dyn() -> None:
+    """Check that dask-based fixed and dynamic lagrange shift functions yields identical results
+
+    Uses different interpolation orders (odd, even, high) and shifts. Shift values chosen
+    to take into account use of round and floor functions in the code to be tested.
+    """
+    t, dt = np.linspace(-5.345, 10.345, 1003, retstep=True, dtype=float)
+    y = g(t)
+
+    for order in (5, 6, 31):
+        length = order + 1
+
+        for d in (0.93456456 / dt, 23.1, 23.5, 23.9, 24.0, 116.0):
+            d1d = np.ones_like(y) * d
+
+            op1_da = make_dynamic_shift_lagrange_dask(
+                length, d, d, DynShiftBC.FLAT, DynShiftBC.EXCEPTION
+            )
+            op1_na = numpyfy_dask_multi(op1_da, chunks=113)
+
+            op2_da = make_fixed_shift_lagrange_dask(
+                DynShiftBC.FLAT, DynShiftBC.EXCEPTION, length
+            )
+            op2_na = numpyfy_dask_multi(op2_da, chunks=113)
+
+            s1_np = op1_na(y, d1d)
+            s2_np = op2_na(y, d)
+
+            assert s1_np == pytest.approx(s2_np, rel=1e-14, abs=0)
-- 
GitLab