Skip to content

Improve bernoulli rng-bit-generation memory footprint#5581

Merged
yeounoh merged 3 commits intomasterfrom
yeounoh_bernoulli
Sep 15, 2023
Merged

Improve bernoulli rng-bit-generation memory footprint#5581
yeounoh merged 3 commits intomasterfrom
yeounoh_bernoulli

Conversation

@yeounoh
Copy link
Copy Markdown
Contributor

@yeounoh yeounoh commented Sep 14, 2023

Bernoulli rng-bit-generation is using f32 -> u32 precision, which is not needed when the range is small. We should be able to opt in for lower precision to save memory, when it's applicable. We don't compute the dynamic range on the fly to avoid/minimize computation.

Tested with GPT-2 benchmark, for the same loss/convergence and for the lower memory profile.

@yeounoh yeounoh self-assigned this Sep 14, 2023
xla::One(probability.builder(), probability_shape.element_type());
xla::XlaOp noise = RngUniform(seed, probability_shape, zero, one);
xla::XlaOp noise =
RngUniform(seed, probability_shape, zero, one, /*downcast=*/true);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason we only downcast for Bernoulli. Would it be better if we always downcast?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I think so. For instance we could do for multinomial, since its also using it to create a random mask, and I don't think the extra precision matters. I guess we could try to extend the application when it becomes an issue or surface in the actual use case like Bernoulli, and it's easier to verify and test its benefit :)

@yeounoh yeounoh merged commit 26a81a1 into master Sep 15, 2023
ManfeiBai pushed a commit that referenced this pull request Sep 15, 2023
* Allow downcasting RngUniform genenration for Bernoulli
will-cromar pushed a commit that referenced this pull request Sep 15, 2023
* Allow downcasting RngUniform genenration for Bernoulli

Co-authored-by: Yeounoh Chung <yeounoh@google.com>
zpcore pushed a commit that referenced this pull request Sep 18, 2023
* Allow downcasting RngUniform genenration for Bernoulli
will-cromar pushed a commit that referenced this pull request Sep 18, 2023
* Allow downcasting RngUniform genenration for Bernoulli

Co-authored-by: Yeounoh Chung <yeounoh@google.com>
will-cromar added a commit that referenced this pull request Sep 19, 2023
* Handle dynamo function without input (#5565) (#5577)

* Make cpu tensor on XLA dynamo backend a warning instead of error (#5549) (#5576)

* [author: jluntamazon] Adding more explicit HLO lowering control by exposing LoweringContext… (#5431) (#5580)

* Adding more explicit HLO lowering control by exposing LoweringContext (and utilities) to python for Neuron

* fixing linter issues

* fixing spacing

* apply comments and fix compilation errors

* add test for new apis

* fix linter

* update test

* update test

* modify test

* reverse back to GetIrValue()

* update test inputs with random numbers

* skip unittest because it only fails in CI

---------

Co-authored-by: aws-kingrj <78175353+aws-kingrj@users.noreply.github.com>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-3-186.us-west-2.compute.internal>
Co-authored-by: seanlatias <seanlatias@gmail.com>

* fixing num_local_processes typo (#5573) (#5579)

Co-authored-by: aws-kingrj <78175353+aws-kingrj@users.noreply.github.com>

* Move where clear pending IR is called to avoid crash (#5552) (#5582)

* Move where clear pending IR is called to avoid crash

* fix CI

* fix CI and add some debugging messages

* Fix release branch and tag patterns for GitHub Actions (#5587) (#5590)

* Improve bernoulli rng-bit-generation memory footprint (#5581) (#5589)

* Allow downcasting RngUniform genenration for Bernoulli

Co-authored-by: Yeounoh Chung <yeounoh@google.com>

* Enable xla:gpu autocast for bfloat16 if not restricted (#5570) (#5591)

* Enable autocast for XLA:GPU

* linter fix

* XLA autocast test for GPU and TPU

* linter fix

* Ensure that xla autocast is properly enabled for GPU and does not crash when torch cuda is not available.

* linter fix

* Add tests

* Support bf16

* linter fix

* exclude unsupported test cases

* increase GPU test timeout to 300

Co-authored-by: Yeounoh Chung <yeounoh@google.com>

* Cherry-pick: Don't trigger CI build on release tag push (#5595)

Copy of #5594 on release branch

* formatting

---------

Co-authored-by: JackCaoG <59073027+JackCaoG@users.noreply.github.com>
Co-authored-by: Wonjoo Lee <wonjoo@google.com>
Co-authored-by: aws-kingrj <78175353+aws-kingrj@users.noreply.github.com>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-3-186.us-west-2.compute.internal>
Co-authored-by: seanlatias <seanlatias@gmail.com>
Co-authored-by: Manfei <41607353+ManfeiBai@users.noreply.github.com>
Co-authored-by: Yeounoh Chung <yeounoh@google.com>
jeffhataws added a commit to jeffhataws/xla that referenced this pull request Dec 17, 2023
sssrijan-amazon added a commit to jeffhataws/xla that referenced this pull request Dec 18, 2023
Revert "Improve bernoulli rng-bit-generation memory footprint (pytorch#5581)…
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants