From 898e15ddc87e43f2d52229f683b7bba16337e26c Mon Sep 17 00:00:00 2001
From: Jean-Baptiste Bayle <j2b.bayle@gmail.com>
Date: Wed, 3 Feb 2021 09:24:36 +0100
Subject: [PATCH] Add a method to plot a `ForEachObject`

---
 lisainstrument/containers.py | 36 ++++++++++++++++++++++++++++++++++++
 1 file changed, 36 insertions(+)

diff --git a/lisainstrument/containers.py b/lisainstrument/containers.py
index 434a14e..f16c584 100644
--- a/lisainstrument/containers.py
+++ b/lisainstrument/containers.py
@@ -11,6 +11,7 @@ import abc
 import logging
 import numpy
 import h5py
+import matplotlib.pyplot
 
 
 class ForEachObject(abc.ABC):
@@ -93,6 +94,41 @@ class ForEachObject(abc.ABC):
         """Return dictionary items."""
         return self.dict.items()
 
+    def plot(self, output=None, dt=1, t0=0, size='auto', title='Signals'):
+        """Plot signals for each object.
+
+        Args:
+            output: output file, None to show the plots
+            dt: sampling period [s]
+            t0: initial time [s]
+            size: duration of time series, or 'auto' [samples]
+            title: plot title
+        """
+        if size == 'auto':
+            size = len(self) if len(self) > 1 else 100
+        t = t0 + numpy.arange(size) * dt
+        # Plot signals
+        logging.info("Plotting signals for each object")
+        matplotlib.pyplot.figure(figsize=(12, 4))
+        for key, signal in self.items():
+            matplotlib.pyplot.plot(t, numpy.broadcast_to(signal, size), label=key)
+        matplotlib.pyplot.grid()
+        matplotlib.pyplot.legend()
+        matplotlib.pyplot.xlabel("Time [s]")
+        matplotlib.pyplot.ylabel("Signal")
+        matplotlib.pyplot.title(title)
+        # Save or show glitch
+        if output is not None:
+            logging.info("Saving plot to %s", output)
+            matplotlib.pyplot.savefig(output, bbox_inches='tight')
+        else:
+            matplotlib.pyplot.show()
+
+    def __len__(self):
+        """Return maximum size of signals."""
+        sizes = [1 if numpy.isscalar(signal) else len(signal) for signal in self.values()]
+        return numpy.max(sizes)
+
     def __eq__(self, other):
         if isinstance(other, self.__class__):
             return self.dict == other.dict
-- 
GitLab