Skip to content
Snippets Groups Projects
Commit 38920a6f authored by Wolfgang Kastaun's avatar Wolfgang Kastaun
Browse files

minor streamlining of time shift algorithms

parent 508c1e7c
No related branches found
No related tags found
No related merge requests found
......@@ -135,14 +135,15 @@ class DynamicShiftNumpy:
out_size = len(shift)
npad_left = 0
padv_left = 0.0
if self.margin_left > 0:
match self._cfg.left_bound:
case ShiftBC.ZEROPAD:
npad_left = self.margin_left
samples = np.concatenate([np.zeros(npad_left), samples])
padv_left = 0.0
case ShiftBC.FLAT:
npad_left = self.margin_left
samples = np.concatenate([np.ones(npad_left) * samples[0], samples])
padv_left = samples[0]
case ShiftBC.EXCEPTION:
msg = (
f"DynamicShiftNumpy: left edge handling {self._cfg.left_bound.name} not "
......@@ -152,14 +153,16 @@ class DynamicShiftNumpy:
case _ as unreachable:
assert_never(unreachable)
npad_right = 0
padv_right = 0.0
if self.margin_right > 0:
match self._cfg.right_bound:
case ShiftBC.ZEROPAD:
samples = np.concatenate([samples, np.zeros(self.margin_right)])
npad_right = self.margin_right
padv_right = 0.0
case ShiftBC.FLAT:
samples = np.concatenate(
[samples, np.ones(self.margin_right) * samples[-1]]
)
npad_right = self.margin_right
padv_right = samples[-1]
case ShiftBC.EXCEPTION:
msg = (
f"DynamicShiftNumpy: right edge handling {self._cfg.right_bound.name} not "
......@@ -169,9 +172,16 @@ class DynamicShiftNumpy:
case _ as unreachable:
assert_never(unreachable)
pos = 0
if npad_left > 0 or npad_right > 0:
samples = np.pad(
samples,
(npad_left, npad_right),
mode="constant",
constant_values=(padv_left, padv_right),
)
n_size = npad_left + out_size + self.margin_right
n_first = pos + npad_left - self.margin_left
n_first = npad_left - self.margin_left
samples_needed = samples[n_first : n_first + n_size]
return self._interp_np.apply_shift(samples_needed, shift, self.margin_left)
......
......@@ -54,10 +54,15 @@ class FixedShiftDask: # pylint: disable=too-few-public-methods
if margin_left > 0:
match self._left_bound:
case ShiftBC.ZEROPAD:
samples = da.concatenate([da.zeros(margin_left), samples])
samples = da.pad(
samples, (margin_left, 0), mode="constant", constant_values=0.0
)
case ShiftBC.FLAT:
samples = da.concatenate(
[da.ones(margin_left) * samples[0], samples]
samples = da.pad(
samples,
(margin_left, 0),
mode="constant",
constant_values=samples[0],
)
case ShiftBC.EXCEPTION:
msg = (
......@@ -75,10 +80,15 @@ class FixedShiftDask: # pylint: disable=too-few-public-methods
if margin_right > 0:
match self._right_bound:
case ShiftBC.ZEROPAD:
samples = da.concatenate([samples, da.zeros(margin_right)])
samples = da.pad(
samples, (0, margin_right), mode="constant", constant_values=0.0
)
case ShiftBC.FLAT:
samples = da.concatenate(
[samples, da.ones(margin_right) * samples[-1]]
samples = da.pad(
samples,
(0, margin_right),
mode="constant",
constant_values=samples[-1],
)
case ShiftBC.EXCEPTION:
msg = (
......
......@@ -279,10 +279,15 @@ class FixedShiftNumpy: # pylint: disable=too-few-public-methods
if margin_left > 0:
match self._left_bound:
case ShiftBC.ZEROPAD:
samples = np.concatenate([np.zeros(margin_left), samples])
samples = np.pad(
samples, (margin_left, 0), mode="constant", constant_values=0.0
)
case ShiftBC.FLAT:
samples = np.concatenate(
[np.ones(margin_left) * samples[0], samples]
samples = np.pad(
samples,
(margin_left, 0),
mode="constant",
constant_values=samples[0],
)
case ShiftBC.EXCEPTION:
msg = (
......@@ -300,10 +305,15 @@ class FixedShiftNumpy: # pylint: disable=too-few-public-methods
if margin_right > 0:
match self._right_bound:
case ShiftBC.ZEROPAD:
samples = np.concatenate([samples, np.zeros(margin_right)])
samples = np.pad(
samples, (0, margin_right), mode="constant", constant_values=0.0
)
case ShiftBC.FLAT:
samples = np.concatenate(
[samples, np.ones(margin_right) * samples[-1]]
samples = np.pad(
samples,
(0, margin_right),
mode="constant",
constant_values=samples[-1],
)
case ShiftBC.EXCEPTION:
msg = (
......
......@@ -238,9 +238,9 @@ class RegularInterpLagrange(RegularInterpCore):
msg = "RegularInterpLagrange: interpolation requires samples above provided range"
raise RuntimeError(msg)
result = np.zeros(locations.shape[0], dtype=samples.dtype)
xpow = np.ones_like(loc_frac)
for fir in self._fir_filt:
result = self._fir_filt[0](samples)[k]
xpow = loc_frac.copy()
for fir in self._fir_filt[1:]:
result[:] += fir(samples)[k] * xpow
xpow[:] *= loc_frac
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment