Don't set $TPU_LIBRARY_PATH during import#5698
Merged
will-cromar merged 2 commits intomasterfrom Oct 16, 2023
Merged
Conversation
Collaborator
|
I am a bit confuse after this pr, what will happen if user still set |
Collaborator
Author
|
We'll still take Internally, we can use |
JackCaoG
approved these changes
Oct 11, 2023
vanbasten23
reviewed
Oct 12, 2023
| 3. libtpu-nightly pip package | ||
|
|
||
| Sets $PTXLA_TPU_LIBRARY_PATH if path is inferred by us to prevent conflicts | ||
| with other frameworks. This env var will be removed in a future version. |
Collaborator
There was a problem hiding this comment.
I wonder under what condition can we remove this env var.
zpcore
pushed a commit
that referenced
this pull request
Oct 19, 2023
* Don't set $TPU_LIBRARY_PATH during import * remove chekck for new env var so people don't use it
Collaborator
|
This PR regress the profiler. I cannot take multi-host profilers after this change. I'm going to revert it. |
alanwaketan
added a commit
that referenced
this pull request
Oct 25, 2023
This reverts commit 146f2a0.
alanwaketan
added a commit
that referenced
this pull request
Oct 25, 2023
jonb377
pushed a commit
that referenced
this pull request
Oct 31, 2023
ghpvnist
pushed a commit
to ghpvnist/pytorch-xla
that referenced
this pull request
Oct 31, 2023
* Don't set $TPU_LIBRARY_PATH during import * remove chekck for new env var so people don't use it
will-cromar
added a commit
that referenced
this pull request
Nov 13, 2023
will-cromar
added a commit
that referenced
this pull request
Nov 14, 2023
mbzomowski
pushed a commit
to mbzomowski-test-org/xla
that referenced
this pull request
Nov 16, 2023
* Don't set $TPU_LIBRARY_PATH during import * remove chekck for new env var so people don't use it
mbzomowski
pushed a commit
to mbzomowski-test-org/xla
that referenced
this pull request
Nov 16, 2023
…torch#5731) This reverts commit 146f2a0.
chunnienc
pushed a commit
to chunnienc/xla
that referenced
this pull request
Dec 14, 2023
* Don't set $TPU_LIBRARY_PATH during import * remove chekck for new env var so people don't use it
chunnienc
pushed a commit
to chunnienc/xla
that referenced
this pull request
Dec 14, 2023
…torch#5731) This reverts commit 146f2a0.
golechwierowicz
pushed a commit
that referenced
this pull request
Jan 12, 2024
* Don't set $TPU_LIBRARY_PATH during import * remove chekck for new env var so people don't use it
golechwierowicz
pushed a commit
that referenced
this pull request
Jan 12, 2024
bhavya01
pushed a commit
that referenced
this pull request
Apr 22, 2024
* Don't set $TPU_LIBRARY_PATH during import * remove chekck for new env var so people don't use it
bhavya01
pushed a commit
that referenced
this pull request
Apr 22, 2024
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
JAX and PyTorch/XLA are both overriding
$TPU_LIBRARY_PATHduring import, leading to confusing issues that depend on import order, e.g. #5625get_library_pathfunction to get that path from the libtpu package.$TPU_LIBRARY_PATHas an override for compatibilityCorresponding fix in JAX: jax-ml/jax@b81a3e1
The issue will be resolved when both JAX and PyTorch/XLA release new packages that include these two changes.
$PTXLA_TPU_LIBRARY_PATHis a temporary hack. I expect we'll have a better way to get the PJRT Plugin path in the near future.cc @jyingl3