Conversation
|
OK There are 2 issues
|
Thanks @JackCaoG , I am going to merge a output param sharding patch, which might change the code path a bit. Let's chat offline, I can explain further. |
…aceholder if SPMD is enabled
| // Device will be Virtual device if SPMD is enabled. | ||
| torch::lazy::BackendDevice device = | ||
| ShardingUtil::UseVirtualDevice() ? ParseDeviceString("SPMD:0") | ||
| : torch_xla::GetCurrentDevice(); |
There was a problem hiding this comment.
@yeounoh I am not sure if we should just update GetCurrentDevice, any thought? We need to sit down and think about how to surface this virtual device to user soon..
There was a problem hiding this comment.
I voted for GetCurrentDevice as there might be other scenario where the caller will also need to distinguish SPMD:0 with XLA:0.
There was a problem hiding this comment.
GetCurrentDevice is being used over 30 places in our code base now, mostly during tracing and caller trying to figure out the hw type. I think it should be fine as long as SPMD:0 can be resolved into correct hardware type. I would leave that in a separate pr since it touches too many codes and might introduce noise.
|
I think this one is ready for review, I will add more test cases(input data sharding, which I am not sure if it works or not) and features in the next pr. |
| WrapXlaData(xla::ComputationClient::Get()->CreateDataPlaceholder( | ||
| device.toString(), std::move(shape))); | ||
| // if SPMD is enabled, we assume all output will be replicated | ||
| if (ShardingUtil::UseVirtualDevice()) { |
There was a problem hiding this comment.
Why we now start adding this for the dynamo path? We don't need this for the LTC path?
There was a problem hiding this comment.
Looks like this patch is dynamo exclusive... Should we hint this somewhere?
There was a problem hiding this comment.
the lazy code path already have this logic, in fact I copt this logic from lazy code path lol
There was a problem hiding this comment.
I smell an opportunity to merge two code paths more. But let's do it in a follow up.
| if (auto xla_tensor_ptr = bridge::TryGetXlaTensor(ivalue.toTensor())) { | ||
| dataptr = xla_tensor_ptr->GetXlaData(); | ||
| } else { | ||
| XLA_CHECK(device.type() != (int8_t)XlaDeviceType::SPMD) |
There was a problem hiding this comment.
What's this XLA_CHECK for?
There was a problem hiding this comment.
It's not needed, but more for a sanity check I probably added to ensure that this doesn't happen. Basically, we want to make sure that the SPMD device type is always on the backend (device data).
| # Add an additional 1x1 layer at the end to ensure the final layer | ||
| # is not sharded. | ||
| self.fc3 = nn.Linear(1, 1) |
There was a problem hiding this comment.
Is this due to the lack of output sharding propagation?
There was a problem hiding this comment.
yea, in this pr I tried to keep it that output is replicated. We can expand this after output sharding pr is ready.
Input sharding should (used to) work if the sharded input is used for the torch compilation. Let me know. Will take a pass on the chages now as well, thanks. |
| @@ -590,6 +593,15 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( | |||
| torch::lazy::BackendDataPtr handle = | |||
| WrapXlaData(xla::ComputationClient::Get()->CreateDataPlaceholder( | |||
There was a problem hiding this comment.
If it's SPMD virtual device, then we should always use PjRtShardedData handle.
There was a problem hiding this comment.
hmm, is the logic below to call WrapDataShards not enough? This code path is shared between spmd and non-spmd code path.
This work was done by @yeounoh and I am trying to land this pr in his behalf. The last attempt was made for @steventk-g in #4862.
Currently test failed with an
Check failed: handle->HasValue(), so still WIP.