Skip to content

[scan] Make sure inputs into fn are not device_data IR nodes#8769

Merged
tengyifei merged 1 commit intomasterfrom
yifeit/scan-mark-sharding
Feb 28, 2025
Merged

[scan] Make sure inputs into fn are not device_data IR nodes#8769
tengyifei merged 1 commit intomasterfrom
yifeit/scan-mark-sharding

Conversation

@tengyifei
Copy link
Copy Markdown
Collaborator

Fixes #8742.

This is to workaround a limitation of mark_sharding, which replaces the innards of the tensors it operates on. In other words, mark_sharding is an in-place operation as opposed to a transform like found in JAX.

When fn contains a mark_sharding and the mark_sharding operates on one of the carry or xs fake tensors, the original device data will be discarded and a new one will be created in its place. As a result, fn will appear to create a few empty tensors internally that are unrelated to the carry and xs fake tensors, and the carry and xs will appear completely unused.

See #8742 for the bug. In short, if an input into the layer to be scanned is a device data, and that layer does a mark_sharding on said input, then the graph capturing in scan will fail.

The workaround here is simple and cursed: multiply any device_data by 1. This will make sure these tensor don't hold device data IR nodes and will defeat the device data replacement of mark_sharding.

Fortunately, XLA simplifies away the multiplication (see 1) so this should become a no-op by the time it hits the TPU.

@tengyifei tengyifei marked this pull request as ready for review February 28, 2025 04:12
@tengyifei tengyifei force-pushed the yifeit/scan-mark-sharding branch from deed400 to 20e63b8 Compare February 28, 2025 05:22
@bhavya01 bhavya01 self-requested a review February 28, 2025 18:08
Comment thread test/scan/test_scan_pallas.py Outdated
is an in-place operation as opposed to a transform like found in JAX.

When `fn` contains a `mark_sharding` and the `mark_sharding` operates on one
of the carry or xs fake tensors, the original device data will be discarded
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we elaborate on device data will be discarded? What about mark_sharding creates this illusion of new tensor being created? Is it setting new IR value or creating new XLA node?

Probably too much LTC detail but will be very helpful to understand why this happens. I am also trying to learn more about this. Will let you know if I find anything.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I've updated the explanation:

When `fn` contains a `mark_sharding` and the `mark_sharding` operates on one
of the carry or xs fake tensors, the original device data will be discarded
and a new one will be created in its place. That's because `mark_sharding` has
different code paths depending on if the IR has or doesn't have device data.
If the IR is an intermediate operation like add or matmul, `mark_sharding` will
update the sharding annotation. If the IR holds data, `mark_sharding` will
transfer the data to the TPU in a sharded manner, and update the data object
in the IR to point to a sharded data object, as can be seen in [2].

As a result, `fn` will appear to create a few empty tensors internally that
are unrelated to the carry and xs fake tensors, and the carry and xs will
appear completely unused.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[2]: https://github.com/pytorch/xla/blob/2675e6892c6f955fc2baf88d85dfdfa72062273c/torch_xla/csrc/xla_sharding_util.cpp#L799-L846

@tengyifei tengyifei requested a review from zpcore February 28, 2025 19:28
@tengyifei tengyifei force-pushed the yifeit/scan-mark-sharding branch from 20e63b8 to 20d5575 Compare February 28, 2025 19:34
Fixes #8742.

This is to workaround a limitation of `mark_sharding`, which replaces
the innards of the tensors it operates on. In other words, `mark_sharding`
is an in-place operation as opposed to a transform like found in JAX.

When `fn` contains a `mark_sharding` and the `mark_sharding` operates on one
of the carry or xs fake tensors, the original device data will be discarded
and a new one will be created in its place. As a result, `fn` will appear to
create a few empty tensors internally that are unrelated to the carry and xs
fake tensors, and the carry and xs will appear completely unused.

See #8742 for the bug. In short,
if an input into the layer to be scanned is a device data, and that layer
does a `mark_sharding` on said input, then the graph capturing in `scan`
will fail.

The workaround here is simple and cursed: multiply any `device_data` by 1.
This will make sure these tensor don't hold device data IR nodes and will
defeat the device data replacement of `mark_sharding`.

Fortunately, XLA simplifies away the multiplication (see [1]) so this should
become a no-op by the time it hits the TPU.

[1]: https://github.com/openxla/xla/blob/869f57d0082d7adbb9efc10cc18f51a562fc7bf3/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc#L4755-L4770
@tengyifei tengyifei force-pushed the yifeit/scan-mark-sharding branch from 20d5575 to 1403694 Compare February 28, 2025 19:37
@tengyifei tengyifei requested a review from bhavya01 February 28, 2025 21:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Scan + flash attention kernel = NaN

3 participants