Skip to content

[PT_BREAK] implement NumPy-like functionality maximum, minimum #2424

@RockingJavaBean

Description

@RockingJavaBean

From PR: pytorch/pytorch#42579

CircleCI error

The pytorch_xla_linux_bionic_py3_6_clang9_build failed in the PyTorch PR to implement torch.maximum and torch.minimum functions.

Below is the error message of the xla build CI.

Aug 14 10:05:46 Generated 1499 wrappers for /var/lib/jenkins/workspace/xla/scripts/../../torch/csrc/autograd/generated/RegistrationDeclarations.h
Aug 14 10:05:46 AtenXlaType function missed override: Tensor max(const Tensor& self, const Tensor& other); // max(Tensor,Tensor)->Tensor
Aug 14 10:05:46 AtenXlaType function missed override: Tensor min(const Tensor& self, const Tensor& other); // min(Tensor,Tensor)->Tensor
Aug 14 10:05:46 Traceback (most recent call last):
Aug 14 10:05:46   File "/var/lib/jenkins/workspace/xla/scripts/gen.py", line 1137, in <module>
Aug 14 10:05:46     generate(args)
Aug 14 10:05:46   File "/var/lib/jenkins/workspace/xla/scripts/gen.py", line 1107, in generate
Aug 14 10:05:46     assert check_overrides(overrides, overridden)
Aug 14 10:05:46 AssertionError

Binary Ops change

pytorch/pytorch#42579 tries to promote torch.maximum and torch.minimum over those hidden overloaded versions of torch.max and torch.min.

The aten/src/ATen/native/BinaryOps.cpp of this PR

Tensor& maximum_out(Tensor& result, const Tensor& self, const Tensor& other) {
  TORCH_CHECK(!self.is_complex() && !other.is_complex(), "maximum does not support complex inputs.");

  auto iter = TensorIterator::binary_op(result, self, other,
                                        /*check_mem_overlap=*/true);
  maximum_stub(iter.device_type(), iter);
  return result;
}

Tensor maximum(const Tensor& self, const Tensor& other) {
  TORCH_CHECK(!self.is_complex() && !other.is_complex(), "maximum does not support complex inputs.");

  Tensor result;
  auto iter = TensorIterator::binary_op(result, self, other);
  maximum_stub(iter.device_type(), iter);
  return iter.output();
}

// binary max, alias for maximum
Tensor& max_out(Tensor& result, const Tensor& self, const Tensor& other) {
  return at::maximum_out(result, self, other);
}

Tensor max(const Tensor& self, const Tensor& other) {
  return at::maximum(self, other);
}

While max.other and min.other are removed in the native_functions.yaml. My bad in the previous description,
Please refer to @mruberry 's comment below.

@ailzhang / @dlibenzi / @JackCaoG.
Cc @mruberry

Metadata

Metadata

Assignees

Labels

pytorch breakingUpstream PyTorch breakage w.r.t. PyTorch/XLA

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions