Skip to content

Min and Max with complex inputs exhibit behavior incompatible with NumPy #36374

@mruberry

Description

@mruberry

As the title says, PyTorch's min and max are incompatible with NumPy's. For example:

a = np.array((0 + 4j, 4 + 0j, -2 - 2j, 1 + 1j, 2 + 2j, 3 + 3j))
t = torch.from_numpy(a)

np_result = np.min(a)
torch_result = torch.min(t)

np_result is (-2 - 2j) while torch_result is (1. + 1.j). The latter seems hard to justify under any notion of min.

Max is no better:

np_result = np.max(a)
torch_result = torch.max(t)

np_result is (4 + 0j) while torch_result is (-inf + 0.j). Again, the torch_result seems hard to justify under any notion of max.

When a dim is supplied to max the PyTorch behavior changes, but is still incompatible with NumPy's:

torch_result = torch.max(t, dim=0)

Here torch_result is (3. + 3.j), which is at least a reasonable max value.

cc @ezyang @anjali411 @dylanbespalko

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: complexRelated to complex number support in PyTorchmodule: numpyRelated to numpy support, and also numpy compatibility of our operatorstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions