Skip to content

Add division overload with rounding_mode selection#50280

Closed
peterbell10 wants to merge 26 commits intogh/peterbell10/36/basefrom
gh/peterbell10/36/head
Closed

Add division overload with rounding_mode selection#50280
peterbell10 wants to merge 26 commits intogh/peterbell10/36/basefrom
gh/peterbell10/36/head

Conversation

@peterbell10
Copy link
Copy Markdown
Collaborator

@peterbell10 peterbell10 commented Jan 8, 2021

Stack from ghstack:

As mentioned in gh-43874, this adds a rounding_mode={'true', 'trunc', 'floor'}
argument so torch.div can be used as a replacement for floor_divide during
the transitional period.

I've included dedicated kernels for truncated and floor division which
aren't strictly necessary for float, but do perform significantly better (~2x) than
doing true division followed by a separate rounding kernel.

Note: I introduce new overloads for aten::div instead of just adding a default
rounding_mode because various JIT passes rely on the exact operator schema.

Differential Revision: D26123271

As mentioned in gh-43874, this adds a `rounding_mode={'true', 'trunc', 'floor'}`
argument so `torch.div` can be used as a replacement for `floor_divide` during
the transitional period.

I've included dedicated kernels for truncated and floor division which
aren't strictly necessary for float, but do significantly better (~2x) than
doing true division followed by a separate rounding kernel.

Note: I introduce new overloads for `aten::div` instead of just adding a default
`rounding_mode` because various JIT passes rely on the exact operator schema.

[ghstack-poisoned]
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Jan 8, 2021

💊 CI failures summary and remediations

As of commit 71b0cfe (more details on the Dr. CI page):


  • 4/4 failures possibly* introduced in this PR
    • 1/4 non-CircleCI failure(s)

🕵️ 3 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_build (1/3)

Step: "(Optional) Merge target branch" (full log | diagnosis details | 🔁 rerun)

Automatic merge failed; fix conflicts and then commit the result.
CONFLICT (add/add): Merge conflict in .jenkins/caffe2/test.sh
Auto-merging .jenkins/caffe2/test.sh
CONFLICT (add/add): Merge conflict in .circleci/docker/ubuntu-rocm/Dockerfile
Auto-merging .circleci/docker/ubuntu-rocm/Dockerfile
CONFLICT (add/add): Merge conflict in .circleci/docker/common/install_rocm.sh
Auto-merging .circleci/docker/common/install_rocm.sh
CONFLICT (add/add): Merge conflict in .circleci/config.yml
Auto-merging .circleci/config.yml
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/dimensions.py
Auto-merging .circleci/cimodel/data/dimensions.py
Automatic merge failed; fix conflicts and then commit the result.


Exited with code exit status 1

See CircleCI build pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc5_4_build (2/3)

Step: "(Optional) Merge target branch" (full log | diagnosis details | 🔁 rerun)

Automatic merge failed; fix conflicts and then commit the result.
CONFLICT (add/add): Merge conflict in .jenkins/caffe2/test.sh
Auto-merging .jenkins/caffe2/test.sh
CONFLICT (add/add): Merge conflict in .circleci/docker/ubuntu-rocm/Dockerfile
Auto-merging .circleci/docker/ubuntu-rocm/Dockerfile
CONFLICT (add/add): Merge conflict in .circleci/docker/common/install_rocm.sh
Auto-merging .circleci/docker/common/install_rocm.sh
CONFLICT (add/add): Merge conflict in .circleci/config.yml
Auto-merging .circleci/config.yml
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/dimensions.py
Auto-merging .circleci/cimodel/data/dimensions.py
Automatic merge failed; fix conflicts and then commit the result.


Exited with code exit status 1

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_build (3/3)

Step: "(Optional) Merge target branch" (full log | diagnosis details | 🔁 rerun)

Automatic merge failed; fix conflicts and then commit the result.
CONFLICT (add/add): Merge conflict in .jenkins/caffe2/test.sh
Auto-merging .jenkins/caffe2/test.sh
CONFLICT (add/add): Merge conflict in .circleci/docker/ubuntu-rocm/Dockerfile
Auto-merging .circleci/docker/ubuntu-rocm/Dockerfile
CONFLICT (add/add): Merge conflict in .circleci/docker/common/install_rocm.sh
Auto-merging .circleci/docker/common/install_rocm.sh
CONFLICT (add/add): Merge conflict in .circleci/config.yml
Auto-merging .circleci/config.yml
CONFLICT (add/add): Merge conflict in .circleci/cimodel/data/dimensions.py
Auto-merging .circleci/cimodel/data/dimensions.py
Automatic merge failed; fix conflicts and then commit the result.


