[PJRT] Don't use futures on ExecuteSharded output buffers#4503
Merged
will-cromar merged 1 commit intomasterfrom Jan 25, 2023
Merged
[PJRT] Don't use futures on ExecuteSharded output buffers#4503will-cromar merged 1 commit intomasterfrom
ExecuteSharded output buffers#4503will-cromar merged 1 commit intomasterfrom
Conversation
dcdfd7b to
c1572e8
Compare
c1572e8 to
f8e622e
Compare
ManfeiBai
pushed a commit
that referenced
this pull request
Jan 30, 2023
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Setting callbacks on each output buffer (see #4400) led to a significant performance regression on both TPU v3 and TPU v4. Remove those callbacks entirely.
Instead, only use the
returned_futureofExecuteShardedwhen not using PJRT C API. This means theExecuteTimemetric will be broken with the PJRT C API until our next pin update, when we can remove this condition.Tested with both
PJRT_DEVICE=TPUandPJRT_DEVICE=TPU_C_API. Both get performance similar to before #4400.