Skip to content

Reenable the distributed checkpointing test#8424

Merged
JackCaoG merged 1 commit intomasterfrom
JackCaoG/reeneable_checkpoint_test
Dec 2, 2024
Merged

Reenable the distributed checkpointing test#8424
JackCaoG merged 1 commit intomasterfrom
JackCaoG/reeneable_checkpoint_test

Conversation

@JackCaoG
Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG commented Nov 27, 2024

This is follow up of #8386.

In the previous pr I found that someone during fallback the pytorch will try to update an existing XLATensor with a CPU tesnor with different shape. In that case we need to remove the sharding spec otherwise there will be a shape mismatch. However I found that in the distributed point we will swap the existing XLATensor with the cpu tensor and it seems like we want to keep the sharding spec.

@jonb377 one concern I have is that test only test the single host, I felt like if it is a actual multi-host case the CPU tensor withh have different shape(sharded) than the shardingspec? I am not sure if we have such test somewhere. Even if we clear the shardingspec after a torch_xla.sync() the tensor will be moved to the device, but most likely replicated. I am a bit worried if I am breaking the distributed checkpointing here.

@JackCaoG JackCaoG added the tpuci label Nov 27, 2024
@JackCaoG JackCaoG marked this pull request as ready for review November 28, 2024 09:06
@JackCaoG JackCaoG requested a review from tengyifei November 28, 2024 09:06
@JackCaoG JackCaoG merged commit 591c397 into master Dec 2, 2024
rpsilva-aws pushed a commit to rpsilva-aws/xla that referenced this pull request Dec 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants