From e6ae24cf109656321b5d9e4d8c754a527b6f2755 Mon Sep 17 00:00:00 2001 From: Wolfgang Kastaun <wolfgang.kastaun@aei.mpg.de> Date: Thu, 19 Dec 2024 16:06:47 +0100 Subject: [PATCH] improve padding logic in dask interpolator, make numpyfy_dask more flexible --- lisainstrument/dynamic_delay_dask.py | 56 ++++++++++++++++------------ 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/lisainstrument/dynamic_delay_dask.py b/lisainstrument/dynamic_delay_dask.py index cec8cf8..1962d03 100644 --- a/lisainstrument/dynamic_delay_dask.py +++ b/lisainstrument/dynamic_delay_dask.py @@ -5,22 +5,21 @@ Use make_dynamic_shift_lagrange_dask to create a Lagrange interpolator for dask from __future__ import annotations -from typing import Callable, Final +from typing import Any, Callable, Final import dask import dask.array as da import numpy as np from typing_extensions import assert_never -from lisainstrument.dynamic_delay_numpy import ( - DynShiftBC, - DynShiftCfg, +from lisainstrument.dynamic_delay_numpy import DynShiftBC, DynShiftCfg +from lisainstrument.fir_filters_dask import DaskArray1D, make_dask_array_1d +from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d +from lisainstrument.regular_interpolators import ( RegularInterpolator, make_regular_interpolator_lagrange, make_regular_interpolator_linear, ) -from lisainstrument.fir_filters_dask import DaskArray1D, make_dask_array_1d -from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d class DynamicShiftDask: @@ -91,14 +90,17 @@ class DynamicShiftDask: Dask array with interpolated samples """ npad_left = 0 + samples_pad = samples if self.margin_left > 0: match self._cfg.left_bound: case DynShiftBC.ZEROPAD: npad_left = self.margin_left - samples = da.concatenate([da.zeros(npad_left), samples]) + samples_pad = da.concatenate([da.zeros(npad_left), samples]) case DynShiftBC.FLAT: npad_left = self.margin_left - samples = da.concatenate([da.ones(npad_left) * samples[0], samples]) + samples_pad = da.concatenate( + [da.ones(npad_left) * samples[0], samples] + ) case DynShiftBC.EXCEPTION: msg = ( f"DynamicShiftDask: left edge handling {self._cfg.left_bound.name} not " @@ -111,10 +113,12 @@ class DynamicShiftDask: if self.margin_right > 0: match self._cfg.right_bound: case DynShiftBC.ZEROPAD: - samples = da.concatenate([samples, da.zeros(self.margin_right)]) + samples_pad = da.concatenate( + [samples_pad, da.zeros(self.margin_right)] + ) case DynShiftBC.FLAT: - samples = da.concatenate( - [samples, da.ones(self.margin_right) * samples[-1]] + samples_pad = da.concatenate( + [samples_pad, da.ones(self.margin_right) * samples_pad[-1]] ) case DynShiftBC.EXCEPTION: msg = ( @@ -125,14 +129,15 @@ class DynamicShiftDask: case _ as unreachable: assert_never(unreachable) + results = [] + chunks = shift.to_delayed() delayed_op = dask.delayed(self._interp_np.apply_shift) - results = [] pos = 0 for chunk, chunk_shape in zip(chunks, shift.chunks[0], strict=True): n_size = npad_left + chunk_shape + self.margin_right n_first = pos + npad_left - self.margin_left - samples_needed = samples[n_first : n_first + n_size] + samples_needed = samples_pad[n_first : n_first + n_size] samples_shifted = delayed_op(samples_needed, chunk, self.margin_left) delayed_chunk = da.from_delayed( samples_shifted, (chunk_shape,), samples.dtype @@ -189,13 +194,15 @@ def make_dynamic_shift_lagrange_dask( return DynamicShiftDask(cfg, interp) -def numpyfy_dask_bivariate( - dafunc: Callable[[da.Array, da.Array], DaskArray1D], chunks: int -) -> Callable[[np.ndarray, np.ndarray], NumpyArray1D]: - """Convert function operating on dask arrays to work with numpy arrays. +def numpyfy_dask_multi( + dafunc: Callable[[*Any], DaskArray1D], chunks: int +) -> Callable[[*Any], NumpyArray1D]: + """Convert function operating on 1D dask arrays to work with 1D numpy arrays. - The dask function will be called with two chunked dask arrays created from - the numpy arrays, and then evaluate the result. + Before calling the dask-based function, all arguments which are numpy + arrays are converted to 1D dask arrays chunked with the specified chunk length. + The result will be evaluated as 1D numpy array. Arguments which are not + numpy arrays will be passed unchanged Arguments: dafunc: function operating on two dask arrays returning single 1D dask array @@ -204,10 +211,11 @@ def numpyfy_dask_bivariate( function operating on numpy arrays """ - def func(x, y): - dx = da.from_array(x, chunks=chunks) - dy = da.from_array(y, chunks=chunks) - r = dafunc(make_dask_array_1d(dx), make_dask_array_1d(dy)).compute() - return make_numpy_array_1d(r) + def func(*args: Any) -> NumpyArray1D: + nargs = [ + (da.from_array(x, chunks=chunks) if isinstance(x, np.ndarray) else x) + for x in args + ] + return make_numpy_array_1d(dafunc(*nargs).compute()) return func -- GitLab