fix multinomial kernels to properly advance random states#38046
fix multinomial kernels to properly advance random states#38046ngimel wants to merge 5 commits intopytorch:masterfrom
Conversation
|
is it feasible to write a test? |
|
Is there any reasonable way to test this? |
💊 CI failures summary and remediationsAs of commit 350b490 (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker. This comment has been revised 17 times. |
There was a problem hiding this comment.
I can never remember how to determine the index formula XD
There was a problem hiding this comment.
It's like multi-d tensor indexing XD
|
I attempted a review but apparently my CUDA has atrophied significantly. One thing I couldn't figure out was why it was necessary to switch from 2D-1D setup to 1D-2D. |
|
It was not directly necessary to fix this bug, I did it to improve efficiency. |
There was a problem hiding this comment.
Wow, this is so obviously wrong it isn't even funny. Oh we used to only use one random
There was a problem hiding this comment.
So what, this is grid.x * ((numDist-1)/grid.x+1)*4... so is this just a really longwinded way of saying numDist * 4? Or maybe with some extra slop at the end?
There was a problem hiding this comment.
Oh you are right, I messed this up, changed the kernel but did not change this. Fill fix now.
There was a problem hiding this comment.
I fixed it and added a comment that hopefully makes things clearer (it's also simpler than it used to be, because each thread is using just 1 random in most cases)
ezyang
left a comment
There was a problem hiding this comment.
I won't claim to understand all the subtleties of the indexing arithmetic here, but I'm going to approve to move things along here. I will admit that I spent half an hour trying to puzzle out the indexing computations, and could not figure it out. If this were a paper, I'd ask for some argument of correctness for why the new offset calculation is correct (whereas the old is not). But maybe this is not worth the effort. I left some comments on bits that were puzzling me below.
There was a problem hiding this comment.
You probably didn't want this print here
There was a problem hiding this comment.
You modified the logic for replacement=False. Maybe that should be tested too?
There was a problem hiding this comment.
Changes to replacement=False were very superficial, but I'll add the test.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
) Summary: Before, multinomial kernels did not advance random states enough, which lead to the same sequence being generated over and over with a shift of 4. This PR fixes that. Fixes pytorch#37403 Pull Request resolved: pytorch#38046 Differential Revision: D21516542 Pulled By: ngimel fbshipit-source-id: 23248a8c3a5c44316c4c35cd71a8c3b5f76c90f2
Summary: Before, multinomial kernels did not advance random states enough, which lead to the same sequence being generated over and over with a shift of 4. This PR fixes that. Fixes #37403 Pull Request resolved: #38046 Differential Revision: D21516542 Pulled By: ngimel fbshipit-source-id: 23248a8c3a5c44316c4c35cd71a8c3b5f76c90f2
) Summary: Before, multinomial kernels did not advance random states enough, which lead to the same sequence being generated over and over with a shift of 4. This PR fixes that. Fixes pytorch#37403 Pull Request resolved: pytorch#38046 Differential Revision: D21516542 Pulled By: ngimel fbshipit-source-id: 23248a8c3a5c44316c4c35cd71a8c3b5f76c90f2
Before, multinomial kernels did not advance random states enough, which lead to the same sequence being generated over and over with a shift of 4. This PR fixes that.
Fixes #37403