Skip to content

Port remainder from TH to ATen (CPU and CUDA)#34136

Closed
kurtamohler wants to merge 16 commits intopytorch:masterfrom
kurtamohler:port-remainder-aten-24753
Closed

Port remainder from TH to ATen (CPU and CUDA)#34136
kurtamohler wants to merge 16 commits intopytorch:masterfrom
kurtamohler:port-remainder-aten-24753

Conversation

@kurtamohler
Copy link
Copy Markdown
Collaborator

CPU issue #24753
CUDA issue #24615

@VitalyFedyunin
Copy link
Copy Markdown
Contributor

Please add performance benchmarks CPU/CUDA; Single Thread/Multi Thread

@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented Mar 3, 2020

💊 CircleCI build failures summary and remediations

As of commit 80c15c0 (more details on the Dr. CI page):


  • 4/4 failures introduced in this PR

🕵️ 3 new failures recognized by patterns

The following build failures do not appear to be due to upstream breakages (reran 2 jobs to discount flakiness):

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_build (1/3)

Step: "Build" (full log | pattern match details) <confirmed not flaky by 2 failures>

Automatic merge failed; fix conflicts and then commit the result.
CONFLICT (add/add): Merge conflict in .circleci/verbatim-sources/workflows-binary-builds-smoke-subset.yml 
Auto-merging .circleci/verbatim-sources/workflows-binary-builds-smoke-subset.yml 
CONFLICT (add/add): Merge conflict in .circleci/verbatim-sources/job-specs-custom.yml 
Auto-merging .circleci/verbatim-sources/job-specs-custom.yml 
CONFLICT (add/add): Merge conflict in .circleci/scripts/should_run_job.py 
Auto-merging .circleci/scripts/should_run_job.py 
CONFLICT (add/add): Merge conflict in .circleci/scripts/binary_ios_build.sh 
Auto-merging .circleci/scripts/binary_ios_build.sh 
CONFLICT (add/add): Merge conflict in .circleci/config.yml 
Auto-merging .circleci/config.yml 
Automatic merge failed; fix conflicts and then commit the result. 

See CircleCI build pytorch_linux_xenial_py3_clang5_mobile_custom_build_dynamic (2/3)

Step: "Build" (full log | pattern match details) <confirmed not flaky by 2 failures>

Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595 not found
DOCKER_IMAGE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595 
Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595 not found 

See CircleCI build pytorch_linux_xenial_py3_clang5_mobile_custom_build_static (3/3)

Step: "Build" (full log | pattern match details) <confirmed not flaky by 2 failures>

Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:07597f23-fa81-474c-8bef-5c8a91b50595 not found
DOCKER_IMAGE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:07597f23-fa81-474c-8bef-5c8a91b50595 
Error response from daemon: manifest for 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:07597f23-fa81-474c-8bef-5c8a91b50595 not found 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 35 times.

@kurtamohler
Copy link
Copy Markdown
Collaborator Author

CPU performance:

