[SPMD][PoC] compile & execute with PjRt#3684
Conversation
d38bd7d to
09f4640
Compare
09f4640 to
5e07428
Compare
c9399ac to
91262bc
Compare
0ca964c to
c26f94b
Compare
* Add device assignment for SPMD compilation
|
CPU test passes, but the GPU fails with the following somewhat unrelated (at least on the outset) error: The mater branch is green, though. cc @JackCaoG |
|
Seesm irrelevant, let me just restart the gpu ci |
|
I will take another pass and try to merge it. |
| expected = t + t | ||
|
|
||
| xt = t.to(xm.xla_device()) | ||
| n_devices = xm.xrt_world_size() |
There was a problem hiding this comment.
Does CI run this test or we only run it on TPU?
There was a problem hiding this comment.
We only run cpp tests -- covers the internal changes that affects the non-spmd code paths -- and the python API tests are disabled link. I will re-enable it after debugging/ adding the API unit tests.
| virtual void TransferToServer(absl::Span<const TensorSource> tensors, | ||
| absl::Span<const DataPtr> datas) = 0; | ||
|
|
||
| // Transfers local sharded tensor values to the TPU servers and returns a |
There was a problem hiding this comment.
I would use TPU Device instead of TPU Server, there is no server in PJRT context.
| void XLATensor::SetShardingSpec(const ShardingSpec& sharding_spec) { | ||
| XLA_CHECK(GetIrValue().node != nullptr) << "Tyring to access a null cursor"; | ||
| dynamic_cast<XlaNode*>(data()->ir_value.node.get()) | ||
| dynamic_cast<XlaNode*>(GetIrValue().node.get()) |
There was a problem hiding this comment.
hmm, we should add a XlaNodeCast to replace dynamic_cast<XlaNode*> so it is cleaner
There was a problem hiding this comment.
I see, I normally prefer more explicit type identifiers especially for casting (similar to avoid using auto too much).
| // TODO(yeounoh): Sharding annotation must be removed by explicit call to | ||
| // ClearSharding. | ||
| ShardingSpecPtr sharding = sharding_spec(); | ||
| if (sharding != nullptr) { |
There was a problem hiding this comment.
we need a test for this. For example when we deep copy a tensor with sharding, the result tensor should also have sharding. Something similar to
Line 1670 in b3f79cc
@steventk-g can you add a test case?
There was a problem hiding this comment.
Good point, @steventk-g let me handle this if you haven't already started.
|
|
||
| auto cached_computation = std::make_shared<CachedComputation>( | ||
| std::move(compile_result.computation)); | ||
| std::move(compile_result.computation), compile_result.is_sharded); |
There was a problem hiding this comment.
why do we need is_sharded separatelly in CachedComputation?
There was a problem hiding this comment.
We could pass around is_sharded between APIs, or wrap it inside the CachedComputation. Is sharded is later needed for the execution (will be associated with the cached computation only), and the latter doesn't require changing the function APIs here and there.
There was a problem hiding this comment.
Mostly LGTM, I had a question regarding ExecuteReplicated in #3684 (comment). If we can align on that this pr is ready to merge.
JackCaoG
left a comment
There was a problem hiding this comment.
Thanks @yeounoh ! I will merge this pr to unblock @steventk-g
This is a follow-up to #3476 and contributes to #3871. The changes include:
Compilepartitioned HLO computation graph with sharding annotations.PjRtComputationClientintegration to supportSPMDsharded operations.PjRtShardedDatastruct to represent shardedData.InputHandlerfor parameter sharding and sharded data transfer.ExecuteReplicatedfor partitioned computation.The PoC implementation supports
replicatedandtiledsharding annotations, and single-hostxla:tpubackend. This enables a simple sharded computation on v3-8, like