Skip to content

fix multinomial kernels to properly advance random states#38046

Closed
ngimel wants to merge 5 commits intopytorch:masterfrom
ngimel:multinomial
Closed

fix multinomial kernels to properly advance random states#38046
ngimel wants to merge 5 commits intopytorch:masterfrom
ngimel:multinomial

Conversation

@ngimel
Copy link
Copy Markdown
Collaborator

@ngimel ngimel commented May 7, 2020

Before, multinomial kernels did not advance random states enough, which lead to the same sequence being generated over and over with a shift of 4. This PR fixes that.
Fixes #37403

@gchanan
Copy link
Copy Markdown
Contributor

gchanan commented May 7, 2020

is it feasible to write a test?

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented May 7, 2020

Is there any reasonable way to test this?

@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented May 7, 2020

💊 CI failures summary and remediations

As of commit 350b490 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 17 times.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can never remember how to determine the index formula XD

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's like multi-d tensor indexing XD

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented May 11, 2020

I attempted a review but apparently my CUDA has atrophied significantly. One thing I couldn't figure out was why it was necessary to switch from 2D-1D setup to 1D-2D.

@ngimel
Copy link
Copy Markdown
Collaborator Author

ngimel commented May 11, 2020

It was not directly necessary to fix this bug, I did it to improve efficiency.
Before:
1D grid, each block was responsible for num_distributions/max_blocks distributions, doing one distribution at a time in a loop. If the number of distributions is small, we don't have enough blocks to fill the device.
2D block, each block is (32,4) threads, but only 4 threads (with threadIdx.x=0) in the block actually did something, and were responsible for generating all the samples for the given distribution. So if we are talking about generating 100000 samples that's 25000 iterations in the loop while other threads are idling and we potentially don't even have enough blocks to fill the device.
We have 128*min(num_distributions, max_blocks) threads
Now:
2D grid, y-dimension of the grid is responsible for num_distributions/max_y_blocks distributions, doing one distribution at a time in a loop.
x-dimension of the grid and x dimension of the block are responsible for generating samples. Generating samples with replacement does not depend on the previous results so there is no point in serializing anything. Roughly speaking, we can launch num_distributions * num_samples threads (subject to some limits, exact formulas omitted for clarity). This allows us to get reasonable device utilization as long as number of samples is not tiny (say more than 128) and num_distributions * n_samples is large enough to fill the device (whereas previously we depended on num_distributions alone being large enough, and used only a quarter of the threads).

Copy link
Copy Markdown
Contributor

@ezyang ezyang May 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, this is so obviously wrong it isn't even funny. Oh we used to only use one random

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what, this is grid.x * ((numDist-1)/grid.x+1)*4... so is this just a really longwinded way of saying numDist * 4? Or maybe with some extra slop at the end?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh you are right, I messed this up, changed the kernel but did not change this. Fill fix now.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed it and added a comment that hopefully makes things clearer (it's also simpler than it used to be, because each thread is using just 1 random in most cases)

Copy link
Copy Markdown
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I won't claim to understand all the subtleties of the indexing arithmetic here, but I'm going to approve to move things along here. I will admit that I spent half an hour trying to puzzle out the indexing computations, and could not figure it out. If this were a paper, I'd ask for some argument of correctness for why the new offset calculation is correct (whereas the old is not). But maybe this is not worth the effort. I left some comments on bits that were puzzling me below.

Comment thread test/test_torch.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably didn't want this print here

Comment thread test/test_torch.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You modified the logic for replacement=False. Maybe that should be tested too?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes to replacement=False were very superficial, but I'll add the test.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@ngimel merged this pull request in 3d96808.

@ngimel ngimel added this to the 1.5.1 milestone May 15, 2020
gchanan pushed a commit to gchanan/pytorch that referenced this pull request May 28, 2020
)

Summary:
Before, multinomial kernels did not advance random states enough, which lead to the same sequence being generated over and over with a shift of 4. This PR fixes that.
Fixes pytorch#37403
Pull Request resolved: pytorch#38046

Differential Revision: D21516542

Pulled By: ngimel

fbshipit-source-id: 23248a8c3a5c44316c4c35cd71a8c3b5f76c90f2
gchanan pushed a commit that referenced this pull request May 28, 2020
Summary:
Before, multinomial kernels did not advance random states enough, which lead to the same sequence being generated over and over with a shift of 4. This PR fixes that.
Fixes #37403
Pull Request resolved: #38046

Differential Revision: D21516542

Pulled By: ngimel

fbshipit-source-id: 23248a8c3a5c44316c4c35cd71a8c3b5f76c90f2
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
)

Summary:
Before, multinomial kernels did not advance random states enough, which lead to the same sequence being generated over and over with a shift of 4. This PR fixes that.
Fixes pytorch#37403
Pull Request resolved: pytorch#38046

Differential Revision: D21516542

Pulled By: ngimel

fbshipit-source-id: 23248a8c3a5c44316c4c35cd71a8c3b5f76c90f2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.multinomial behaves abnormally with CUDA tensor

5 participants