IFRT prototype#5677
Conversation
|
At this point, I am able to run our SPMD ResNet50 example successfully, but it's extremely slow. Any case where I resharded an array, I chose "copy" semantics for safety. I'll have to take another pass and carefully think through the ownership of the underlying data, since the excessive copies are likely contributing to poor performance. Sharded execution does not appear to be implemented at all within the IFRT/PJRT wrapper, so I marked that method unimplemented for now. Likewise, I Dynamic shape is currently unsupported by Until we have feature parity, I will keep IFRT as a separate |
f58cba4 to
40ab61d
Compare
bbde87f to
f411a37
Compare
deed402 to
9bc594e
Compare
|
Coming back to this PR (finally) after merging supporting changes in separate PRs. Performance is significantly better after rebasing -- it only lags PJRT by ~10% on ResNet50 now, compared to 80% in my first draft. There's still room for optimization, particularly around reducing the number of copies used when transforming IFRT arrays. IFRT is still highly experimental in this state. Known outstanding issues other than performance:
I'll clean up this PR and send it for review as an optional/experimental setting. |
e4856a3 to
1855086
Compare
|
Performance on LLama 7B is not bad! It's somewhere between PJRT now and PJRT before I started working on some optimizations this month: |
|
If you don't intend to merge this for 2.2 release, I will hold on the review until the branch cut. |
|
Merging after the cut sounds good to me. This won't be useful in the 2.2 release. |
|
I will take a look today |
|
|
||
| // Builds a map from the device's global ordinal to its index in the `devices` | ||
| // array. | ||
| std::unordered_map<int, int> build_index_map( |
There was a problem hiding this comment.
for these utils, can we share between pjrt and ifrt? so they are actually different?
There was a problem hiding this comment.
This was intentional. I wanted to minimize changes to common code, and ideally we would be able to remove the PJRT computation client at some point. Would you prefer I try to factor out all common functionality in this PR?
| } | ||
|
|
||
| IfrtComputationClient::IfrtComputationClient() { | ||
| std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, ""); |
There was a problem hiding this comment.
let's do a grep of pjrt in this file and replace them with ifrt. Through I am curious did you intend to query the EnvPjrtDevice here?
| const std::vector<DataPtr>& shards, std::string device, xla::Shape shape, | ||
| xla::OpSharding sharding) { | ||
| // TODO: implement CreateDataPlaceholder for sharded data | ||
| if (shards.size() == 0) { |
There was a problem hiding this comment.
is there a legit use case of shards.zie() == 0?
There was a problem hiding this comment.
This is how sharded data placeholders get created right now:
xla/torch_xla/csrc/xla_sharding_util.cpp
Lines 590 to 593 in a2f80e4
I'd rather update CreateDataPlaceholder to take a sharding and make it more explicit
This reverts commit 7d52f67.
3cec447 to
9ebc616
Compare
ComputationClientwith IFRT, which currently just wraps PJRT.XLA_USE_IFRT=1initialize_pjrt.cc/hsince IFRT wraps the samePjRtClient.pjrt_computation_client:spmd_device_strconstPjRtClientto aunique_ptr. Only SE:TPU required us to useshared_ptr, which is now removed.There's still some opportunities to refactor common functionality up to
ComputationClient, but I'm trying to minimize changes to the high-level API for this PR.IFRT is still highly experimental. Use at your own risk. See my comments below for caveats and limitations.