From 38920a6fad23ccb160d49a6dfe7ab35c39b516e4 Mon Sep 17 00:00:00 2001 From: Wolfgang Kastaun <wolfgang.kastaun@aei.mpg.de> Date: Sat, 21 Dec 2024 20:55:01 +0100 Subject: [PATCH] minor streamlining of time shift algorithms --- lisainstrument/dynamic_delay_numpy.py | 26 +++++++++++++++++-------- lisainstrument/fixed_shift_dask.py | 22 +++++++++++++++------ lisainstrument/fixed_shift_numpy.py | 22 +++++++++++++++------ lisainstrument/regular_interpolators.py | 6 +++--- 4 files changed, 53 insertions(+), 23 deletions(-) diff --git a/lisainstrument/dynamic_delay_numpy.py b/lisainstrument/dynamic_delay_numpy.py index 6dc71f9..ee251be 100644 --- a/lisainstrument/dynamic_delay_numpy.py +++ b/lisainstrument/dynamic_delay_numpy.py @@ -135,14 +135,15 @@ class DynamicShiftNumpy: out_size = len(shift) npad_left = 0 + padv_left = 0.0 if self.margin_left > 0: match self._cfg.left_bound: case ShiftBC.ZEROPAD: npad_left = self.margin_left - samples = np.concatenate([np.zeros(npad_left), samples]) + padv_left = 0.0 case ShiftBC.FLAT: npad_left = self.margin_left - samples = np.concatenate([np.ones(npad_left) * samples[0], samples]) + padv_left = samples[0] case ShiftBC.EXCEPTION: msg = ( f"DynamicShiftNumpy: left edge handling {self._cfg.left_bound.name} not " @@ -152,14 +153,16 @@ class DynamicShiftNumpy: case _ as unreachable: assert_never(unreachable) + npad_right = 0 + padv_right = 0.0 if self.margin_right > 0: match self._cfg.right_bound: case ShiftBC.ZEROPAD: - samples = np.concatenate([samples, np.zeros(self.margin_right)]) + npad_right = self.margin_right + padv_right = 0.0 case ShiftBC.FLAT: - samples = np.concatenate( - [samples, np.ones(self.margin_right) * samples[-1]] - ) + npad_right = self.margin_right + padv_right = samples[-1] case ShiftBC.EXCEPTION: msg = ( f"DynamicShiftNumpy: right edge handling {self._cfg.right_bound.name} not " @@ -169,9 +172,16 @@ class DynamicShiftNumpy: case _ as unreachable: assert_never(unreachable) - pos = 0 + if npad_left > 0 or npad_right > 0: + samples = np.pad( + samples, + (npad_left, npad_right), + mode="constant", + constant_values=(padv_left, padv_right), + ) + n_size = npad_left + out_size + self.margin_right - n_first = pos + npad_left - self.margin_left + n_first = npad_left - self.margin_left samples_needed = samples[n_first : n_first + n_size] return self._interp_np.apply_shift(samples_needed, shift, self.margin_left) diff --git a/lisainstrument/fixed_shift_dask.py b/lisainstrument/fixed_shift_dask.py index 5b26022..23dad21 100644 --- a/lisainstrument/fixed_shift_dask.py +++ b/lisainstrument/fixed_shift_dask.py @@ -54,10 +54,15 @@ class FixedShiftDask: # pylint: disable=too-few-public-methods if margin_left > 0: match self._left_bound: case ShiftBC.ZEROPAD: - samples = da.concatenate([da.zeros(margin_left), samples]) + samples = da.pad( + samples, (margin_left, 0), mode="constant", constant_values=0.0 + ) case ShiftBC.FLAT: - samples = da.concatenate( - [da.ones(margin_left) * samples[0], samples] + samples = da.pad( + samples, + (margin_left, 0), + mode="constant", + constant_values=samples[0], ) case ShiftBC.EXCEPTION: msg = ( @@ -75,10 +80,15 @@ class FixedShiftDask: # pylint: disable=too-few-public-methods if margin_right > 0: match self._right_bound: case ShiftBC.ZEROPAD: - samples = da.concatenate([samples, da.zeros(margin_right)]) + samples = da.pad( + samples, (0, margin_right), mode="constant", constant_values=0.0 + ) case ShiftBC.FLAT: - samples = da.concatenate( - [samples, da.ones(margin_right) * samples[-1]] + samples = da.pad( + samples, + (0, margin_right), + mode="constant", + constant_values=samples[-1], ) case ShiftBC.EXCEPTION: msg = ( diff --git a/lisainstrument/fixed_shift_numpy.py b/lisainstrument/fixed_shift_numpy.py index 53e4756..399ab16 100644 --- a/lisainstrument/fixed_shift_numpy.py +++ b/lisainstrument/fixed_shift_numpy.py @@ -279,10 +279,15 @@ class FixedShiftNumpy: # pylint: disable=too-few-public-methods if margin_left > 0: match self._left_bound: case ShiftBC.ZEROPAD: - samples = np.concatenate([np.zeros(margin_left), samples]) + samples = np.pad( + samples, (margin_left, 0), mode="constant", constant_values=0.0 + ) case ShiftBC.FLAT: - samples = np.concatenate( - [np.ones(margin_left) * samples[0], samples] + samples = np.pad( + samples, + (margin_left, 0), + mode="constant", + constant_values=samples[0], ) case ShiftBC.EXCEPTION: msg = ( @@ -300,10 +305,15 @@ class FixedShiftNumpy: # pylint: disable=too-few-public-methods if margin_right > 0: match self._right_bound: case ShiftBC.ZEROPAD: - samples = np.concatenate([samples, np.zeros(margin_right)]) + samples = np.pad( + samples, (0, margin_right), mode="constant", constant_values=0.0 + ) case ShiftBC.FLAT: - samples = np.concatenate( - [samples, np.ones(margin_right) * samples[-1]] + samples = np.pad( + samples, + (0, margin_right), + mode="constant", + constant_values=samples[-1], ) case ShiftBC.EXCEPTION: msg = ( diff --git a/lisainstrument/regular_interpolators.py b/lisainstrument/regular_interpolators.py index e5691f5..6a99c1c 100644 --- a/lisainstrument/regular_interpolators.py +++ b/lisainstrument/regular_interpolators.py @@ -238,9 +238,9 @@ class RegularInterpLagrange(RegularInterpCore): msg = "RegularInterpLagrange: interpolation requires samples above provided range" raise RuntimeError(msg) - result = np.zeros(locations.shape[0], dtype=samples.dtype) - xpow = np.ones_like(loc_frac) - for fir in self._fir_filt: + result = self._fir_filt[0](samples)[k] + xpow = loc_frac.copy() + for fir in self._fir_filt[1:]: result[:] += fir(samples)[k] * xpow xpow[:] *= loc_frac -- GitLab