diff --git a/lisainstrument/instrument.py b/lisainstrument/instrument.py index 0a6bd1c539679a20bb6c4c72f68e1b2586b0178c..ad99ca58c237ef91583cb27298a40a5651214e3f 100755 --- a/lisainstrument/instrument.py +++ b/lisainstrument/instrument.py @@ -27,7 +27,7 @@ from scipy.integrate import cumulative_trapezoid from scipy.interpolate import InterpolatedUnivariateSpline from scipy.signal import firwin, kaiserord, lfilter -from . import dsp, noises +from . import noises from .containers import ForEachMOSA, ForEachSC from .dynamic_delay_dask import make_dynamic_shift_lagrange_dask, numpyfy_dask_multi from .dynamic_delay_dsp import make_dynamic_shift_dsp_dask @@ -36,6 +36,10 @@ from .fixed_shift_dask import make_fixed_shift_lagrange_dask from .fixed_shift_dsp import make_fixed_shift_dsp_dask from .fixed_shift_numpy import AdaptiveShiftNumpy from .noises import generate_subseed +from .shift_inversion_dask import ( + make_shift_inverse_dsp_dask, + make_shift_inverse_lagrange_dask, +) logger = logging.getLogger(__name__) @@ -492,10 +496,6 @@ class Instrument: # MOC time correlation self.moc_time_correlation_asds = ForEachSC(moc_time_correlation_asds) - # Clock-noise inversion - self.clockinv_tolerance = float(clockinv_tolerance) - self.clockinv_maxiter = int(clockinv_maxiter) - # Ranging noise self.ranging_biases = ForEachMOSA(ranging_biases) self.ranging_asds = ForEachMOSA(ranging_asds) @@ -620,8 +620,7 @@ class Instrument: else: self.mosa_angles = ForEachMOSA(mosa_angles) - # Interpolation and antialiasing filter - self.init_interpolation(interpolation) + # Interpolation and clock-noise inversion if delay_clock_max < 0: msg = ( @@ -639,7 +638,17 @@ class Instrument: msg = f"Maximum interspacecraft delay {delay_isc_min} below minimum delay {delay_isc_min} specified" raise ValueError(msg) - self.init_delays(interpolation, delay_isc_min, delay_isc_max, delay_clock_max) + self.init_delays( + interpolation, + delay_isc_min, + delay_isc_max, + delay_clock_max, + clockinv_tolerance, + clockinv_maxiter, + ) + + # Antialiasing filter + self.init_aafilter(aafilter) # Electronic delays @@ -748,6 +757,8 @@ class Instrument: delay_isc_min: float, delay_isc_max: float, delay_clock_max: float, + clockinv_tolerance: float, + clockinv_maxiter: int, ) -> None: """Initialize or design the interpolation functions for the delays @@ -774,13 +785,22 @@ class Instrument: Warning: for delays smaller than the interpolation stencil, this will also cause invalid results near the left boundary. + This also sets up the time frame inversion methods, using the same + interpolator in the internal inversion algorithm. The delay_clock_max + parameter is also used there during the internal interpolation. + Args: interpolation: see `interpolation` docstring in `__init__()` delay_isc_min: Minimum allowed interspacecraft delay [s] delay_isc_max: Maximum allowed interspacecraft delay [s] delay_clock_max: Maximum allowed absolute delay [s] between clocks/tpc/tcp + clockinv_tolerance: Tolerance [s] for inverting time coordinate transforms + clockinv_maxiter: Maximum iterations for time frame inversion scheme """ + self.clockinv_tolerance = float(clockinv_tolerance) + self.clockinv_maxiter = int(clockinv_maxiter) + match interpolation: case None: self._apply_shift = lambda x, _: x @@ -800,12 +820,18 @@ class Instrument: self.apply_shift_clock = self.make_adaptive_shift_lagrange( order, -delay_clock_max, delay_clock_max, ShiftBC.FLAT, ShiftBC.FLAT ) + shift_inversion_da = make_shift_inverse_lagrange_dask( + order, + delay_clock_max, + self.clockinv_maxiter, + self.clockinv_tolerance, + ) + self.interpolation_order = order case ( "lagrange_dsp", int(order), ): - self.apply_shift = self.make_adaptive_shift_dsp( order, delay_isc_min, @@ -817,6 +843,13 @@ class Instrument: self.apply_shift_clock = self.make_adaptive_shift_dsp( order, -delay_clock_max, delay_clock_max, ShiftBC.FLAT, ShiftBC.FLAT ) + shift_inversion_da = make_shift_inverse_dsp_dask( + order, + delay_clock_max, + self.clockinv_maxiter, + self.clockinv_tolerance, + ) + self.interpolation_order = order case ("lagrange", int(order)): msg = ( @@ -829,44 +862,8 @@ class Instrument: case _: msg = f"Invalid interpolation parameters {interpolation}" raise RuntimeError(msg) - # ~ self.interpolation_order = - # ~ self.delay_isc_min = delay_isc_min - # ~ self.delay_isc_max = delay_isc_max - - def init_interpolation(self, interpolation): - """Initialize or design the interpolation function. - We support no interpolation, a custom interpolation function, or Lagrange interpolation. - - Args: - parameters: see `interpolation` docstring in `__init__()` - """ - if interpolation is None: - logger.info("Disabling interpolation") - self.interpolation_order = None - self.interpolate = lambda x, _: x - elif callable(interpolation): - logger.info("Using user-provided interpolation function") - self.interpolation_order = None - self.interpolate = lambda x, shift: ( - x if np.isscalar(x) else interpolation(x, shift * self.physics_fs) - ) - else: - method = str(interpolation[0]) - if method == "lagrange": - self.interpolation_order = int(interpolation[1]) - logger.debug( - "Using Lagrange interpolation of order %s", self.interpolation_order - ) - self.interpolate = lambda x, shift: ( - x - if np.isscalar(x) - else dsp.timeshift( - x, shift * self.physics_fs, self.interpolation_order - ) - ) - else: - raise ValueError(f"invalid interpolation parameters '{interpolation}'") + self.shift_inversion = self.numpyfy_dask_multi(shift_inversion_da) def init_aafilter(self, aafilter): """Initialize antialiasing filter and downsampling.""" @@ -3047,32 +3044,8 @@ class Instrument: self.clockinv_maxiter, ) - # Drop samples at the edges to compute error - edge = min(100, len(scet_wrt_tps) // 2 - 1) - error = 0 - - niter = 0 - next_inverse = scet_wrt_tps - while not niter or error > self.clockinv_tolerance: - if niter >= self.clockinv_maxiter: - logger.warning( - "Maximum number of iterations '%s' reached for SC %s (error=%.2E)", - niter, - sc, - error, - ) - break - logger.debug("Starting iteration #%s", niter) - inverse = next_inverse - next_inverse = self.interpolate(scet_wrt_tps, -inverse) - error = np.max(np.abs((inverse - next_inverse)[edge:-edge])) - logger.debug("End of iteration %s, with an error of %.2E s", niter, error) - niter += 1 - logger.debug( - "End of SCET with respect to TCB inversion after %s iterations with an error of %.2E s", - niter, - error, - ) + inverse = self.shift_inversion(scet_wrt_tps, self.physics_fs) + logger.debug("End of SCET with respect to TCB inversion") return inverse def _write_attr(self, hdf5, *names):