Exited with code exit status 1


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jan 8, 2021
As mentioned in gh-43874, this adds a `rounding_mode={'true', 'trunc', 'floor'}`
argument so `torch.div` can be used as a replacement for `floor_divide` during
the transitional period.

I've included dedicated kernels for truncated and floor division which
aren't strictly necessary for float, but do significantly better (~2x) than
doing true division followed by a separate rounding kernel.

Note: I introduce new overloads for `aten::div` instead of just adding a default
`rounding_mode` because various JIT passes rely on the exact operator schema.

[ghstack-poisoned]
As mentioned in gh-43874, this adds a `rounding_mode={'true', 'trunc', 'floor'}`
argument so `torch.div` can be used as a replacement for `floor_divide` during
the transitional period.

I've included dedicated kernels for truncated and floor division which
aren't strictly necessary for float, but do significantly better (~2x) than
doing true division followed by a separate rounding kernel.

Note: I introduce new overloads for `aten::div` instead of just adding a default
`rounding_mode` because various JIT passes rely on the exact operator schema.

[ghstack-poisoned]
@peterbell10 peterbell10 requested a review from mruberry January 8, 2021 17:05
As mentioned in gh-43874, this adds a `rounding_mode={'true', 'trunc', 'floor'}`
argument so `torch.div` can be used as a replacement for `floor_divide` during
the transitional period.

I've included dedicated kernels for truncated and floor division which
aren't strictly necessary for float, but do significantly better (~2x) than
doing true division followed by a separate rounding kernel.

Note: I introduce new overloads for `aten::div` instead of just adding a default
`rounding_mode` because various JIT passes rely on the exact operator schema.

[ghstack-poisoned]
As mentioned in gh-43874, this adds a `rounding_mode={'true', 'trunc', 'floor'}`
argument so `torch.div` can be used as a replacement for `floor_divide` during
the transitional period.

I've included dedicated kernels for truncated and floor division which
aren't strictly necessary for float, but do perform significantly better (~2x) than
doing true division followed by a separate rounding kernel.

Note: I introduce new overloads for `aten::div` instead of just adding a default
`rounding_mode` because various JIT passes rely on the exact operator schema.

[ghstack-poisoned]
As mentioned in gh-43874, this adds a `rounding_mode={'true', 'trunc', 'floor'}`
argument so `torch.div` can be used as a replacement for `floor_divide` during
the transitional period.

I've included dedicated kernels for truncated and floor division which
aren't strictly necessary for float, but do perform significantly better (~2x) than
doing true division followed by a separate rounding kernel.

Note: I introduce new overloads for `aten::div` instead of just adding a default
`rounding_mode` because various JIT passes rely on the exact operator schema.

[ghstack-poisoned]
As mentioned in gh-43874, this adds a `rounding_mode={'true', 'trunc', 'floor'}`
argument so `torch.div` can be used as a replacement for `floor_divide` during
the transitional period.

I've included dedicated kernels for truncated and floor division which
aren't strictly necessary for float, but do perform significantly better (~2x) than
doing true division followed by a separate rounding kernel.

Note: I introduce new overloads for `aten::div` instead of just adding a default
`rounding_mode` because various JIT passes rely on the exact operator schema.

