Improve bernoulli rng-bit-generation memory footprint#5581
Merged
Conversation
5199b04 to
c33fbaf
Compare
JackCaoG
reviewed
Sep 15, 2023
| xla::One(probability.builder(), probability_shape.element_type()); | ||
| xla::XlaOp noise = RngUniform(seed, probability_shape, zero, one); | ||
| xla::XlaOp noise = | ||
| RngUniform(seed, probability_shape, zero, one, /*downcast=*/true); |
Collaborator
There was a problem hiding this comment.
is there a reason we only downcast for Bernoulli. Would it be better if we always downcast?
Contributor
Author
There was a problem hiding this comment.
Yea, I think so. For instance we could do for multinomial, since its also using it to create a random mask, and I don't think the extra precision matters. I guess we could try to extend the application when it becomes an issue or surface in the actual use case like Bernoulli, and it's easier to verify and test its benefit :)
JackCaoG
approved these changes
Sep 15, 2023
ManfeiBai
pushed a commit
that referenced
this pull request
Sep 15, 2023
* Allow downcasting RngUniform genenration for Bernoulli
will-cromar
pushed a commit
that referenced
this pull request
Sep 15, 2023
zpcore
pushed a commit
that referenced
this pull request
Sep 18, 2023
* Allow downcasting RngUniform genenration for Bernoulli
will-cromar
pushed a commit
that referenced
this pull request
Sep 18, 2023
will-cromar
added a commit
that referenced
this pull request
Sep 19, 2023
* Handle dynamo function without input (#5565) (#5577) * Make cpu tensor on XLA dynamo backend a warning instead of error (#5549) (#5576) * [author: jluntamazon] Adding more explicit HLO lowering control by exposing LoweringContext… (#5431) (#5580) * Adding more explicit HLO lowering control by exposing LoweringContext (and utilities) to python for Neuron * fixing linter issues * fixing spacing * apply comments and fix compilation errors * add test for new apis * fix linter * update test * update test * modify test * reverse back to GetIrValue() * update test inputs with random numbers * skip unittest because it only fails in CI --------- Co-authored-by: aws-kingrj <78175353+aws-kingrj@users.noreply.github.com> Co-authored-by: Ubuntu <ubuntu@ip-172-31-3-186.us-west-2.compute.internal> Co-authored-by: seanlatias <seanlatias@gmail.com> * fixing num_local_processes typo (#5573) (#5579) Co-authored-by: aws-kingrj <78175353+aws-kingrj@users.noreply.github.com> * Move where clear pending IR is called to avoid crash (#5552) (#5582) * Move where clear pending IR is called to avoid crash * fix CI * fix CI and add some debugging messages * Fix release branch and tag patterns for GitHub Actions (#5587) (#5590) * Improve bernoulli rng-bit-generation memory footprint (#5581) (#5589) * Allow downcasting RngUniform genenration for Bernoulli Co-authored-by: Yeounoh Chung <yeounoh@google.com> * Enable xla:gpu autocast for bfloat16 if not restricted (#5570) (#5591) * Enable autocast for XLA:GPU * linter fix * XLA autocast test for GPU and TPU * linter fix * Ensure that xla autocast is properly enabled for GPU and does not crash when torch cuda is not available. * linter fix * Add tests * Support bf16 * linter fix * exclude unsupported test cases * increase GPU test timeout to 300 Co-authored-by: Yeounoh Chung <yeounoh@google.com> * Cherry-pick: Don't trigger CI build on release tag push (#5595) Copy of #5594 on release branch * formatting --------- Co-authored-by: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Co-authored-by: Wonjoo Lee <wonjoo@google.com> Co-authored-by: aws-kingrj <78175353+aws-kingrj@users.noreply.github.com> Co-authored-by: Ubuntu <ubuntu@ip-172-31-3-186.us-west-2.compute.internal> Co-authored-by: seanlatias <seanlatias@gmail.com> Co-authored-by: Manfei <41607353+ManfeiBai@users.noreply.github.com> Co-authored-by: Yeounoh Chung <yeounoh@google.com>
jeffhataws
added a commit
to jeffhataws/xla
that referenced
this pull request
Dec 17, 2023
…h#5581) (pytorch#5589)" This reverts commit fa5d132.
sssrijan-amazon
added a commit
to jeffhataws/xla
that referenced
this pull request
Dec 18, 2023
Revert "Improve bernoulli rng-bit-generation memory footprint (pytorch#5581)…
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Bernoulli rng-bit-generation is using f32 -> u32 precision, which is not needed when the range is small. We should be able to opt in for lower precision to save memory, when it's applicable. We don't compute the dynamic range on the fly to avoid/minimize computation.
Tested with GPT-2 benchmark, for the same loss/convergence and for the lower memory profile.