From acf339eb96e93d830a03944e286747fa3383ebab Mon Sep 17 00:00:00 2001
From: Jean-Baptiste Bayle <j2b.bayle@gmail.com>
Date: Mon, 8 Mar 2021 19:35:04 +0100
Subject: [PATCH] Add arithmetics to `ForEachObject`

---
 lisainstrument/containers.py |  60 ++++++++++++++
 tests/test_containers.py     | 154 +++++++++++++++++++++++++++++++++++
 2 files changed, 214 insertions(+)

diff --git a/lisainstrument/containers.py b/lisainstrument/containers.py
index e8c9679..2c794d5 100644
--- a/lisainstrument/containers.py
+++ b/lisainstrument/containers.py
@@ -137,6 +137,66 @@ class ForEachObject(abc.ABC):
             return self.dict == other
         return numpy.all([self[index] == other for index in self.indices()])
 
+    def __abs__(self):
+        return self.transformed(lambda index, value: abs(value))
+
+    def __neg__(self):
+        return self.transformed(lambda index, value: -value)
+
+    def __add__(self, other):
+        if numpy.isscalar(other):
+            return self.transformed(lambda _, value: value + other)
+        if isinstance(other, type(self)):
+            return self.transformed(lambda index, value: value + other[index])
+        raise TypeError(f"unsupported operand type for +: '{type(self)}' and '{type(other)}'")
+
+    def __radd__(self, other):
+        return self + other
+
+    def __sub__(self, other):
+        if numpy.isscalar(other):
+            return self.transformed(lambda _, value: value - other)
+        if isinstance(other, type(self)):
+            return self.transformed(lambda index, value: value - other[index])
+        raise TypeError(f"unsupported operand type for -: '{type(self)}' and '{type(other)}'")
+
+    def __rsub__(self, other):
+        return -(self - other)
+
+    def __mul__(self, other):
+        if numpy.isscalar(other):
+            return self.transformed(lambda _, value: value * other)
+        if isinstance(other, type(self)):
+            return self.transformed(lambda index, value: value * other[index])
+        raise TypeError(f"unsupported operand type for *: '{type(self)}' and '{type(other)}'")
+
+    def __rmul__(self, other):
+        return self * other
+
+    def __floordiv__(self, other):
+        if numpy.isscalar(other):
+            return self.transformed(lambda _, value: value // other)
+        if isinstance(other, type(self)):
+            return self.transformed(lambda index, value: value // other[index])
+        raise TypeError(f"unsupported operand type for //: '{type(self)}' and '{type(other)}'")
+
+    def __truediv__(self, other):
+        if numpy.isscalar(other):
+            return self.transformed(lambda _, value: value / other)
+        if isinstance(other, type(self)):
+            return self.transformed(lambda index, value: value / other[index])
+        raise TypeError(f"unsupported operand type for /: '{type(self)}' and '{type(other)}'")
+
+    def __rtruediv__(self, other):
+        if numpy.isscalar(other):
+            return self.transformed(lambda _, value: other / value)
+        raise TypeError(f"unsupported operand type for /: '{type(self)}' and '{type(other)}'")
+
+    def __pow__(self, other):
+        if numpy.isscalar(other):
+            return self.transformed(lambda _, value: value**other)
+        raise TypeError(f"unsupported operand type for **: '{type(self)}' and '{type(other)}'")
+
     def __repr__(self):
         return repr(self.dict)
 
diff --git a/tests/test_containers.py b/tests/test_containers.py
index 094ced5..e9878ce 100644
--- a/tests/test_containers.py
+++ b/tests/test_containers.py
@@ -120,6 +120,141 @@ def test_equality():
     assert object_1 == object_2
 
 
+def test_addition():
+    """Check that we can add two ForEachObject subclasses of the same type, or a scalar."""
+
+    object_1 = ForEachAB({'A': 1, 'B': 2})
+    object_2 = ForEachAB({'A': 3, 'B': -4})
+    object_3 = ForEachMOSA(0)
+
+    result = object_1 + object_2
+    assert result['A'] == 4
+    assert result['B'] == -2
+
+    result = object_1 + 10
+    assert result['A'] == 11
+    assert result['B'] == 12
+
+    result = -5 + object_2
+    assert result['A'] == -2
+    assert result['B'] == -9
+
+    with raises(TypeError):
+        result = object_1 + object_3
+    with raises(TypeError):
+        result = object_2 + object_3
+
+
+def test_subtraction():
+    """Check that we can subtract two ForEachObject subclasses of the same type, or a scalar."""
+
+    object_1 = ForEachAB({'A': 1, 'B': 2})
+    object_2 = ForEachAB({'A': 3, 'B': -4})
+    object_3 = ForEachMOSA(0)
+
+    result = object_1 - object_2
+    assert result['A'] == -2
+    assert result['B'] == 6
+
+    result = object_1 - 10
+    assert result['A'] == -9
+    assert result['B'] == -8
+
+    result = 10 - object_2
+    assert result['A'] == 7
+    assert result['B'] == 14
+
+    with raises(TypeError):
+        result = object_1 - object_3
+    with raises(TypeError):
+        result = object_2 - object_3
+
+
+def test_multiplication():
+    """Check that we can multiply two ForEachObject subclasses of the same type, or a scalar."""
+
+    object_1 = ForEachAB({'A': 1, 'B': 2})
+    object_2 = ForEachAB({'A': 3, 'B': -4})
+    object_3 = ForEachMOSA(0)
+
+    result = object_1 * object_2
+    assert result['A'] == 3
+    assert result['B'] == -8
+
+    result = object_1 * 2
+    assert result['A'] == 2
+    assert result['B'] == 4
+
+    result = 10 * object_2
+    assert result['A'] == 30
+    assert result['B'] == -40
+
+    with raises(TypeError):
+        result = object_1 * object_3
+    with raises(TypeError):
+        result = object_2 * object_3
+
+
+def test_floor_division():
+    """Check that we can apply floor division on two ForEachObject subclasses, or a scalar."""
+
+    object_1 = ForEachAB({'A': 1, 'B': 2})
+    object_2 = ForEachAB({'A': 3, 'B': -4})
+    object_3 = ForEachMOSA(0)
+
+    result = object_2 // object_1
+    assert result['A'] == 3
+    assert result['B'] == -2
+
+    result = object_2 // 2
+    assert result['A'] == 1
+    assert result['B'] == -2
+
+    with raises(TypeError):
+        result = object_1 // object_3
+    with raises(TypeError):
+        result = 20 // object_1
+
+
+def test_real_division():
+    """Check that we can divide two ForEachObject subclasses of the same type, or a scalar."""
+
+    object_1 = ForEachAB({'A': 1, 'B': 2})
+    object_2 = ForEachAB({'A': 3, 'B': -4})
+    object_3 = ForEachMOSA(0)
+
+    result = object_1 / object_2
+    assert result['A'] == 1/3
+    assert result['B'] == -1/2
+
+    result = object_1 / 5
+    assert result['A'] == 1/5
+    assert result['B'] == 2/5
+
+    result = 10 / object_2
+    assert result['A'] == 10/3
+    assert result['B'] == -10/4
+
+    with raises(TypeError):
+        result = object_1 / object_3
+    with raises(TypeError):
+        result = object_2 / object_3
+
+
+def test_power():
+    """Check that we take the power of a ForEachObject instance."""
+
+    object_1 = ForEachAB({'A': 1, 'B': -2})
+
+    result = object_1**1
+    assert result['A'] == 1
+    assert result['B'] == -2
+
+    result = object_1**2
+    assert result['A'] == 1
+    assert result['B'] == 4
+
+
 def test_transformed():
     """Check that we transformation is correctly applied to ForEachObject instances."""
 
@@ -132,6 +267,25 @@ def test_transformed():
     assert my_object['A'] == 'a'
     assert my_object['B'] == 'b'
 
+
+def test_abs():
+    """Check that we can take the absolute value of ForEachObject instances."""
+
+    my_object = ForEachAB({'A': -1, 'B': 2})
+    my_object = abs(my_object)
+    assert my_object['A'] == 1
+    assert my_object['B'] == 2
+
+
+def test_neg():
+    """Check that we can take the negative value of ForEachObject instances."""
+
+    my_object = ForEachAB({'A': -1, 'B': 2})
+    my_object = -my_object
+    assert my_object['A'] == 1
+    assert my_object['B'] == -2
+
+
 def test_write():
     """Check that we can write a ForEachObject instance."""
 
-- 
GitLab