From 3c9a4fc89a5413788c8b4480caff0c78f1043a13 Mon Sep 17 00:00:00 2001
From: Wolfgang Kastaun <wolfgang.kastaun@aei.mpg.de>
Date: Thu, 19 Dec 2024 17:31:44 +0100
Subject: [PATCH] added some unit tests for fixed shift with numpy or dask

---
 tests/test_fixed_shift_dask.py  | 88 +++++++++++++++++++++++++++++++++
 tests/test_fixed_shift_numpy.py | 63 +++++++++++++++++++++++
 2 files changed, 151 insertions(+)
 create mode 100644 tests/test_fixed_shift_dask.py
 create mode 100644 tests/test_fixed_shift_numpy.py

diff --git a/tests/test_fixed_shift_dask.py b/tests/test_fixed_shift_dask.py
new file mode 100644
index 0000000..bc4177c
--- /dev/null
+++ b/tests/test_fixed_shift_dask.py
@@ -0,0 +1,88 @@
+import numpy as np
+import pytest
+
+
+from lisainstrument.dynamic_delay_numpy import DynShiftBC
+from lisainstrument.dynamic_delay_dask import numpyfy_dask_multi
+from lisainstrument.fixed_shift_numpy import make_fixed_shift_lagrange_numpy
+from lisainstrument.fixed_shift_dask import make_fixed_shift_lagrange_dask
+
+
+def test_fixed_shift_lagrange_dask() -> None:
+    order = 5
+    length = order + 1
+
+    t, dt = np.linspace(-5.345, 10.345, 1003, retstep=True)
+
+    def g(x):
+        return (
+            4.32546
+            + 3.34324 * x
+            - 4.342 * x**2
+            - 0.46 * x**3
+            + 1.43598 * x**4
+            - np.pi * x**5
+        )
+
+    y = g(t)
+
+    for d in (0.93456456 / dt, 3.1, 3.5, 3.9, 4.0, 4.1, 4.5, 4.9, 116.0):
+        
+        op_np = make_fixed_shift_lagrange_numpy(
+            DynShiftBC.FLAT, DynShiftBC.EXCEPTION, length
+        )
+        op_da = make_fixed_shift_lagrange_dask(
+            DynShiftBC.FLAT, DynShiftBC.EXCEPTION, length
+        )
+        op_na = numpyfy_dask_multi(op_da, chunks=113)
+
+        s_np = op_np(y, d)
+        s_na = op_na(y, d)
+        s_ex = g(t - d * dt)
+
+        margin_ex = -int(np.floor(-d)) + order // 2
+
+        
+        assert s_na == pytest.approx(s_np, rel=1e-14, abs=0)
+        assert s_ex[margin_ex:] == pytest.approx(s_na[margin_ex:], abs=1e-15, rel=5e-13)
+
+        
+        with pytest.raises(RuntimeError):
+            op_na(y, -d)
+
+        op2_np = make_fixed_shift_lagrange_numpy(
+            DynShiftBC.EXCEPTION, DynShiftBC.FLAT, length
+        )
+        op2_da = make_fixed_shift_lagrange_dask(
+            DynShiftBC.EXCEPTION, DynShiftBC.FLAT, length
+        )
+        op2_na = numpyfy_dask_multi(op2_da, chunks=113)
+
+        s2_np = op2_np(y, -d)
+        s2_na = op2_na(y, -d)
+        s2_ex = g(t + d * dt)
+
+        margin2_ex = int(np.floor(d)) + (length - 1 - order // 2)
+
+        assert s2_na == pytest.approx(s2_np, rel=1e-14, abs=0)
+        assert s2_ex[:-margin2_ex] == pytest.approx(
+            s2_na[:-margin2_ex], abs=1e-15, rel=5e-13
+        )
+        
+        with pytest.raises(RuntimeError):
+            op2_na(y, d)
+            
+        op3_da = make_fixed_shift_lagrange_dask(
+            DynShiftBC.FLAT, DynShiftBC.FLAT, length
+        )
+        op3_na = numpyfy_dask_multi(op3_da, chunks=113)    
+            
+        s3_na = op3_na(y, d)  
+        assert s3_na == pytest.approx(s_np, rel=1e-14, abs=0)
+        
+        s4_na = op3_na(y, -d)  
+        assert s4_na == pytest.approx(s2_np, rel=1e-14, abs=0)
+        
+        
+        
+        
diff --git a/tests/test_fixed_shift_numpy.py b/tests/test_fixed_shift_numpy.py
new file mode 100644
index 0000000..4b3a3a9
--- /dev/null
+++ b/tests/test_fixed_shift_numpy.py
@@ -0,0 +1,63 @@
+import numpy as np
+import pytest
+
+
+from lisainstrument.dynamic_delay_numpy import DynShiftBC
+
+from lisainstrument.fixed_shift_numpy import make_fixed_shift_lagrange_numpy
+
+
+def test_fixed_shift_lagrange_dask() -> None:
+    order = 5
+    length = order + 1
+
+    t, dt = np.linspace(-5.345, 10.345, 1003, retstep=True)
+
+    def g(x):
+        return (
+            4.32546
+            + 3.34324 * x
+            - 4.342 * x**2
+            - 0.46 * x**3
+            + 1.43598 * x**4
+            - 3.3456 * x**5
+        )
+
+    y = g(t)
+
+    for d in (0.93456456/dt, 3.1, 3.5, 3.9, 4.0, 4.1, 4.5, 4.9):
+    # ~ d = 0.93456456 / dt
+        
+        op_np = make_fixed_shift_lagrange_numpy(DynShiftBC.FLAT, DynShiftBC.EXCEPTION, length)
+        
+        s_np = op_np(y, d)
+        s_ex = g(t - d * dt)
+
+        margin_ex = -int(np.floor(-d)) + order // 2 
+        
+        print(d, d*dt, margin_ex)
+        
+        assert s_ex[margin_ex:] == pytest.approx(
+            s_np[margin_ex:], abs=1e-15, rel=5e-13
+        )
+
+        # ~ assert np.min(np.abs(s_ex[:margin_ex] - s_np[:margin_ex])) > np.max(np.abs(y[:margin_ex]))*1e-10
+        
+        with pytest.raises(RuntimeError):
+            op_np(y, -d)
+            
+        
+        op2_np = make_fixed_shift_lagrange_numpy(DynShiftBC.EXCEPTION, DynShiftBC.FLAT, length)
+        
+        s2_np = op2_np(y, -d)
+        s2_ex = g(t + d * dt)
+
+        margin2_ex = int(np.floor(d)) + (length - 1 - order // 2)
+
+        assert s2_ex[:-margin2_ex] == pytest.approx(
+            s2_np[:-margin2_ex], abs=1e-15, rel=5e-13
+        )
+        # ~ assert np.min(np.abs(s2_ex[-margin2_ex:] - s2_np[-margin2_ex:])) > np.max(np.abs(y[-margin2_ex:]))*1e-10
+
+        with pytest.raises(RuntimeError):
+            op2_np(y, d)
-- 
GitLab