Skip to content
Snippets Groups Projects
Commit e6ae24cf authored by Wolfgang Kastaun's avatar Wolfgang Kastaun
Browse files

improve padding logic in dask interpolator, make numpyfy_dask more flexible

parent 497b9b8c
No related branches found
No related tags found
No related merge requests found
......@@ -5,22 +5,21 @@ Use make_dynamic_shift_lagrange_dask to create a Lagrange interpolator for dask
from __future__ import annotations
from typing import Callable, Final
from typing import Any, Callable, Final
import dask
import dask.array as da
import numpy as np
from typing_extensions import assert_never
from lisainstrument.dynamic_delay_numpy import (
DynShiftBC,
DynShiftCfg,
from lisainstrument.dynamic_delay_numpy import DynShiftBC, DynShiftCfg
from lisainstrument.fir_filters_dask import DaskArray1D, make_dask_array_1d
from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d
from lisainstrument.regular_interpolators import (
RegularInterpolator,
make_regular_interpolator_lagrange,
make_regular_interpolator_linear,
)
from lisainstrument.fir_filters_dask import DaskArray1D, make_dask_array_1d
from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d
class DynamicShiftDask:
......@@ -91,14 +90,17 @@ class DynamicShiftDask:
Dask array with interpolated samples
"""
npad_left = 0
samples_pad = samples
if self.margin_left > 0:
match self._cfg.left_bound:
case DynShiftBC.ZEROPAD:
npad_left = self.margin_left
samples = da.concatenate([da.zeros(npad_left), samples])
samples_pad = da.concatenate([da.zeros(npad_left), samples])
case DynShiftBC.FLAT:
npad_left = self.margin_left
samples = da.concatenate([da.ones(npad_left) * samples[0], samples])
samples_pad = da.concatenate(
[da.ones(npad_left) * samples[0], samples]
)
case DynShiftBC.EXCEPTION:
msg = (
f"DynamicShiftDask: left edge handling {self._cfg.left_bound.name} not "
......@@ -111,10 +113,12 @@ class DynamicShiftDask:
if self.margin_right > 0:
match self._cfg.right_bound:
case DynShiftBC.ZEROPAD:
samples = da.concatenate([samples, da.zeros(self.margin_right)])
samples_pad = da.concatenate(
[samples_pad, da.zeros(self.margin_right)]
)
case DynShiftBC.FLAT:
samples = da.concatenate(
[samples, da.ones(self.margin_right) * samples[-1]]
samples_pad = da.concatenate(
[samples_pad, da.ones(self.margin_right) * samples_pad[-1]]
)
case DynShiftBC.EXCEPTION:
msg = (
......@@ -125,14 +129,15 @@ class DynamicShiftDask:
case _ as unreachable:
assert_never(unreachable)
results = []
chunks = shift.to_delayed()
delayed_op = dask.delayed(self._interp_np.apply_shift)
results = []
pos = 0
for chunk, chunk_shape in zip(chunks, shift.chunks[0], strict=True):
n_size = npad_left + chunk_shape + self.margin_right
n_first = pos + npad_left - self.margin_left
samples_needed = samples[n_first : n_first + n_size]
samples_needed = samples_pad[n_first : n_first + n_size]
samples_shifted = delayed_op(samples_needed, chunk, self.margin_left)
delayed_chunk = da.from_delayed(
samples_shifted, (chunk_shape,), samples.dtype
......@@ -189,13 +194,15 @@ def make_dynamic_shift_lagrange_dask(
return DynamicShiftDask(cfg, interp)
def numpyfy_dask_bivariate(
dafunc: Callable[[da.Array, da.Array], DaskArray1D], chunks: int
) -> Callable[[np.ndarray, np.ndarray], NumpyArray1D]:
"""Convert function operating on dask arrays to work with numpy arrays.
def numpyfy_dask_multi(
dafunc: Callable[[*Any], DaskArray1D], chunks: int
) -> Callable[[*Any], NumpyArray1D]:
"""Convert function operating on 1D dask arrays to work with 1D numpy arrays.
The dask function will be called with two chunked dask arrays created from
the numpy arrays, and then evaluate the result.
Before calling the dask-based function, all arguments which are numpy
arrays are converted to 1D dask arrays chunked with the specified chunk length.
The result will be evaluated as 1D numpy array. Arguments which are not
numpy arrays will be passed unchanged
Arguments:
dafunc: function operating on two dask arrays returning single 1D dask array
......@@ -204,10 +211,11 @@ def numpyfy_dask_bivariate(
function operating on numpy arrays
"""
def func(x, y):
dx = da.from_array(x, chunks=chunks)
dy = da.from_array(y, chunks=chunks)
r = dafunc(make_dask_array_1d(dx), make_dask_array_1d(dy)).compute()
return make_numpy_array_1d(r)
def func(*args: Any) -> NumpyArray1D:
nargs = [
(da.from_array(x, chunks=chunks) if isinstance(x, np.ndarray) else x)
for x in args
]
return make_numpy_array_1d(dafunc(*nargs).compute())
return func
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment