Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions torch/_tensor_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(self, tensor):
def width(self):
return self.max_width

def format(self, value):
def format(self, value, has_non_zero_decimal_val=False):
if self.floating_dtype:
if self.sci_mode:
ret = ('{{:{}.{}e}}').format(self.max_width, PRINT_OPTS.precision).format(value)
Expand All @@ -146,6 +146,9 @@ def format(self, value):
elif self.complex_dtype:
p = PRINT_OPTS.precision
ret = '({{:.{}f}} {{}} {{:.{}f}}j)'.format(p, p).format(value.real, '+-'[value.imag < 0], abs(value.imag))
if not has_non_zero_decimal_val:
# complex tensor contains integer elements only
ret = "({{:.0f}} {{}} {{:.0f}}.j)".format(p, p).format(value.real, '+-'[value.imag < 0], abs(value.imag))
else:
ret = '{}'.format(value)
return (self.max_width - len(ret)) * ' ' + ret
Expand All @@ -166,7 +169,14 @@ def _vector_str(self, indent, formatter, summarize):
[' ...'] +
[formatter.format(val) for val in self[-PRINT_OPTS.edgeitems:].tolist()])
else:
data = [formatter.format(val) for val in self.tolist()]
# variable to keep track of complex float tensors
contains_decimal = False
for val in self.tolist():
if isinstance(val, float) and not val.is_integer():
contains_decimal = True
if isinstance(val, complex) and not val.imag.is_integer():
contains_decimal = True
data = [formatter.format(val, has_non_zero_decimal_val=contains_decimal) for val in self.tolist()]

data_lines = [data[i:i + elements_per_line] for i in range(0, len(data), elements_per_line)]
lines = [', '.join(line) for line in data_lines]
Expand Down