From 38920a6fad23ccb160d49a6dfe7ab35c39b516e4 Mon Sep 17 00:00:00 2001
From: Wolfgang Kastaun <wolfgang.kastaun@aei.mpg.de>
Date: Sat, 21 Dec 2024 20:55:01 +0100
Subject: [PATCH] minor streamlining of time shift algorithms

---
 lisainstrument/dynamic_delay_numpy.py   | 26 +++++++++++++++++--------
 lisainstrument/fixed_shift_dask.py      | 22 +++++++++++++++------
 lisainstrument/fixed_shift_numpy.py     | 22 +++++++++++++++------
 lisainstrument/regular_interpolators.py |  6 +++---
 4 files changed, 53 insertions(+), 23 deletions(-)

diff --git a/lisainstrument/dynamic_delay_numpy.py b/lisainstrument/dynamic_delay_numpy.py
index 6dc71f9..ee251be 100644
--- a/lisainstrument/dynamic_delay_numpy.py
+++ b/lisainstrument/dynamic_delay_numpy.py
@@ -135,14 +135,15 @@ class DynamicShiftNumpy:
         out_size = len(shift)
 
         npad_left = 0
+        padv_left = 0.0
         if self.margin_left > 0:
             match self._cfg.left_bound:
                 case ShiftBC.ZEROPAD:
                     npad_left = self.margin_left
-                    samples = np.concatenate([np.zeros(npad_left), samples])
+                    padv_left = 0.0
                 case ShiftBC.FLAT:
                     npad_left = self.margin_left