dtype num_elements TH CPU single-core time TH CPU multi-core time ATen CPU single-core time ATen CPU multi-core time Single-core speedup multi-core speedup
torch.float32 1000 0.0000404 0.000037 0.000045 0.0000484 0.8977777778 0.7644628099
torch.float32 10000 0.0001216 0.0001082 0.0002138 0.0002374 0.5687558466 0.4557708509
torch.float32 100000 0.0009386 0.0004288 0.001894 0.0008242 0.4955649419 0.5202620723
torch.float32 1000000 0.0102662 0.0007216 0.020459 0.0019244 0.5017938316 0.3749740179
torch.float32 10000000 0.1067806 0.0098368 0.2163726 0.02088 0.4935033364 0.4711111111
torch.float64 1000 0.0000462 0.0000418 0.000062 0.000082 0.7451612903 0.5097560976
torch.float64 10000 0.0001254 0.0001188 0.0003942 0.000523 0.3181126332 0.2271510516
torch.float64 100000 0.0008038 0.0001702 0.0037252 0.0013866 0.2157736497 0.1227462859
torch.float64 1000000 0.0078996 0.0004958 0.0390346 0.0037106 0.2023743038 0.1336172048
torch.float64 10000000 0.1048736 0.0136762 0.3988242 0.0384804 0.262956962 0.3554069085
torch.int32 1000 0.0000456 0.00006 0.0000978 0.0001258 0.4662576687 0.4769475358
torch.int32 10000 0.0001824 0.0001844 0.000755 0.0009836 0.241589404 0.1874745832
torch.int32 100000 0.0015858 0.0002876 0.0073278 0.0028862 0.2164087448 0.09964659414
torch.int32 1000000 0.0150388 0.0008232 0.0732916 0.006207 0.205191318 0.1326244563
torch.int32 10000000 0.1661234 0.012534 0.7425 0.0699036 0.2237352189 0.1793040702
torch.int64 1000 0.0000424 0.000043 0.0000932 0.0001218 0.4549356223 0.3530377668
torch.int64 10000 0.0001238 0.0001258 0.0007162 0.0009254 0.1728567439 0.1359412146
torch.int64 100000 0.0010318 0.0002146 0.0069764 0.002493 0.1478986297 0.08608102688
torch.int64 1000000 0.0085434 0.0006556 0.0704638 0.0063514 0.1212452351 0.103221337
torch.int64 10000000 0.1148014 0.0144832 0.7106772 0.064825 0.1615380372 0.2234199769

GPU performance:

dtype num_elements TH GPU time ATen GPU time Speedup
torch.float32 1000000 0.0000588 0.0000466 1.261802575
torch.float32 10000000 0.0000608 0.0000474 1.282700422
torch.float32 100000000 0.000063 0.0000456 1.381578947
torch.float64 1000000 0.0000572 0.0000482 1.186721992
torch.float64 10000000 0.0000566 0.0000476 1.18907563
torch.float64 100000000 0.0000608 0.000048 1.266666667
torch.int32 1000000 0.0000556 0.0000462 1.203463203
torch.int32 10000000 0.0000566 0.0000472 1.199152542
torch.int32 100000000 0.0000546 0.0000456 1.197368421
torch.int64 1000000 0.0000566 0.0000472 1.199152542
torch.int64 10000000 0.0000584 0.0000468 1.247863248
torch.int64 100000000 0.0000564 0.0000474 1.189873418

As you can see, the CPU performance decreases with my changes. My comment in the CPU issue talks about why I think that is, and why I currently think that should be fixed in a separate issue: #24753 (comment) But if someone thinks differently, my mind is open to being changed!

@kurtamohler
Copy link
Copy Markdown
Collaborator Author

The test failure from the Dr CI comment above is caused by incorrect handling of datatypes in my implementation, when the divisor argument is a scalar. It seems that the scalar needs to be converted to match the datatype of the tensor. This was the old behavior:

>>> torch.remainder(torch.tensor([4, 7, 0]), 2.55)
tensor([0, 1, 0])

This is the behavior with my changes:

>>> torch.remainder(torch.tensor([4, 7, 0]), 2.55)
tensor([1.4500, 1.9000, 0.0000])

So I will figure out how to fix that.

@kurtamohler
Copy link
Copy Markdown
Collaborator Author

Looks like I'm getting errors for the XLA device tests:

