Skip to content

Commit d2b06fe

Browse files
committed
ENH: str/repr fixed for 0d-arrays
0d arrays now use the arrayprint.py formatters to print themselves. Deprecates 'style' argument to array2string. Integer scalars are no longer printed using the function set with ``np.set_print_function``.
1 parent c610444 commit d2b06fe

5 files changed

Lines changed: 83 additions & 39 deletions

File tree

doc/release/1.14.0-notes.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,14 @@ Improvements
4343

4444
Changes
4545
=======
46+
47+
0d arrays now print their elements like other arrays
48+
----------------------------------------------------
49+
0d arrays now use the array2string formatters to print their elements, like
50+
other arrays. The `style` argument of array2string is now non-functional.
51+
52+
integer scalars are now unaffected by ``np.set_string_function``
53+
----------------------------------------------------------------
54+
Previously the str/repr of integer scalars could be controlled by
55+
``np.set_string_function``, unlike most other numpy scalars. This is no longer
56+
the case.

numpy/core/arrayprint.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@
1515
# and by Perry Greenfield 2000-4-1 for numarray
1616
# and by Travis Oliphant 2005-8-22 for numpy
1717

18+
19+
# Note: Both scalartypes.c.src and arrayprint.py implement strs for numpy
20+
# scalars but for different purposes. scalartypes.c.src has str/reprs for when
21+
# the scalar is printed on its own, while arrayprint.py has strs for when
22+
# scalars are printed inside an ndarray. Only the latter strs are currently
23+
# user-customizable.
24+
1825
import sys
1926
import functools
2027
if sys.version_info[0] >= 3:
@@ -28,12 +35,14 @@
2835
except ImportError:
2936
from dummy_thread import get_ident
3037

38+
import numpy as np
3139
from . import numerictypes as _nt
3240
from .umath import maximum, minimum, absolute, not_equal, isnan, isinf
3341
from .multiarray import (array, format_longfloat, datetime_as_string,
3442
datetime_data, dtype)
3543
from .fromnumeric import ravel
3644
from .numeric import asarray
45+
import warnings
3746

3847
if sys.version_info[0] >= 3:
3948
_MAXINT = sys.maxsize
@@ -399,7 +408,7 @@ def wrapper(self, *args, **kwargs):
399408
@_recursive_guard()
400409
def array2string(a, max_line_width=None, precision=None,
401410
suppress_small=None, separator=' ', prefix="",
402-
style=repr, formatter=None):
411+
style=np._NoValue, formatter=None):
403412
"""
404413
Return a string representation of an array.
405414
@@ -425,9 +434,10 @@ def array2string(a, max_line_width=None, precision=None,
425434
426435
The length of the prefix string is used to align the
427436
output correctly.
428-
style : function, optional
429-
A function that accepts an ndarray and returns a string. Used only
430-
when the shape of `a` is equal to ``()``, i.e. for 0-D arrays.
437+
style : _NoValue, optional
438+
Has no effect, do not use.
439+
440+
.. deprecated:: 1.14.0
431441
formatter : dict of callables, optional
432442
If not None, the keys should indicate the type(s) that the respective
433443
formatting function applies to. Callables should return a string.
@@ -494,6 +504,11 @@ def array2string(a, max_line_width=None, precision=None,
494504
495505
"""
496506

507+
# Deprecation 05-16-2017 v1.14
508+
if style is not np._NoValue:
509+
warnings.warn("'style' argument is deprecated and no longer functional",
510+
DeprecationWarning, stacklevel=3)
511+
497512
if max_line_width is None:
498513
max_line_width = _line_width
499514

@@ -506,16 +521,7 @@ def array2string(a, max_line_width=None, precision=None,
506521
if formatter is None:
507522
formatter = _formatter
508523

509-
if a.shape == ():
510-
x = a.item()
511-
if a.dtype.fields is not None:
512-
arr = array([x], dtype=a.dtype)
513-
format_function = _get_format_function(
514-
arr, precision, suppress_small, formatter)
515-
lst = format_function(arr[0])
516-
else:
517-
lst = style(x)
518-
elif functools.reduce(product, a.shape) == 0:
524+
if a.size == 0:
519525
# treat as a null array if any of shape elements == 0
520526
lst = "[]"
521527
else:
@@ -542,7 +548,7 @@ def _formatArray(a, format_function, rank, max_line_len,
542548
543549
"""
544550
if rank == 0:
545-
raise ValueError("rank shouldn't be zero.")
551+
return format_function(a[()]) + '\n'
546552

547553
if summary_insert and 2*edge_items < len(a):
548554
leading_items = edge_items
@@ -809,22 +815,21 @@ def __call__(self, x):
809815

810816
class TimedeltaFormat(object):
811817
def __init__(self, data):
812-
if data.dtype.kind == 'm':
813-
nat_value = array(['NaT'], dtype=data.dtype)[0]
814-
int_dtype = dtype(data.dtype.byteorder + 'i8')
815-
int_view = data.view(int_dtype)
816-
v = int_view[not_equal(int_view, nat_value.view(int_dtype))]
817-
if len(v) > 0:
818-
# Max str length of non-NaT elements
819-
max_str_len = max(len(str(maximum.reduce(v))),
820-
len(str(minimum.reduce(v))))
821-
else:
822-
max_str_len = 0
823-
if len(v) < len(data):
824-
# data contains a NaT
825-
max_str_len = max(max_str_len, 5)
826-
self.format = '%' + str(max_str_len) + 'd'
827-
self._nat = "'NaT'".rjust(max_str_len)
818+
nat_value = array(['NaT'], dtype=data.dtype)[0]
819+
int_dtype = dtype(data.dtype.byteorder + 'i8')
820+
int_view = data.view(int_dtype)
821+
v = int_view[not_equal(int_view, nat_value.view(int_dtype))]
822+
if len(v) > 0:
823+
# Max str length of non-NaT elements
824+
max_str_len = max(len(str(maximum.reduce(v))),
825+
len(str(minimum.reduce(v))))
826+
else:
827+
max_str_len = 0
828+
if len(v) < len(data):
829+
# data contains a NaT
830+
max_str_len = max(max_str_len, 5)
831+
self.format = '%' + str(max_str_len) + 'd'
832+
self._nat = "'NaT'".rjust(max_str_len)
828833

829834
def __call__(self, x):
830835
# TODO: After NAT == NAT deprecation should be simplified:

numpy/core/numeric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1936,7 +1936,7 @@ def array_str(a, max_line_width=None, precision=None, suppress_small=None):
19361936
'[0 1 2]'
19371937
19381938
"""
1939-
return array2string(a, max_line_width, precision, suppress_small, ' ', "", str)
1939+
return array2string(a, max_line_width, precision, suppress_small, ' ', "")
19401940

19411941

19421942
def set_string_function(f, repr=True):

numpy/core/src/multiarray/scalartypes.c.src

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,6 @@ gentype_str(PyObject *self)
338338
return ret;
339339
}
340340

341-
342341
static PyObject *
343342
gentype_repr(PyObject *self)
344343
{
@@ -353,6 +352,20 @@ gentype_repr(PyObject *self)
353352
return ret;
354353
}
355354

355+
static PyObject *
356+
genint_type_str(PyObject *self)
357+
{
358+
PyObject *item, *item_str;
359+
item = gentype_generic_method(self, NULL, NULL, "item");
360+
if (item == NULL) {
361+
return NULL;
362+
}
363+
364+
item_str = PyObject_Str(item);
365+
Py_DECREF(item);
366+
return item_str;
367+
}
368+
356369
/*
357370
* The __format__ method for PEP 3101.
358371
*/
@@ -4185,6 +4198,19 @@ initialize_numeric_types(void)
41854198

41864199
/**end repeat**/
41874200

4201+
4202+
/**begin repeat
4203+
* #Type = Bool, Byte, UByte, Short, UShort, Int, UInt, Long,
4204+
* ULong, LongLong, ULongLong#
4205+
*/
4206+
4207+
/* both str/repr use genint_type_str to avoid trailing "L" of longs */
4208+
Py@Type@ArrType_Type.tp_str = genint_type_str;
4209+
Py@Type@ArrType_Type.tp_repr = genint_type_str;
4210+
4211+
/**end repeat**/
4212+
4213+
41884214
PyHalfArrType_Type.tp_print = halftype_print;
41894215
PyFloatArrType_Type.tp_print = floattype_print;
41904216
PyDoubleArrType_Type.tp_print = doubletype_print;

numpy/core/tests/test_arrayprint.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,6 @@ def test_basic(self):
115115
assert_(np.array2string(a) == '[0 1 2]')
116116
assert_(np.array2string(a, max_line_width=4) == '[0 1\n 2]')
117117

118-
def test_style_keyword(self):
119-
"""This should only apply to 0-D arrays. See #1218."""
120-
stylestr = np.array2string(np.array(1.5),
121-
style=lambda x: "Value in 0-D array: " + str(x))
122-
assert_(stylestr == 'Value in 0-D array: 1.5')
123-
124118
def test_format_function(self):
125119
"""Test custom format function for each element in array."""
126120
def _format_function(x):
@@ -242,6 +236,14 @@ def test_formatter_reset(self):
242236
np.set_printoptions(formatter={'float_kind':None})
243237
assert_equal(repr(x), "array([ 0., 1., 2.])")
244238

239+
def test_0d_arrays(self):
240+
assert_equal(repr(np.datetime64('2005-02-25')[...]),
241+
"array('2005-02-25', dtype='datetime64[D]')")
242+
243+
x = np.array(1)
244+
np.set_printoptions(formatter={'all':lambda x: "test"})
245+
assert_equal(repr(x), "array(test)")
246+
245247
def test_unicode_object_array():
246248
import sys
247249
if sys.version_info[0] >= 3:

0 commit comments

Comments
 (0)