From e6ae24cf109656321b5d9e4d8c754a527b6f2755 Mon Sep 17 00:00:00 2001
From: Wolfgang Kastaun <wolfgang.kastaun@aei.mpg.de>
Date: Thu, 19 Dec 2024 16:06:47 +0100
Subject: [PATCH] improve padding logic in dask interpolator, make numpyfy_dask
 more flexible

---
 lisainstrument/dynamic_delay_dask.py | 56 ++++++++++++++++------------
 1 file changed, 32 insertions(+), 24 deletions(-)

diff --git a/lisainstrument/dynamic_delay_dask.py b/lisainstrument/dynamic_delay_dask.py
index cec8cf8..1962d03 100644
--- a/lisainstrument/dynamic_delay_dask.py
+++ b/lisainstrument/dynamic_delay_dask.py
@@ -5,22 +5,21 @@ Use make_dynamic_shift_lagrange_dask to create a Lagrange interpolator for dask
 
 from __future__ import annotations
 
-from typing import Callable, Final
+from typing import Any, Callable, Final
 
 import dask
 import dask.array as da
 import numpy as np
 from typing_extensions import assert_never
 
-from lisainstrument.dynamic_delay_numpy import (
-    DynShiftBC,
-    DynShiftCfg,
+from lisainstrument.dynamic_delay_numpy import DynShiftBC, DynShiftCfg
+from lisainstrument.fir_filters_dask import DaskArray1D, make_dask_array_1d
+from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d
+from lisainstrument.regular_interpolators import (
     RegularInterpolator,
     make_regular_interpolator_lagrange,
     make_regular_interpolator_linear,
 )
-from lisainstrument.fir_filters_dask import DaskArray1D, make_dask_array_1d
-from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d
 
 
 class DynamicShiftDask:
@@ -91,14 +90,17 @@ class DynamicShiftDask:
             Dask array with interpolated samples
         """
         npad_left = 0
+        samples_pad = samples
         if self.margin_left > 0:
             match self._cfg.left_bound:
                 case DynShiftBC.ZEROPAD:
                     npad_left = self.margin_left
-                    samples = da.concatenate([da.zeros(npad_left), samples])
+                    samples_pad = da.concatenate([da.zeros(npad_left), samples])
                 case DynShiftBC.FLAT:
                     npad_left = self.margin_left
-                    samples = da.concatenate([da.ones(npad_left) * samples[0], samples])
+                    samples_pad = da.concatenate(
+                        [da.ones(npad_left) * samples[0], samples]
+                    )
                 case DynShiftBC.EXCEPTION:
                     msg = (
                         f"DynamicShiftDask: left edge handling {self._cfg.left_bound.name} not "
@@ -111,10 +113,12 @@ class DynamicShiftDask:
         if self.margin_right > 0:
             match self._cfg.right_bound:
                 case DynShiftBC.ZEROPAD:
-                    samples = da.concatenate([samples, da.zeros(self.margin_right)])
+                    samples_pad = da.concatenate(
+                        [samples_pad, da.zeros(self.margin_right)]
+                    )
                 case DynShiftBC.FLAT:
-                    samples = da.concatenate(
-                        [samples, da.ones(self.margin_right) * samples[-1]]
+                    samples_pad = da.concatenate(
+                        [samples_pad, da.ones(self.margin_right) * samples_pad[-1]]
                     )
                 case DynShiftBC.EXCEPTION:
                     msg = (
@@ -125,14 +129,15 @@ class DynamicShiftDask:
                 case _ as unreachable:
                     assert_never(unreachable)
 
+        results = []
+
         chunks = shift.to_delayed()
         delayed_op = dask.delayed(self._interp_np.apply_shift)
-        results = []
         pos = 0
         for chunk, chunk_shape in zip(chunks, shift.chunks[0], strict=True):
             n_size = npad_left + chunk_shape + self.margin_right
             n_first = pos + npad_left - self.margin_left
-            samples_needed = samples[n_first : n_first + n_size]
+            samples_needed = samples_pad[n_first : n_first + n_size]
             samples_shifted = delayed_op(samples_needed, chunk, self.margin_left)
             delayed_chunk = da.from_delayed(
                 samples_shifted, (chunk_shape,), samples.dtype
@@ -189,13 +194,15 @@ def make_dynamic_shift_lagrange_dask(
     return DynamicShiftDask(cfg, interp)
 
 
-def numpyfy_dask_bivariate(
-    dafunc: Callable[[da.Array, da.Array], DaskArray1D], chunks: int
-) -> Callable[[np.ndarray, np.ndarray], NumpyArray1D]:
-    """Convert function operating on dask arrays to work with numpy arrays.
+def numpyfy_dask_multi(
+    dafunc: Callable[[*Any], DaskArray1D], chunks: int
+) -> Callable[[*Any], NumpyArray1D]:
+    """Convert function operating on 1D dask arrays to work with 1D numpy arrays.
 
-    The dask function will be called with two chunked dask arrays created from
-    the numpy arrays, and then evaluate the result.
+    Before calling the dask-based function, all arguments which are numpy
+    arrays are converted to 1D dask arrays chunked with the specified chunk length.
+    The result will be evaluated as 1D numpy array. Arguments which are not
+    numpy arrays will be passed unchanged
 
     Arguments:
         dafunc: function operating on two dask arrays returning single 1D dask array
@@ -204,10 +211,11 @@ def numpyfy_dask_bivariate(
         function operating on numpy arrays
     """
 
-    def func(x, y):
-        dx = da.from_array(x, chunks=chunks)
-        dy = da.from_array(y, chunks=chunks)
-        r = dafunc(make_dask_array_1d(dx), make_dask_array_1d(dy)).compute()
-        return make_numpy_array_1d(r)
+    def func(*args: Any) -> NumpyArray1D:
+        nargs = [
+            (da.from_array(x, chunks=chunks) if isinstance(x, np.ndarray) else x)
+            for x in args
+        ]
+        return make_numpy_array_1d(dafunc(*nargs).compute())
 
     return func
-- 
GitLab