Mar 03 21:48:11 FAIL [0.043s]: test_remainder_inplace_negative_tensor_xla_int16 (__main__.TestTensorDeviceOpsXLA)
Mar 03 21:48:11 ----------------------------------------------------------------------
Mar 03 21:48:11 Traceback (most recent call last):
Mar 03 21:48:11   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 198, in instantiated_test
Mar 03 21:48:11     result = test(self, device_arg, dtype)
Mar 03 21:48:11   File "/var/lib/jenkins/workspace/xla/test/../../test/test_torch.py", line 15314, in fn
Mar 03 21:48:11     self.assertEqual(cpu_tensor, device_tensor, prec=precision)
Mar 03 21:48:11   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 336, in assertEqual
Mar 03 21:48:11     return DeviceTypeTestBase.assertEqual(self, x, y, prec, message, allow_inf, **kwargs)
Mar 03 21:48:11   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 872, in assertEqual
Mar 03 21:48:11     assertTensorsEqual(x, y)
Mar 03 21:48:11   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 842, in assertTensorsEqual
Mar 03 21:48:11     self.assertLessEqual(max_err, prec, message)
Mar 03 21:48:11 AssertionError: tensor(9, dtype=torch.int16) not less than or equal to 0.001 : 

I'm not sure how to fix that.

@kurtamohler kurtamohler requested a review from ezyang March 4, 2020 17:26
Comment thread aten/src/ATen/Declarations.cwrap Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it the case that currently non-broadcast-able but "same number of element" tensors work?

i.e. (2,3) X (3,2)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that didn't work in TH and it doesn't work in my implementation either. They both fail with the same error:

>>> torch.ones([2,3]).remainder(torch.ones([3,2]))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1

Comment thread aten/src/ATen/native/BinaryOps.cpp Outdated
@gchanan gchanan added the module: bc-breaking Related to a BC-breaking change label Mar 4, 2020
@ailzhang ailzhang self-requested a review March 4, 2020 19:52
Comment thread aten/src/ATen/native/cpu/BinaryOpsKernel.cpp Outdated
Comment thread aten/src/ATen/native/cpu/BinaryOpsKernel.cpp Outdated
Comment thread aten/src/ATen/native/cpu/BinaryOpsKernel.cpp Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused about what is going on here.

In my nightly version of PyTorch, the following happens when I have a tensor with remainder 0:
On CPU, I get Floating point exception (core dumped)
On CUDA, I get a tensor filled with 4294967295.

Is that consistent with what you are seeing? Why the change to make (at least) CUDA an error?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I do see that behavior. Shouldn't we change that though? It seems better to have a catchable exception rather than a core dump. And if the CPU version raises an exception, shouldn't the CUDA version as well?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the crash into an exception for div was considered high pri: #327

I suspect that changing this behaviour for remainder (and also fmod) is also desirable.

Copy link
Copy Markdown
Contributor

@ailzhang ailzhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi thanks for the PR!
I took a look at XLA test failure and I think it's real.
For example:

In [2]: a = torch.tensor(6)

In [3]: torch.remainder(a, -3)
Out[3]: tensor(0)

But it returns -3 after this patch. CPU and CUDA tests passed since they were changed together in this PR.

@ezyang ezyang removed their request for review March 5, 2020 15:33
@kurtamohler
Copy link
Copy Markdown
Collaborator Author

@ailzhang , do you have any suggestions for fixing the XLA support? I don't know anything about how operations are dispatched to that device.

@ailzhang
Copy link
Copy Markdown
Contributor

ailzhang commented Mar 5, 2020

@kurtamohler The XLA failure is not specific to XLA. I put an example in the comment above that this PR changed returned result of CPU and CUDA. If you fix the example above, XLA will pass as well.

@kurtamohler
Copy link
Copy Markdown
Collaborator Author

@ailzhang , oh I see! Sorry, I misunderstood. I'll go ahead and fix that.

@ailzhang
Copy link
Copy Markdown
Contributor

ailzhang commented Mar 5, 2020

@kurtamohler Nw! Please also add a test case for the example after your fix! ;) We definitely should have added a test case for those edge cases.

@kurtamohler
Copy link
Copy Markdown
Collaborator Author

@ailzhang , yeah there seem to be a fair amount of missing tests for various cases. I will add them today.

