[scan] Make sure inputs into fn are not device_data IR nodes#8769
Merged
[scan] Make sure inputs into fn are not device_data IR nodes#8769
fn are not device_data IR nodes#8769Conversation
deed400 to
20e63b8
Compare
bhavya01
reviewed
Feb 28, 2025
| is an in-place operation as opposed to a transform like found in JAX. | ||
|
|
||
| When `fn` contains a `mark_sharding` and the `mark_sharding` operates on one | ||
| of the carry or xs fake tensors, the original device data will be discarded |
Collaborator
There was a problem hiding this comment.
Can we elaborate on device data will be discarded? What about mark_sharding creates this illusion of new tensor being created? Is it setting new IR value or creating new XLA node?
Probably too much LTC detail but will be very helpful to understand why this happens. I am also trying to learn more about this. Will let you know if I find anything.
Collaborator
Author
There was a problem hiding this comment.
Yes. I've updated the explanation:
When `fn` contains a `mark_sharding` and the `mark_sharding` operates on one
of the carry or xs fake tensors, the original device data will be discarded
and a new one will be created in its place. That's because `mark_sharding` has
different code paths depending on if the IR has or doesn't have device data.
If the IR is an intermediate operation like add or matmul, `mark_sharding` will
update the sharding annotation. If the IR holds data, `mark_sharding` will
transfer the data to the TPU in a sharded manner, and update the data object
in the IR to point to a sharded data object, as can be seen in [2].
As a result, `fn` will appear to create a few empty tensors internally that
are unrelated to the carry and xs fake tensors, and the carry and xs will
appear completely unused.
Collaborator
Author
There was a problem hiding this comment.
[2]: https://github.com/pytorch/xla/blob/2675e6892c6f955fc2baf88d85dfdfa72062273c/torch_xla/csrc/xla_sharding_util.cpp#L799-L846
20e63b8 to
20d5575
Compare
Fixes #8742. This is to workaround a limitation of `mark_sharding`, which replaces the innards of the tensors it operates on. In other words, `mark_sharding` is an in-place operation as opposed to a transform like found in JAX. When `fn` contains a `mark_sharding` and the `mark_sharding` operates on one of the carry or xs fake tensors, the original device data will be discarded and a new one will be created in its place. As a result, `fn` will appear to create a few empty tensors internally that are unrelated to the carry and xs fake tensors, and the carry and xs will appear completely unused. See #8742 for the bug. In short, if an input into the layer to be scanned is a device data, and that layer does a `mark_sharding` on said input, then the graph capturing in `scan` will fail. The workaround here is simple and cursed: multiply any `device_data` by 1. This will make sure these tensor don't hold device data IR nodes and will defeat the device data replacement of `mark_sharding`. Fortunately, XLA simplifies away the multiplication (see [1]) so this should become a no-op by the time it hits the TPU. [1]: https://github.com/openxla/xla/blob/869f57d0082d7adbb9efc10cc18f51a562fc7bf3/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc#L4755-L4770
20d5575 to
1403694
Compare
qihqi
approved these changes
Feb 28, 2025
zpcore
pushed a commit
that referenced
this pull request
Mar 1, 2025
pgmoka
pushed a commit
that referenced
this pull request
Mar 5, 2025
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #8742.
This is to workaround a limitation of
mark_sharding, which replaces the innards of the tensors it operates on. In other words,mark_shardingis an in-place operation as opposed to a transform like found in JAX.When
fncontains amark_shardingand themark_shardingoperates on one of the carry or xs fake tensors, the original device data will be discarded and a new one will be created in its place. As a result,fnwill appear to create a few empty tensors internally that are unrelated to the carry and xs fake tensors, and the carry and xs will appear completely unused.See #8742 for the bug. In short, if an input into the layer to be scanned is a device data, and that layer does a
mark_shardingon said input, then the graph capturing inscanwill fail.The workaround here is simple and cursed: multiply any
device_databy 1. This will make sure these tensor don't hold device data IR nodes and will defeat the device data replacement ofmark_sharding.Fortunately, XLA simplifies away the multiplication (see 1) so this should become a no-op by the time it hits the TPU.