-                    samples = np.concatenate([np.ones(npad_left) * samples[0], samples])
+                    padv_left = samples[0]
                 case ShiftBC.EXCEPTION:
                     msg = (
                         f"DynamicShiftNumpy: left edge handling {self._cfg.left_bound.name} not "
@@ -152,14 +153,16 @@ class DynamicShiftNumpy:
                 case _ as unreachable:
                     assert_never(unreachable)
 
+        npad_right = 0
+        padv_right = 0.0
         if self.margin_right > 0:
             match self._cfg.right_bound:
                 case ShiftBC.ZEROPAD:
-                    samples = np.concatenate([samples, np.zeros(self.margin_right)])
+                    npad_right = self.margin_right
+                    padv_right = 0.0
                 case ShiftBC.FLAT:
-                    samples = np.concatenate(
-                        [samples, np.ones(self.margin_right) * samples[-1]]
-                    )
+                    npad_right = self.margin_right
+                    padv_right = samples[-1]
                 case ShiftBC.EXCEPTION:
                     msg = (
                         f"DynamicShiftNumpy: right edge handling {self._cfg.right_bound.name} not "
@@ -169,9 +172,16 @@ class DynamicShiftNumpy:
                 case _ as unreachable:
                     assert_never(unreachable)
 
-        pos = 0
+        if npad_left > 0 or npad_right > 0:
+            samples = np.pad(
+                samples,
+                (npad_left, npad_right),
+                mode="constant",
+                constant_values=(padv_left, padv_right),
+            )
+
         n_size = npad_left + out_size + self.margin_right
-        n_first = pos + npad_left - self.margin_left
+        n_first = npad_left - self.margin_left
         samples_needed = samples[n_first : n_first + n_size]
         return self._interp_np.apply_shift(samples_needed, shift, self.margin_left)
 
diff --git a/lisainstrument/fixed_shift_dask.py b/lisainstrument/fixed_shift_dask.py
index 5b26022..23dad21 100644
--- a/lisainstrument/fixed_shift_dask.py
+++ b/lisainstrument/fixed_shift_dask.py
@@ -54,10 +54,15 @@ class FixedShiftDask:  # pylint: disable=too-few-public-methods
         if margin_left > 0:
             match self._left_bound:
                 case ShiftBC.ZEROPAD:
-                    samples = da.concatenate([da.zeros(margin_left), samples])
+                    samples = da.pad(
+                        samples, (margin_left, 0), mode="constant", constant_values=0.0
+                    )
                 case ShiftBC.FLAT:
-                    samples = da.concatenate(
-                        [da.ones(margin_left) * samples[0], samples]
+                    samples = da.pad(
+                        samples,
+                        (margin_left, 0),
+                        mode="constant",
+                        constant_values=samples[0],
                     )
                 case ShiftBC.EXCEPTION:
                     msg = (
@@ -75,10 +80,15 @@ class FixedShiftDask:  # pylint: disable=too-few-public-methods
         if margin_right > 0:
             match self._right_bound:
                 case ShiftBC.ZEROPAD:
-                    samples = da.concatenate([samples, da.zeros(margin_right)])
+                    samples = da.pad(
+                        samples, (0, margin_right), mode="constant", constant_values=0.0
+                    )
                 case ShiftBC.FLAT:
-                    samples = da.concatenate(
-                        [samples, da.ones(margin_right) * samples[-1]]
+                    samples = da.pad(
+                        samples,
+                        (0, margin_right),
+                        mode="constant",
+                        constant_values=samples[-1],
                     )
                 case ShiftBC.EXCEPTION:
                     msg = (
diff --git a/lisainstrument/fixed_shift_numpy.py b/lisainstrument/fixed_shift_numpy.py
index 53e4756..399ab16 100644
--- a/lisainstrument/fixed_shift_numpy.py
+++ b/lisainstrument/fixed_shift_numpy.py
@@ -279,10 +279,15 @@ class FixedShiftNumpy:  # pylint: disable=too-few-public-methods
         if margin_left > 0:
             match self._left_bound:
                 case ShiftBC.ZEROPAD:
-                    samples = np.concatenate([np.zeros(margin_left), samples])
+                    samples = np.pad(
+                        samples, (margin_left, 0), mode="constant", constant_values=0.0
+                    )
                 case ShiftBC.FLAT:
-                    samples = np.concatenate(
-                        [np.ones(margin_left) * samples[0], samples]
+                    samples = np.pad(
+                        samples,
+                        (margin_left, 0),
+                        mode="constant",
+                        constant_values=samples[0],
                     )
                 case ShiftBC.EXCEPTION:
                     msg = (
@@ -300,10 +305,15 @@ class FixedShiftNumpy:  # pylint: disable=too-few-public-methods
         if margin_right > 0:
             match self._right_bound:
                 case ShiftBC.ZEROPAD:
-                    samples = np.concatenate([samples, np.zeros(margin_right)])
+                    samples = np.pad(
+                        samples, (0, margin_right), mode="constant", constant_values=0.0
+                    )
                 case ShiftBC.FLAT:
-                    samples = np.concatenate(
-                        [samples, np.ones(margin_right) * samples[-1]]
+                    samples = np.pad(
+                        samples,
+                        (0, margin_right),
+                        mode="constant",
+                        constant_values=samples[-1],
                     )
                 case ShiftBC.EXCEPTION:
                     msg = (
diff --git a/lisainstrument/regular_interpolators.py b/lisainstrument/regular_interpolators.py
index e5691f5..6a99c1c 100644
--- a/lisainstrument/regular_interpolators.py
+++ b/lisainstrument/regular_interpolators.py
@@ -238,9 +238,9 @@ class RegularInterpLagrange(RegularInterpCore):
             msg = "RegularInterpLagrange: interpolation requires samples above provided range"
             raise RuntimeError(msg)
 
-        result = np.zeros(locations.shape[0], dtype=samples.dtype)
-        xpow = np.ones_like(loc_frac)
-        for fir in self._fir_filt:
+        result = self._fir_filt[0](samples)[k]
+        xpow = loc_frac.copy()
+        for fir in self._fir_filt[1:]:
             result[:] += fir(samples)[k] * xpow
             xpow[:] *= loc_frac
 
-- 
GitLab