Skip to content

Handle the parameter wrapping for SPMD#7604

Merged
JackCaoG merged 2 commits intomasterfrom
JackCaoG/auto_wrap_spmd
Jul 2, 2024
Merged

Handle the parameter wrapping for SPMD#7604
JackCaoG merged 2 commits intomasterfrom
JackCaoG/auto_wrap_spmd

Conversation

@JackCaoG
Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG commented Jul 1, 2024

should fix #7161

@alanwaketan
Copy link
Copy Markdown
Collaborator

Let me know when it's ready for review.

@JackCaoG JackCaoG marked this pull request as ready for review July 2, 2024 18:49
@JackCaoG
Copy link
Copy Markdown
Collaborator Author

JackCaoG commented Jul 2, 2024

not sure why TPU CI is skipped, but I think this pr is ready for review.

<< xla::HloSharding::FromProto(instr.sharding())->ToString();
}
}
return std::move(param_shardings);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to do std::move for return value.

output = linears(input)
torch_xla.sync()
xm.wait_device_ops()
self.assertEqual(output.shape, torch.Size([100, 40]))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check sufficient? Without this change, what will be the output shape?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just lazy and want to make sure this does not crash haha. I can add a unit test for check the value.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess my point is I cannot link the test to the change as the change attach the sharding to inputs but then you are checking the output.

Copy link
Copy Markdown
Collaborator Author

@JackCaoG JackCaoG Jul 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without the change, the test will crash when executing., since the param mapping threashold is set to 1

Comment thread test/run_tests.sh

function run_parameter_warpping {
echo "Running in parameter wrapping mode: $@"
XLA_PARAMETER_WRAPPING_THREADSHOLD=1 run_test "$@"
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now I think about it maybe it is better to set this env var as part of the test

@JackCaoG
Copy link
Copy Markdown
Collaborator Author

JackCaoG commented Jul 2, 2024

Let me fix the review comments in a follow up pr so we can land this change for @alanwaketan to try.

@JackCaoG JackCaoG merged commit 4b7e518 into master Jul 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

A large number of Tensors (>8000) in the graph will trigger an spmd sharding error

2 participants