[SPMD] Support SPMDFullToShardShape#6922
Conversation
| return xtensors; | ||
| } | ||
|
|
||
| bool IsIr(const at::Tensor& tensor) { |
| # It looks like XLA does't like only having manual sharding in the HLO. | ||
| # It needs to be paired with SPMDFullToShardShape/SPMDShardToFullShape. | ||
| # The following exception cannot be caught somehow. | ||
| # xx.cpu() |
There was a problem hiding this comment.
do you intend to keep this xx.cpu?
There was a problem hiding this comment.
Yea, it's more like a note that this won't work... I was trying to use with self.assertRaises but that doesn't capture the exception... I have noticed this before too. When libtpu crashed, it's hard to catch it in the py level. Not sure why. Maybe you have some better ideas?
There was a problem hiding this comment.
oh I think I run into similar issue before.. The way I handle it was ugly through
xla/test/spmd/test_dynamo_spmd.py
Lines 172 to 181 in a7a1357
There was a problem hiding this comment.
C++ crash on pt level can be caught with self.assertRaise but not libtpu level.... I'm not sure why... yea, not even with this hack...
There was a problem hiding this comment.
cc @will-cromar Do you know how to catch libtpu exception on py? Appreciate your insights.
There was a problem hiding this comment.
I don't think you can. To make a proper runtime error, you have to raise an exception, and Google internal binaries don't generally do that. I wrote about a similar case in #6700 (comment)
There was a problem hiding this comment.
Thanks, Will. That makes a lot of sense now.
| tensor_methods::custom_sharding_(output_tensor, | ||
| input_tensor->sharding_spec()); | ||
| input_tensor->sharding_spec(), | ||
| CustomSharding::Type::kSharding); |
There was a problem hiding this comment.
so you assume only tensor with kSharding will be called with in place ops?
There was a problem hiding this comment.
That's the original design which is to align with the original design of SPMD... So yea.. for kSharding...
There was a problem hiding this comment.
can we make kSharding to be default then? This way most people reading this code won't need to figure out what kSharding actually means.
| enum class Type { | ||
| kSharding, | ||
| kSPMDFullToShardShape, | ||
| kSPMDShardToFullShape, | ||
| }; |
There was a problem hiding this comment.
This enum is really confusing, can you add some comment around what they actually does? I was reading the SPMD code again, this op itself only means we want to shard the underlying value and the actual sharding resides in the XlaTensor or Based XLAIR object?
There was a problem hiding this comment.
Right, this is just the name of the custom call. The sharding annotation is in XlaTensor as normal. I can add more explanations.
There was a problem hiding this comment.
Maybe we can annotate explicilty that this is sharding type for custom call in the enum class name or somethinhg.
There was a problem hiding this comment.
I guess the current approach sort of does it already? Can you be more specific? @yeounoh
There was a problem hiding this comment.
I agree, Type is already defined under CustomSharding
|
Thanks, Yeounoh! |
Summary: This pull request supports SPMDFullToShardShape which is a custom op that opens a region for non-partitioned graph in SPMD program. It will stop SPMD auto sharding and partition in that region and therefore allows manual sharding like cc ops. To implement it, this pull request expands CustomSharding node to accept a new type. To be notice, the output shape of the op needs to be the shard shape of the input, and the node needs to have manual sharding annotation. Test Plan: PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_spmd_full_to_shard_shape
Summary: This pull request supports SPMDFullToShardShape which is a custom op that opens a region for non-partitioned graph in SPMD program. It will stop SPMD auto sharding and partition in that region and therefore allows manual sharding like cc ops. To implement it, this pull request expands CustomSharding node to accept a new type. To be notice, the output shape of the op needs to be the shard shape of the input, and the node needs to have manual sharding annotation. Test Plan: PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_spmd_full_to_shard_shape
Summary:
This pull request supports SPMDFullToShardShape which is a custom op that opens a region for non-partitioned graph in SPMD program. It will stop SPMD auto sharding and partition in that region and therefore allows manual sharding like cc ops.
To implement it, this pull request expands CustomSharding node to accept a new type. To be notice, the output shape of the op needs to be the shard shape of the input, and the node needs to have manual sharding annotation.
Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_spmd_full_to_shard_shape