[ghstack-poisoned]
BINARY_POINTWISE(mul);
BINARY_POINTWISE(div);
{
using Binop = Tensor (*)(const Tensor&, const Tensor&, std::string);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rzou would you take a look here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry I think you got the wrong user. Was that meant to be @zou3519?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was, thanks @peterbell10. Darn autocomplete!

cc @zou3519

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this lgtm!

Comment thread aten/src/ATen/native/BinaryOps.cpp
Comment thread aten/src/ATen/native/BinaryOps.cpp Outdated
Comment thread aten/src/ATen/native/BinaryOps.cpp Outdated
Comment thread aten/src/ATen/native/BinaryOps.cpp Outdated
Comment thread aten/src/ATen/native/cpu/BinaryOpsKernel.cpp Outdated
Comment thread aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
} else if (isIntegralType(dtype, /*includeBool*/ false)) {
// There's no SIMD integer division, so don't try to vectorize it.
// TODO: if the divisor is a scalar, rewrite as multiplication by a constant.
AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "div_floor_cpu", [&]() {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is inconsistent between using dtype and iter.common_dtype().

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed all uses of iter.dtype(). If instead you meant the variable dtype, then I would note that it's assigned from iter.common_dtype() above. Just a bit less to type.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize the value is the same, just for readability the code might want to stick to either dtype or iter.common_dtype(). No big deal either way.

});
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "div_floor_cpu", [&]() {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same dtype vs iter.common_dtype here, too.

Comment thread aten/src/ATen/native/cuda/BinaryMulDivKernel.cu
Comment thread aten/src/ATen/native/cuda/BinaryMulDivKernel.cu Outdated
Comment thread aten/src/ATen/native/cuda/BinaryMulDivKernel.cu
Comment thread aten/src/ATen/native/cuda/BinaryMulDivKernel.cu
As mentioned in gh-43874, this adds a `rounding_mode={'true', 'trunc', 'floor'}`
argument so `torch.div` can be used as a replacement for `floor_divide` during
the transitional period.

I've included dedicated kernels for truncated and floor division which
aren't strictly necessary for float, but do perform significantly better (~2x) than
doing true division followed by a separate rounding kernel.

Note: I introduce new overloads for `aten::div` instead of just adding a default
`rounding_mode` because various JIT passes rely on the exact operator schema.

Differential Revision: [D26123271](https://our.internmc.facebook.com/intern/diff/D26123271)

[ghstack-poisoned]
As mentioned in gh-43874, this adds a `rounding_mode={'true', 'trunc', 'floor'}`
argument so `torch.div` can be used as a replacement for `floor_divide` during
the transitional period.

I've included dedicated kernels for truncated and floor division which
aren't strictly necessary for float, but do perform significantly better (~2x) than
doing true division followed by a separate rounding kernel.

Note: I introduce new overloads for `aten::div` instead of just adding a default
`rounding_mode` because various JIT passes rely on the exact operator schema.

Differential Revision: [D26123271](https://our.internmc.facebook.com/intern/diff/D26123271)

[ghstack-poisoned]
As mentioned in gh-43874, this adds a `rounding_mode={'true', 'trunc', 'floor'}`
argument so `torch.div` can be used as a replacement for `floor_divide` during
the transitional period.

I've included dedicated kernels for truncated and floor division which
aren't strictly necessary for float, but do perform significantly better (~2x) than
doing true division followed by a separate rounding kernel.

Note: I introduce new overloads for `aten::div` instead of just adding a default
`rounding_mode` because various JIT passes rely on the exact operator schema.

Differential Revision: [D26123271](https://our.internmc.facebook.com/intern/diff/D26123271)

[ghstack-poisoned]
@JackCaoG
Copy link
Copy Markdown
Collaborator

Hi @mruberry, I think I can get pt/xla pr ready this week. I will ping you when that is ready.

@peterbell10
Copy link
Copy Markdown
Collaborator Author

@mruberry PTAL when you can. Have addressed your comments and rebased.

Comment thread torch/_torch_docs.py Outdated
* ``"true"`` - default behavior. Performs no rounding and, if both :attr:`input` and
:attr:`other` are integer types, promotes the inputs to the default scalar type.
Equivalent to true division in Python (the ``/`` operator) and NumPy's ``np.true_divide``.
* ``"trunc"`` - rounds the results of the division down.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The description for trunc and floor are identical, "rounds the results of the division down".

For trunc I think we can say rounds towards zero?

#50280 (comment)

torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(saved_dtype)
try:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for fixing this.

Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One doc nit, otherwise looks awesome.

@JackCaoG let us know when this is safe to merge.

As mentioned in gh-43874, this adds a `rounding_mode={'true', 'trunc', 'floor'}`
argument so `torch.div` can be used as a replacement for `floor_divide` during
the transitional period.

I've included dedicated kernels for truncated and floor division which
aren't strictly necessary for float, but do perform significantly better (~2x) than
doing true division followed by a separate rounding kernel.

Note: I introduce new overloads for `aten::div` instead of just adding a default
`rounding_mode` because various JIT passes rely on the exact operator schema.

Differential Revision: [D26123271](https://our.internmc.facebook.com/intern/diff/D26123271)

[ghstack-poisoned]
peterbell10 added a commit to peterbell10/pytorch that referenced this pull request Feb 1, 2021
As mentioned in pytorchgh-43874, this adds a `rounding_mode={'true', 'trunc', 'floor'}`
argument so `torch.div` can be used as a replacement for `floor_divide` during
the transitional period.

I've included dedicated kernels for truncated and floor division which
aren't strictly necessary for float, but do significantly better (~2x) than
doing true division followed by a separate rounding kernel.

Note: I introduce new overloads for `aten::div` instead of just adding a default
`rounding_mode` because various JIT passes rely on the exact operator schema.

ghstack-source-id: 372c4be
Pull Request resolved: pytorch#50280
@mruberry
Copy link
Copy Markdown
Collaborator

mruberry commented Feb 2, 2021

One of the test failures is real: test_div_rounding_numpy_cuda_bfloat16.

We can skip the test for simplicity to unblock this landing. @JackCaoG's PR is ready to go so I'd like to land this during PST business hours on Tuesday, February 2nd.

As mentioned in gh-43874, this adds a `rounding_mode={'true', 'trunc', 'floor'}`
argument so `torch.div` can be used as a replacement for `floor_divide` during
the transitional period.

I've included dedicated kernels for truncated and floor division which
aren't strictly necessary for float, but do perform significantly better (~2x) than
doing true division followed by a separate rounding kernel.

Note: I introduce new overloads for `aten::div` instead of just adding a default
`rounding_mode` because various JIT passes rely on the exact operator schema.

Differential Revision: [D26123271](https://our.internmc.facebook.com/intern/diff/D26123271)

[ghstack-poisoned]
@peterbell10
Copy link
Copy Markdown
Collaborator Author

In hindsight, the BFloat16 comparison with random test data is unlikely to work perfectly. It's sensitive to exact rounding and since NumPy is rounding to float32 precision, it will occasionally get different answers.

@mruberry
Copy link
Copy Markdown
Collaborator

mruberry commented Feb 2, 2021

In hindsight, the BFloat16 comparison with random test data is unlikely to work perfectly. It's sensitive to exact rounding and since NumPy is rounding to float32 precision, it will occasionally get different answers.

We can disable the test for now. In the future we could use a fixture to ensure there are no rounding issues.

}
return floordiv;
},
[](Vec256<scalar_t> a, Vec256<scalar_t> b) -> Vec256<scalar_t>{
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is triggering some internal build issues. Adding a vectorized function can be a little tricky because we often have to stub them out on some platforms, like Android.

Since we're so close to the branch cut, I propose removing the copysign implementation and this vectorized implementation. We can file an issue and add them back in a later PR where we can take our time and focus on that issue.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry this should be good now. Removed all Vec256 changes and unvectorized floor_divide.

peterbell10 and others added 2 commits February 2, 2021 19:31
As mentioned in gh-43874, this adds a `rounding_mode={'true', 'trunc', 'floor'}`
argument so `torch.div` can be used as a replacement for `floor_divide` during
the transitional period.

I've included dedicated kernels for truncated and floor division which
aren't strictly necessary for float, but do perform significantly better (~2x) than
doing true division followed by a separate rounding kernel.

Note: I introduce new overloads for `aten::div` instead of just adding a default
`rounding_mode` because various JIT passes rely on the exact operator schema.

Differential Revision: [D26123271](https://our.internmc.facebook.com/intern/diff/D26123271)

[ghstack-poisoned]
@mruberry
Copy link
Copy Markdown
Collaborator

mruberry commented Feb 3, 2021

FYI there's another internal issue that I'm reviewing now. I'll keep this updated.

@mruberry
Copy link
Copy Markdown
Collaborator

mruberry commented Feb 4, 2021

Update: hacking through internal infra issues still.

@mruberry
Copy link
Copy Markdown
Collaborator

mruberry commented Feb 4, 2021

Update: blocking infra team confirms it's fixed its issue. This should land today.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@mruberry merged this pull request in b150f15.

@mruberry
Copy link
Copy Markdown
Collaborator

mruberry commented Feb 4, 2021

Landed. Some changes had to be made internally:

  • div and floor_divide had to call the stub directly and not re-dispatch to the out= variant (unclear how necessary this was, there were some performance failures but they may have been flaky)
  • the kernel name "div_cpu" couldn't change because mobile manifests list kernel names for export, and they didn't understand the new name

@facebook-github-bot facebook-github-bot deleted the gh/peterbell10/36/head branch February 8, 2021 15:21
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Pull Request resolved: pytorch#51706

Pull Request resolved: pytorch#50280

As mentioned in pytorchgh-43874, this adds a `rounding_mode={'true', 'trunc', 'floor'}`
argument so `torch.div` can be used as a replacement for `floor_divide` during
the transitional period.

I've included dedicated kernels for truncated and floor division which
aren't strictly necessary for float, but do perform significantly better (~2x) than
doing true division followed by a separate rounding kernel.

Note: I introduce new overloads for `aten::div` instead of just adding a default
`rounding_mode` because various JIT passes rely on the exact operator schema.

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D26123271

Pulled By: mruberry

fbshipit-source-id: 51a83717602114597ec9c4d946e35a392eb01d46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged oncall: jit Add this issue/PR to JIT oncall triage queue open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants