From dc9a79113dcd3c834862df8a223c9b4df7aed9b9 Mon Sep 17 00:00:00 2001
From: Wolfgang Kastaun <wolfgang.kastaun@aei.mpg.de>
Date: Thu, 30 Jan 2025 17:07:41 +0100
Subject: [PATCH] Use new shift inversion methods in instrument class

Note results can change within the inversion tolerance. The reason is that
each dask chunk is inverted separately, which may lead to fewer iterations on the
faster converging chunks of data.
The boundary conditions during the interpolation within the fixed point iteration are
flat. All points are used in the convergence measure. In contrast, the old scheme excluded
a hardcoded margin.

TODO: the case where interpolation is set to None is currently not implemented.

All unit tests are passing.
---
 lisainstrument/instrument.py | 117 ++++++++++++++---------------------
 1 file changed, 45 insertions(+), 72 deletions(-)

diff --git a/lisainstrument/instrument.py b/lisainstrument/instrument.py
index 0a6bd1c..ad99ca5 100755
--- a/lisainstrument/instrument.py
+++ b/lisainstrument/instrument.py
@@ -27,7 +27,7 @@ from scipy.integrate import cumulative_trapezoid
 from scipy.interpolate import InterpolatedUnivariateSpline
 from scipy.signal import firwin, kaiserord, lfilter
 
-from . import dsp, noises
+from . import noises
 from .containers import ForEachMOSA, ForEachSC
 from .dynamic_delay_dask import make_dynamic_shift_lagrange_dask, numpyfy_dask_multi
 from .dynamic_delay_dsp import make_dynamic_shift_dsp_dask
@@ -36,6 +36,10 @@ from .fixed_shift_dask import make_fixed_shift_lagrange_dask
 from .fixed_shift_dsp import make_fixed_shift_dsp_dask
 from .fixed_shift_numpy import AdaptiveShiftNumpy
 from .noises import generate_subseed
+from .shift_inversion_dask import (
+    make_shift_inverse_dsp_dask,
+    make_shift_inverse_lagrange_dask,
+)
 
 logger = logging.getLogger(__name__)
 
@@ -492,10 +496,6 @@ class Instrument:
         # MOC time correlation
         self.moc_time_correlation_asds = ForEachSC(moc_time_correlation_asds)
 
-        # Clock-noise inversion
-        self.clockinv_tolerance = float(clockinv_tolerance)
-        self.clockinv_maxiter = int(clockinv_maxiter)
-
         # Ranging noise
         self.ranging_biases = ForEachMOSA(ranging_biases)
         self.ranging_asds = ForEachMOSA(ranging_asds)
@@ -620,8 +620,7 @@ class Instrument:
         else:
             self.mosa_angles = ForEachMOSA(mosa_angles)
 
-        # Interpolation and antialiasing filter
-        self.init_interpolation(interpolation)
+        # Interpolation and clock-noise inversion
 
         if delay_clock_max < 0:
             msg = (
@@ -639,7 +638,17 @@ class Instrument:
             msg = f"Maximum interspacecraft delay {delay_isc_min} below minimum delay {delay_isc_min} specified"
             raise ValueError(msg)
 
-        self.init_delays(interpolation, delay_isc_min, delay_isc_max, delay_clock_max)
+        self.init_delays(
+            interpolation,
+            delay_isc_min,
+            delay_isc_max,
+            delay_clock_max,
+            clockinv_tolerance,
+            clockinv_maxiter,
+        )
+
+        # Antialiasing filter
+
         self.init_aafilter(aafilter)
 
         # Electronic delays
@@ -748,6 +757,8 @@ class Instrument:
         delay_isc_min: float,
         delay_isc_max: float,
         delay_clock_max: float,
+        clockinv_tolerance: float,
+        clockinv_maxiter: int,
     ) -> None:
         """Initialize or design the interpolation functions for the delays
 
@@ -774,13 +785,22 @@ class Instrument:
         Warning: for delays smaller than the interpolation stencil, this
         will also cause invalid results near the left boundary.
 
+        This also sets up the time frame inversion methods, using the same
+        interpolator in the internal inversion algorithm. The delay_clock_max
+        parameter is also used there during the internal interpolation.
+
         Args:
             interpolation: see `interpolation` docstring in `__init__()`
             delay_isc_min: Minimum allowed interspacecraft delay [s]
             delay_isc_max: Maximum allowed interspacecraft delay [s]
             delay_clock_max: Maximum allowed absolute delay [s] between clocks/tpc/tcp
+            clockinv_tolerance: Tolerance [s] for inverting time coordinate transforms
+            clockinv_maxiter: Maximum iterations for time frame inversion scheme
         """
 
+        self.clockinv_tolerance = float(clockinv_tolerance)
+        self.clockinv_maxiter = int(clockinv_maxiter)
+
         match interpolation:
             case None:
                 self._apply_shift = lambda x, _: x
@@ -800,12 +820,18 @@ class Instrument:
                 self.apply_shift_clock = self.make_adaptive_shift_lagrange(
                     order, -delay_clock_max, delay_clock_max, ShiftBC.FLAT, ShiftBC.FLAT
                 )
+                shift_inversion_da = make_shift_inverse_lagrange_dask(
+                    order,
+                    delay_clock_max,
+                    self.clockinv_maxiter,
+                    self.clockinv_tolerance,
+                )
+                self.interpolation_order = order
 
             case (
                 "lagrange_dsp",
                 int(order),
             ):
-
                 self.apply_shift = self.make_adaptive_shift_dsp(
                     order,
                     delay_isc_min,
@@ -817,6 +843,13 @@ class Instrument:
                 self.apply_shift_clock = self.make_adaptive_shift_dsp(
                     order, -delay_clock_max, delay_clock_max, ShiftBC.FLAT, ShiftBC.FLAT
                 )
+                shift_inversion_da = make_shift_inverse_dsp_dask(
+                    order,
+                    delay_clock_max,
+                    self.clockinv_maxiter,
+                    self.clockinv_tolerance,
+                )
+                self.interpolation_order = order
 
             case ("lagrange", int(order)):
                 msg = (
@@ -829,44 +862,8 @@ class Instrument:
             case _:
                 msg = f"Invalid interpolation parameters {interpolation}"
                 raise RuntimeError(msg)
-            # ~ self.interpolation_order =
-            # ~ self.delay_isc_min = delay_isc_min
-            # ~ self.delay_isc_max = delay_isc_max
-
-    def init_interpolation(self, interpolation):
-        """Initialize or design the interpolation function.
 
-        We support no interpolation, a custom interpolation function, or Lagrange interpolation.
-
-        Args:
-            parameters: see `interpolation` docstring in `__init__()`
-        """
-        if interpolation is None:
-            logger.info("Disabling interpolation")
-            self.interpolation_order = None
-            self.interpolate = lambda x, _: x
-        elif callable(interpolation):
-            logger.info("Using user-provided interpolation function")
-            self.interpolation_order = None
-            self.interpolate = lambda x, shift: (
-                x if np.isscalar(x) else interpolation(x, shift * self.physics_fs)
-            )
-        else:
-            method = str(interpolation[0])
-            if method == "lagrange":
-                self.interpolation_order = int(interpolation[1])
-                logger.debug(
-                    "Using Lagrange interpolation of order %s", self.interpolation_order
-                )
-                self.interpolate = lambda x, shift: (
-                    x
-                    if np.isscalar(x)
-                    else dsp.timeshift(
-                        x, shift * self.physics_fs, self.interpolation_order
-                    )
-                )
-            else:
-                raise ValueError(f"invalid interpolation parameters '{interpolation}'")
+        self.shift_inversion = self.numpyfy_dask_multi(shift_inversion_da)
 
     def init_aafilter(self, aafilter):
         """Initialize antialiasing filter and downsampling."""
@@ -3047,32 +3044,8 @@ class Instrument:
             self.clockinv_maxiter,
         )
 
-        # Drop samples at the edges to compute error
-        edge = min(100, len(scet_wrt_tps) // 2 - 1)
-        error = 0
-
-        niter = 0
-        next_inverse = scet_wrt_tps
-        while not niter or error > self.clockinv_tolerance:
-            if niter >= self.clockinv_maxiter:
-                logger.warning(
-                    "Maximum number of iterations '%s' reached for SC %s (error=%.2E)",
-                    niter,
-                    sc,
-                    error,
-                )
-                break
-            logger.debug("Starting iteration #%s", niter)
-            inverse = next_inverse
-            next_inverse = self.interpolate(scet_wrt_tps, -inverse)
-            error = np.max(np.abs((inverse - next_inverse)[edge:-edge]))
-            logger.debug("End of iteration %s, with an error of %.2E s", niter, error)
-            niter += 1
-        logger.debug(
-            "End of SCET with respect to TCB inversion after %s iterations with an error of %.2E s",
-            niter,
-            error,
-        )
+        inverse = self.shift_inversion(scet_wrt_tps, self.physics_fs)
+        logger.debug("End of SCET with respect to TCB inversion")
         return inverse
 
     def _write_attr(self, hdf5, *names):
-- 
GitLab