From b365b1d9c0ad063611cdd622c601f425f7ca269e Mon Sep 17 00:00:00 2001
From: Jean-Baptiste Bayle <j2b.bayle@gmail.com>
Date: Wed, 21 Sep 2022 20:43:13 +0100
Subject: [PATCH] Update containers to allow for concurrency

---
 lisainstrument/containers.py | 17 +++++++++++++----
 1 file changed, 13 insertions(+), 4 deletions(-)

diff --git a/lisainstrument/containers.py b/lisainstrument/containers.py
index 50d4af3..d26b720 100644
--- a/lisainstrument/containers.py
+++ b/lisainstrument/containers.py
@@ -9,6 +9,7 @@ Authors:
 
 import abc
 import logging
+from concurrent.futures import ThreadPoolExecutor
 import h5py
 import numpy as np
 import matplotlib.pyplot as plt
@@ -19,16 +20,23 @@ logger = logging.getLogger(__name__)
 class ForEachObject(abc.ABC):
     """Abstract class which represents a dictionary holding a value for each object."""
 
-    def __init__(self, values):
+    def __init__(self, values, concurrent=False):
         """Initialize an object with a value or a function of the mosa index.
 
         Args:
             values: a value, a dictionary of values, or a function (mosa -> value)
+            concurrent (bool): whether to parallelize using a thread pool
         """
         if isinstance(values, dict):
             self.dict = {mosa: values[mosa] for mosa in self.indices()}
         elif callable(values):
-            self.dict = {mosa: values(mosa) for mosa in self.indices()}
+            if concurrent:
+                with ThreadPoolExecutor() as executor:
+                    indices = self.indices()
+                    computed_values = executor.map(values, indices)
+                    self.dict = dict(zip(indices, computed_values))
+            else:
+                self.dict = {mosa: values(mosa) for mosa in self.indices()}
         elif isinstance(values, h5py.Dataset):
             self.dict = {mosa: values[mosa] for mosa in self.indices()}
         else:
@@ -40,13 +48,14 @@ class ForEachObject(abc.ABC):
         """Return list of object indices."""
         raise NotImplementedError
 
-    def transformed(self, transformation):
+    def transformed(self, transformation, concurrent=False):
         """Return a new dictionary from transformed objects.
 
         Args:
             transformation: function (mosa, value -> new_value)
+            concurrent (bool): whether to parallelize using a thread pool
         """
-        return self.__class__(lambda mosa: transformation(mosa, self[mosa]))
+        return self.__class__(lambda mosa: transformation(mosa, self[mosa]), concurrent)
 
     def collapsed(self):
         """Turn a numpy arrays containing identical elements into a scalar.
-- 
GitLab