diff --git a/lisainstrument/instrument.py b/lisainstrument/instrument.py index dc9c898760a83a6e713526836c34f997767c45d4..20424c8fbb255a101f4f0b937356e64dd38dab5e 100755 --- a/lisainstrument/instrument.py +++ b/lisainstrument/instrument.py @@ -29,6 +29,11 @@ from scipy.signal import firwin, kaiserord, lfilter from . import dsp, noises from .containers import ForEachMOSA, ForEachSC +from .dynamic_delay_dask import make_dynamic_shift_lagrange_dask, numpyfy_dask_multi +from .dynamic_delay_dsp import make_dynamic_shift_dsp_dask +from .dynamic_delay_numpy import ShiftBC +from .fixed_shift_dask import make_fixed_shift_lagrange_dask +from .fixed_shift_dsp import make_fixed_shift_dsp_dask from .noises import generate_subseed logger = logging.getLogger(__name__) @@ -59,9 +64,10 @@ class Instrument: gws: path to gravitational-wave file, or dictionary of gravitational-wave responses; if ``orbit_dataset`` is ``'tps/ppr'``, we try to read link responses as functions of the TPS instead of link responses in the TCB (fallback behavior) - interpolation: interpolation function or interpolation method and parameters; - use a tuple ('lagrange', order) with `order` the odd Lagrange interpolation order; - an arbitrary function should take (x, shift [number of samples]) as parameter + interpolation: interpolation method and parameters; + use a tuple ('lagrange', order, min_delay_time, max_delay_time) with `order` the odd + Lagrange interpolation order; min_delay_time and max_delay_time are the minimum and + maximum delays that can occur in the entire simulation. glitches: path to glitch file, or dictionary of glitch signals per injection point lock: pre-defined laser locking configuration (e.g., 'N1-12' is configuration N1 with 12 as primary laser), or 'six' for 6 lasers locked on cavities, or a dictionary of locking @@ -124,6 +130,7 @@ class Instrument: None or 0 for no ambiguities electro_delays: tuple (isi, tmi, rfi) of dictionaries for electronic delays [s] concurrent (bool): whether to use multiprocessing + chunking_size (int): size of chunks when using dask """ # pylint: disable=attribute-defined-outside-init @@ -219,7 +226,7 @@ class Instrument: orbits="static", orbit_dataset="tps/ppr", gws=None, - interpolation=("lagrange", 31), + interpolation=("lagrange", 31, 6.0, 12.0), # Artifacts glitches=None, # Laser locking and frequency plan @@ -275,6 +282,7 @@ class Instrument: electro_delays=(0, 0, 0), # Concurrency concurrent=False, + chunking_size=256, ) -> None: # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-statements,too-many-locals,too-many-branches logger.info("Initializing instrumental simulation") @@ -282,6 +290,7 @@ class Instrument: self.version = importlib_metadata.version("lisainstrument") self.simulated = False self.concurrent = bool(concurrent) + self.chunking_size = int(chunking_size) if seed is None: seed = noises.generate_random_seed() @@ -607,6 +616,7 @@ class Instrument: # Interpolation and antialiasing filter self.init_interpolation(interpolation) + self.init_delays(interpolation) self.init_aafilter(aafilter) # Electronic delays @@ -614,6 +624,127 @@ class Instrument: self.electro_delays_tmis = ForEachMOSA(electro_delays[1]) self.electro_delays_rfis = ForEachMOSA(electro_delays[2]) + def numpyfy_dask_multi(self, op): + """Turn dask-based functions into functions accepting numpy arrays + + Turn functions operating on one or more dask array(s) and returning + one dask array into functions accepting numpy arrays in place of dask arrays, + turning each into a dask array with given chunk size, and return + a numpy array from evaluating the dask array result. + + This is a temporary measure during refactoring. + + Args: + op: The dask function to convert + """ + return numpyfy_dask_multi(op, chunks=self.chunking_size) + + def apply_shift(self, x, shift_time): + """Apply a time shift using interpolators set up in init_delays() + + The shift is given in time units, and the data is assumed to be sampled + with rate self.physics_fs. + + Both data and time shift can be scalar or 1D numpy arrays. Scalars + are interpreted as constant arrays. In case of scalar data, the same + scalar is returned. In case of scalar shift, a more efficient algorithm + is used, which should yield identical results as for a const shift array. + + The properties of the interpolation algorithms are set in init_delays(), + see there for details. + + Args: + x: the data to be shifted, scalar or 1D array + shift_time: the shift in time units, scalar or 1D array + + Returns: + The shifted data + """ + if np.isscalar(x): + return x + shift_samples = shift_time * self.physics_fs + if np.isscalar(shift_time): + return self._delay_const(x, float(shift_samples)) + return self._delay_dynamic(x, shift_samples) + + def init_delays(self, interpolation): + """Initialize or design the interpolation functions for the delays + + We support no interpolation or Lagrange interpolation. + This sets up two interpolation methods, one for dynamic delays and one + for constant delays. + + There are two implementations of Lagrange interpolation. The "lagrange" + method is written from scratch, wheras the "lagrange_dsp" method is a + wrapper around the existing dsp.timeshift(). Internally, both are called + through a common interface by a function providing the dask integration (also + responsible for the boundary conditions). + + The dask-based interpolators are cast into functions working on + numpy arrays and returning a numpy array for the duration of the + refactoring. + + The interpolation tuple now has two more members giving the minimum + and maximum delay that the timeshift has to expect. This is a technical + constraint from working with dask arrays. + + The boundary conditions are set to flat on the left side. On the right + side, the delays must be large enough to allow computing the shifted data + without using boundary conditions. Otherwise, an exception will be raised. + + Args: + interpolation: see `interpolation` docstring in `__init__()` + """ + + match interpolation: + case None: + self._dynamic_delay = lambda x, _: x + self._const_delay = self._dynamic_delay + case ("lagrange", int(order), float(min_delay_time), float(max_delay_time)): + # ~ self.interpolation_order = order + op_dyn = make_dynamic_shift_lagrange_dask( + order, + min_delay_time * self.physics_fs, + max_delay_time * self.physics_fs, + ShiftBC.FLAT, + ShiftBC.EXCEPTION, + ) + op_fix = make_fixed_shift_lagrange_dask( + ShiftBC.FLAT, ShiftBC.EXCEPTION, order + ) + self._delay_dynamic = self.numpyfy_dask_multi(op_dyn) + self._delay_const = self.numpyfy_dask_multi(op_fix) + case ( + "lagrange_dsp", + int(order), + float(min_delay_time), + float(max_delay_time), + ): + # ~ self.interpolation_order = order + op_dyn = make_dynamic_shift_dsp_dask( + order, + min_delay_time * self.physics_fs, + max_delay_time * self.physics_fs, + ShiftBC.FLAT, + ShiftBC.EXCEPTION, + ) + op_fix = make_fixed_shift_dsp_dask( + ShiftBC.FLAT, ShiftBC.EXCEPTION, order + ) + self._delay_dynamic = self.numpyfy_dask_multi(op_dyn) + self._delay_const = self.numpyfy_dask_multi(op_fix) + case ("lagrange", int(order)): + msg = ( + "interpolation parameter tuples need 4 entries since daskification" + ) + raise RuntimeError(msg) + case func if callable(func): + msg = "Custom callables as interpolation method temporarily forbidden during daskification" + raise RuntimeError(msg) + case _: + msg = f"Invalid interpolation parameters {interpolation}" + raise RuntimeError(msg) + def init_interpolation(self, interpolation): """Initialize or design the interpolation function. @@ -1610,7 +1741,7 @@ class Instrument: logger.debug("Propagating carrier offsets to distant MOSAs") delayed_distant_carrier_offsets = ( self.local_carrier_offsets.distant().transformed( - lambda mosa, x: self.interpolate(x, -self.pprs[mosa]), + lambda mosa, x: self.apply_shift(x, -self.pprs[mosa]), concurrent=self.concurrent, ) ) @@ -1629,7 +1760,7 @@ class Instrument: propagated_carrier_fluctuations = ( 1 - self.d_pprs ) * carrier_fluctuations.distant().transformed( - lambda mosa, x: self.interpolate(x, -self.pprs[mosa]), + lambda mosa, x: self.apply_shift(x, -self.pprs[mosa]), concurrent=self.concurrent, ) self.distant_carrier_fluctuations = ( @@ -1642,7 +1773,7 @@ class Instrument: logger.debug("Propagating upper sideband offsets to distant MOSAs") delayed_distant_usb_offsets = self.local_usb_offsets.distant().transformed( - lambda mosa, x: self.interpolate(x, -self.pprs[mosa]), + lambda mosa, x: self.apply_shift(x, -self.pprs[mosa]), concurrent=self.concurrent, ) self.distant_usb_offsets = ( @@ -1660,7 +1791,7 @@ class Instrument: propagated_usb_fluctuations = ( 1 - self.d_pprs ) * usb_fluctuations.distant().transformed( - lambda mosa, x: self.interpolate(x, -self.pprs[mosa]), + lambda mosa, x: self.apply_shift(x, -self.pprs[mosa]), concurrent=self.concurrent, ) self.distant_usb_fluctuations = ( @@ -1676,7 +1807,7 @@ class Instrument: self.scet_wrt_tps_local.for_each_mosa() .distant() .transformed( - lambda mosa, x: self.interpolate(x, -self.pprs[mosa]) - self.pprs[mosa] + lambda mosa, x: self.apply_shift(x, -self.pprs[mosa]) - self.pprs[mosa] ) ) @@ -2712,7 +2843,7 @@ class Instrument: distant(mosa), ) carrier_offsets = self.local_carrier_offsets[distant(mosa)] - delayed_distant_carrier_offsets = self.interpolate( + delayed_distant_carrier_offsets = self.apply_shift( carrier_offsets, -self.pprs[mosa] ) distant_carrier_offsets = ( @@ -2737,7 +2868,7 @@ class Instrument: ) distant_carrier_fluctuations = ( (1 - self.d_pprs[mosa]) - * self.interpolate(carrier_fluctuations, -self.pprs[mosa]) + * self.apply_shift(carrier_fluctuations, -self.pprs[mosa]) - (self.central_freq + delayed_distant_carrier_offsets) * self.gws[mosa] - (self.central_freq + delayed_distant_carrier_offsets) * (self.local_ttls[mosa] - self.mosa_jitter_xs[mosa])