[Functionalization] Enable FSDP#4691
Conversation
Summary: This pull request enables FSDP by replacing .set_ with our own _replace_xla_tensor API. The reason for that is Functionalization pass will reapply the new value to all the tensor's aliases since it's an in-place ops. However, that reapplication assumes the source and the destination would share the same amount of elements (view_copy). And .set_ doesn't follow this rule. P.S. It also removes two .data tests that are no longer applicable. Test Plan: CI.
|
@bdhirsh Hit a weird crash while running .set in our resnet FSDP: |
|
Is resnet failure blocking? |
No, I created a new API |
| self.assertEqual(met.counter_value('DestroyXlaTensor'), 5) | ||
|
|
||
| # shouldn't crash | ||
| t2.cpu() |
There was a problem hiding this comment.
nit, can we just do a value check here instead of just calling .cpu?
There was a problem hiding this comment.
Yea, Let me add that.
|
Sounds good. I'm having trouble repro'ing that crash without XLA, although I am able to repro at least one issue with fails with: Glad that you have a workaround for now! |
Thanks, Brian. Let me see if I can write a repro for you without XLA. |
Summary: This pull request enables FSDP by replacing .set_ with our own _replace_xla_tensor API. The reason for that is Functionalization pass will reapply the new value to all the tensor's aliases since it's an in-place ops. However, that reapplication assumes the source and the destination would share the same amount of elements (view_copy). And .set_ doesn't follow this rule. P.S. It also removes two .data tests that are no longer applicable. Test Plan: CI.
Summary:
This pull request enables FSDP by replacing .set_ with our own
_replace_xla_tensorAPI. The reason for that is Functionalization pass will reapply the new value to all the tensor's aliases since it's an in-place ops. However, that reapplication assumes the source and the destination would share the same amount of elements (view_copy). And .set_ doesn't follow this rule.P.S. It also removes two .data tests that are no longer applicable.
Test Plan:
CI.