@@ -857,6 +857,7 @@ def __str__ (self):
857857#####--------------------------------------------------------------------------
858858#---- --- Mask creation functions ---
859859#####--------------------------------------------------------------------------
860+
860861def _recursive_make_descr (datatype , newtype = bool_ ):
861862 "Private function allowing recursion in make_descr."
862863 # Do we have some name fields ?
@@ -1134,6 +1135,7 @@ def masked_where(condition, a, copy=True):
11341135 result ._mask = cond
11351136 return result
11361137
1138+
11371139def masked_greater (x , value , copy = True ):
11381140 """
11391141 Return the array `x` masked where (x > value).
@@ -1142,22 +1144,27 @@ def masked_greater(x, value, copy=True):
11421144 """
11431145 return masked_where (greater (x , value ), x , copy = copy )
11441146
1147+
11451148def masked_greater_equal (x , value , copy = True ):
11461149 "Shortcut to masked_where, with condition = (x >= value)."
11471150 return masked_where (greater_equal (x , value ), x , copy = copy )
11481151
1152+
11491153def masked_less (x , value , copy = True ):
11501154 "Shortcut to masked_where, with condition = (x < value)."
11511155 return masked_where (less (x , value ), x , copy = copy )
11521156
1157+
11531158def masked_less_equal (x , value , copy = True ):
11541159 "Shortcut to masked_where, with condition = (x <= value)."
11551160 return masked_where (less_equal (x , value ), x , copy = copy )
11561161
1162+
11571163def masked_not_equal (x , value , copy = True ):
11581164 "Shortcut to masked_where, with condition = (x != value)."
11591165 return masked_where (not_equal (x , value ), x , copy = copy )
11601166
1167+
11611168def masked_equal (x , value , copy = True ):
11621169 """
11631170 Shortcut to masked_where, with condition = (x == value). For
@@ -1171,6 +1178,7 @@ def masked_equal(x, value, copy=True):
11711178 # return array(d, mask=m, copy=copy)
11721179 return masked_where (equal (x , value ), x , copy = copy )
11731180
1181+
11741182def masked_inside (x , v1 , v2 , copy = True ):
11751183 """
11761184 Shortcut to masked_where, where ``condition`` is True for x inside
@@ -1188,6 +1196,7 @@ def masked_inside(x, v1, v2, copy=True):
11881196 condition = (xf >= v1 ) & (xf <= v2 )
11891197 return masked_where (condition , x , copy = copy )
11901198
1199+
11911200def masked_outside (x , v1 , v2 , copy = True ):
11921201 """
11931202 Shortcut to ``masked_where``, where ``condition`` is True for x outside
@@ -1205,7 +1214,7 @@ def masked_outside(x, v1, v2, copy=True):
12051214 condition = (xf < v1 ) | (xf > v2 )
12061215 return masked_where (condition , x , copy = copy )
12071216
1208- #
1217+
12091218def masked_object (x , value , copy = True , shrink = True ):
12101219 """
12111220 Mask the array `x` where the data are exactly equal to value.
@@ -1234,6 +1243,7 @@ def masked_object(x, value, copy=True, shrink=True):
12341243 mask = mask_or (mask , make_mask (condition , shrink = shrink ))
12351244 return masked_array (x , mask = mask , copy = copy , fill_value = value )
12361245
1246+
12371247def masked_values (x , value , rtol = 1.e-5 , atol = 1.e-8 , copy = True , shrink = True ):
12381248 """
12391249 Mask the array x where the data are approximately equal in
@@ -1271,6 +1281,7 @@ def masked_values(x, value, rtol=1.e-5, atol=1.e-8, copy=True, shrink=True):
12711281 mask = mask_or (mask , make_mask (condition , shrink = shrink ))
12721282 return masked_array (xnew , mask = mask , copy = copy , fill_value = value )
12731283
1284+
12741285def masked_invalid (a , copy = True ):
12751286 """
12761287 Mask the array for invalid values (NaNs or infs).
@@ -1292,6 +1303,7 @@ def masked_invalid(a, copy=True):
12921303#####--------------------------------------------------------------------------
12931304#---- --- Printing options ---
12941305#####--------------------------------------------------------------------------
1306+
12951307class _MaskedPrintOption :
12961308 """
12971309 Handle the string used to represent missing data in a masked array.
@@ -1372,6 +1384,20 @@ def _recursive_printoption(result, mask, printopt):
13721384#---- --- MaskedArray class ---
13731385#####--------------------------------------------------------------------------
13741386
1387+ def _recursive_filled (a , mask , fill_value ):
1388+ """
1389+ Recursively fill `a` with `fill_value`.
1390+ Private function
1391+ """
1392+ names = a .dtype .names
1393+ for name in names :
1394+ current = a [name ]
1395+ print "Name: %s : %s" % (name , current )
1396+ if current .dtype .names :
1397+ _recursive_filled (current , mask [name ], fill_value [name ])
1398+ else :
1399+ np .putmask (current , mask [name ], fill_value [name ])
1400+
13751401#...............................................................................
13761402class _arraymethod (object ):
13771403 """
@@ -2013,6 +2039,7 @@ def _getrecordmask(self):
20132039 try :
20142040 return _mask .view ((bool_ , len (self .dtype ))).all (axis )
20152041 except ValueError :
2042+ # In case we have nested fields...
20162043 return np .all ([[f [n ].all () for n in _mask .dtype .names ]
20172044 for f in _mask ], axis = axis )
20182045
@@ -2106,6 +2133,7 @@ def set_fill_value(self, value=None):
21062133 fill_value = property (fget = get_fill_value , fset = set_fill_value ,
21072134 doc = "Filling value." )
21082135
2136+
21092137 def filled (self , fill_value = None ):
21102138 """Return a copy of self._data, where masked values are filled
21112139 with fill_value.
@@ -2140,9 +2168,10 @@ def filled(self, fill_value=None):
21402168 #
21412169 if m .dtype .names :
21422170 result = self ._data .copy ()
2143- for n in result .dtype .names :
2144- field = result [n ]
2145- np .putmask (field , self ._mask [n ], fill_value [n ])
2171+ _recursive_filled (result , self ._mask , fill_value )
2172+ # for n in result.dtype.names:
2173+ # field = result[n]
2174+ # np.putmask(field, self._mask[n], fill_value[n])
21462175 elif not m .any ():
21472176 return self ._data
21482177 else :
@@ -2287,6 +2316,58 @@ def __repr__(self):
22872316 return _print_templates ['short' ] % parameters
22882317 return _print_templates ['long' ] % parameters
22892318 #............................................
2319+ def __eq__ (self , other ):
2320+ "Check whether other equals self elementwise"
2321+ omask = getattr (other , '_mask' , nomask )
2322+ if omask is nomask :
2323+ check = ndarray .__eq__ (self .filled (0 ), other ).view (type (self ))
2324+ check ._mask = self ._mask
2325+ else :
2326+ odata = filled (other , 0 )
2327+ check = ndarray .__eq__ (self .filled (0 ), odata ).view (type (self ))
2328+ if self ._mask is nomask :
2329+ check ._mask = omask
2330+ else :
2331+ mask = mask_or (self ._mask , omask )
2332+ if mask .dtype .names :
2333+ if mask .size > 1 :
2334+ axis = 1
2335+ else :
2336+ axis = None
2337+ try :
2338+ mask = mask .view ((bool_ , len (self .dtype ))).all (axis )
2339+ except ValueError :
2340+ mask = np .all ([[f [n ].all () for n in mask .dtype .names ]
2341+ for f in mask ], axis = axis )
2342+ check ._mask = mask
2343+ return check
2344+ #
2345+ def __ne__ (self , other ):
2346+ "Check whether other doesn't equal self elementwise"
2347+ omask = getattr (other , '_mask' , nomask )
2348+ if omask is nomask :
2349+ check = ndarray .__ne__ (self .filled (0 ), other ).view (type (self ))
2350+ check ._mask = self ._mask
2351+ else :
2352+ odata = filled (other , 0 )
2353+ check = ndarray .__ne__ (self .filled (0 ), odata ).view (type (self ))
2354+ if self ._mask is nomask :
2355+ check ._mask = omask
2356+ else :
2357+ mask = mask_or (self ._mask , omask )
2358+ if mask .dtype .names :
2359+ if mask .size > 1 :
2360+ axis = 1
2361+ else :
2362+ axis = None
2363+ try :
2364+ mask = mask .view ((bool_ , len (self .dtype ))).all (axis )
2365+ except ValueError :
2366+ mask = np .all ([[f [n ].all () for n in mask .dtype .names ]
2367+ for f in mask ], axis = axis )
2368+ check ._mask = mask
2369+ return check
2370+ #
22902371 def __add__ (self , other ):
22912372 "Add other to self, and return a new masked array."
22922373 return add (self , other )
0 commit comments