From 405d4ae2eb3f167007a627dccab370e04f1343e4 Mon Sep 17 00:00:00 2001 From: Marten van Kerkwijk Date: Sat, 12 Oct 2013 12:11:37 -0400 Subject: [PATCH 1/3] Ensure NotImplemented is passed on in MaskedArray ufunc's --- numpy/ma/core.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 8f2dbedeb61d..68c5d55951e3 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -930,6 +930,9 @@ def __call__ (self, a, b, *args, **kwargs): with np.errstate(): np.seterr(divide='ignore', invalid='ignore') result = self.f(da, db, *args, **kwargs) + # check it worked + if result is NotImplemented: + return NotImplemented # Case 1. : scalar if not result.ndim: if m: @@ -999,6 +1002,9 @@ def outer (self, a, b): return masked (da, db) = (getdata(a), getdata(b)) d = self.f.outer(da, db) + # check it worked + if d is NotImplemented: + return NotImplemented if m is not nomask: np.copyto(d, da, where=m) if d.shape: @@ -1065,6 +1071,9 @@ def __call__(self, a, b, *args, **kwargs): with np.errstate(): np.seterr(divide='ignore', invalid='ignore') result = self.f(da, db, *args, **kwargs) + # check it worked + if result is NotImplemented: + return NotImplemented # Get the mask as a combination of ma, mb and invalid m = ~umath.isfinite(result) m |= ma From 301d075e1d8b2471b7ecaf0182622568b54b9b68 Mon Sep 17 00:00:00 2001 From: Marten van Kerkwijk Date: Sat, 12 Oct 2013 16:06:14 -0400 Subject: [PATCH 2/3] Add test cases to ensure NotImplemented is passed on --- numpy/ma/tests/test_core.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 6c7baa48ec75..2d45a079db72 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -1662,6 +1662,26 @@ def test_ndarray_mask(self): assert_equal(test.mask, control.mask) self.assertTrue(not isinstance(test.mask, MaskedArray)) + def test_treatment_of_NotImplemented(self): + "Check we return NotImplemented if ufunc cannot deal with other" + a = masked_array([1., 2.], mask=[1, 0]) + # basic test + assert a.__mul__('abc') == NotImplemented # _MaskedBinaryOperation + assert multiply.outer(a, 'abc') == NotImplemented + assert a.__div__('abc') == NotImplemented # _DomainedBinaryOperation + + # also check that rmul of another class can be accessed + class MyClass(str): + def __mul__(self, other): + return "My mul" + + def __rmul__(self, other): + return "My rmul" + + me = MyClass() + assert me * a == "My mul" + assert a * me == "My rmul" + #------------------------------------------------------------------------------ class TestMaskedArrayInPlaceArithmetics(TestCase): From 50f33ad0873057a3cd7e673d54e8fc34260049f5 Mon Sep 17 00:00:00 2001 From: Marten van Kerkwijk Date: Sat, 12 Oct 2013 19:50:43 -0400 Subject: [PATCH 3/3] Small corrections to tests --- numpy/ma/tests/test_core.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 2d45a079db72..a4c0766e1a44 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -1663,14 +1663,15 @@ def test_ndarray_mask(self): self.assertTrue(not isinstance(test.mask, MaskedArray)) def test_treatment_of_NotImplemented(self): - "Check we return NotImplemented if ufunc cannot deal with other" + "Check any NotImplemented returned by umath. is passed on" a = masked_array([1., 2.], mask=[1, 0]) - # basic test - assert a.__mul__('abc') == NotImplemented # _MaskedBinaryOperation - assert multiply.outer(a, 'abc') == NotImplemented - assert a.__div__('abc') == NotImplemented # _DomainedBinaryOperation + # basic tests for _MaskedBinaryOperation + assert_(a.__mul__('abc') is NotImplemented) + assert_(multiply.outer(a, 'abc') is NotImplemented) + # and for _DomainedBinaryOperation + assert_(a.__div__('abc') is NotImplemented) - # also check that rmul of another class can be accessed + # also check explicitly that rmul of another class can be accessed class MyClass(str): def __mul__(self, other): return "My mul" @@ -1679,8 +1680,8 @@ def __rmul__(self, other): return "My rmul" me = MyClass() - assert me * a == "My mul" - assert a * me == "My rmul" + assert_(me * a == "My mul") + assert_(a * me == "My rmul") #------------------------------------------------------------------------------