Comment thread aten/src/ATen/native/cuda/BinaryArithmeticKernel.cu Outdated
@kurtamohler kurtamohler force-pushed the port-remainder-aten-24753 branch from 360d69f to fd41b34 Compare March 6, 2020 16:19
@kurtamohler kurtamohler force-pushed the port-remainder-aten-24753 branch from 57c49fe to 5985e03 Compare March 6, 2020 16:52
@kurtamohler
Copy link
Copy Markdown
Collaborator Author

@ailzhang , the XLA build is working now. Thanks for the help!

@yf225 yf225 added module: porting Issues related to porting TH/THNN legacy to ATen native triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 6, 2020
Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@kurtamohler
Copy link
Copy Markdown
Collaborator Author

@ezyang , i just pushed the commit to remove the CUDA device-side assert. I think it came in right after you approved. Will phabricator pick it up?

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Mar 11, 2020

yeah we got it

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@ezyang merged this pull request in fbbeee0.

@zou3519
Copy link
Copy Markdown
Contributor

zou3519 commented Apr 9, 2020

I think this is marked as BC-breaking because torch.remainder(a, b) now has different behavior on CUDA.
Previously, torch.remainder(torch.tensor(1).cuda(), torch.tensor(0).cuda()) gives tensor(4294967295, device='cuda:0'), but now it asserts:
/scratch/rzou/pt/cudnn_double_Bwd/aten/src/ATen/native/cuda/BinaryArithmeticKernel.cu:80: lambda [](signed long, signed long)->signed long::operator()(signed long, signed long)->signed long: block: [0,0,0], thread: [0,0,0] Assertion b != 0 failed.

EDIT: edited for correctness

gchanan added a commit that referenced this pull request Apr 16, 2020
If you look at #34136, you will notice a commit (80c15c0) that didn't get merged.
This is to address that, to avoid crashing on remainder when the rhs is 0.

[ghstack-poisoned]
gchanan added a commit that referenced this pull request Apr 16, 2020
If you look at #34136, you will notice a commit (80c15c0) that didn't get merged.
This is to address that, to avoid crashing on remainder when the rhs is 0.

ghstack-source-id: e805e29
Pull Request resolved: #36760
gchanan added a commit to gchanan/pytorch that referenced this pull request Apr 16, 2020
If you look at pytorch#34136, you will notice a commit (pytorch@80c15c0) that didn't get merged.
This is to address that, to avoid crashing on remainder when the rhs is 0.

ghstack-source-id: e805e29
Pull Request resolved: pytorch#36760
facebook-github-bot pushed a commit that referenced this pull request Apr 17, 2020
Summary:
Pull Request resolved: #36760

If you look at #34136, you will notice a commit (80c15c0) that didn't get merged.
This is to address that, to avoid crashing on remainder when the rhs is 0.

Test Plan: Imported from OSS

Differential Revision: D21078776

Pulled By: gchanan

fbshipit-source-id: 0ac138cbafac28cf8d696a2a413d3c542138cff9
gchanan added a commit that referenced this pull request Apr 17, 2020
If you look at #34136, you will notice a commit (80c15c0) that didn't get merged.
This is to address that, to avoid crashing on remainder when the rhs is 0.

ghstack-source-id: e805e29
Pull Request resolved: #36760
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
CPU issue pytorch#24753
CUDA issue pytorch#24615
Pull Request resolved: pytorch#34136

Differential Revision: D20375458

Pulled By: ezyang

fbshipit-source-id: 1a9fb39a7e2d17a0d31bd14b211eaacea060e834
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Pull Request resolved: pytorch#36760

If you look at pytorch#34136, you will notice a commit (pytorch@80c15c0) that didn't get merged.
This is to address that, to avoid crashing on remainder when the rhs is 0.

Test Plan: Imported from OSS

Differential Revision: D21078776

Pulled By: gchanan

fbshipit-source-id: 0ac138cbafac28cf8d696a2a413d3c542138cff9
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: bc-breaking Related to a BC-breaking change module: porting Issues related to porting TH/THNN legacy to ATen native open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.