Handle the parameter wrapping for SPMD#7604
Conversation
|
Let me know when it's ready for review. |
|
not sure why TPU CI is skipped, but I think this pr is ready for review. |
| << xla::HloSharding::FromProto(instr.sharding())->ToString(); | ||
| } | ||
| } | ||
| return std::move(param_shardings); |
There was a problem hiding this comment.
I don't think we need to do std::move for return value.
| output = linears(input) | ||
| torch_xla.sync() | ||
| xm.wait_device_ops() | ||
| self.assertEqual(output.shape, torch.Size([100, 40])) |
There was a problem hiding this comment.
Is this check sufficient? Without this change, what will be the output shape?
There was a problem hiding this comment.
I was just lazy and want to make sure this does not crash haha. I can add a unit test for check the value.
There was a problem hiding this comment.
I guess my point is I cannot link the test to the change as the change attach the sharding to inputs but then you are checking the output.
There was a problem hiding this comment.
without the change, the test will crash when executing., since the param mapping threashold is set to 1
|
|
||
| function run_parameter_warpping { | ||
| echo "Running in parameter wrapping mode: $@" | ||
| XLA_PARAMETER_WRAPPING_THREADSHOLD=1 run_test "$@" |
There was a problem hiding this comment.
now I think about it maybe it is better to set this env var as part of the test
|
Let me fix the review comments in a follow up pr so we can land this change for @alanwaketan to try. |
should fix #7161