Conversation
will-cromar
left a comment
There was a problem hiding this comment.
Can you use xr.world_size instead? I don't think we need to rename the function in xla_model
Lines 148 to 152 in 5b8e8e0
xla/torch_xla/core/xla_model.py Lines 129 to 132 in 5b8e8e0 _WORLD_SIZE will cause recompilation. @alanwaketan for insights.
|
|
You will know if there is a recompilation from the test. |
|
Hi @will-cromar , I tried functools.lru_cache and it crashes in multiprocess. I notice that if we use the functools.warps and assign attributes to the functions to be wrapped, it will cause crash. Probably lru_cache uses the func attributes in this case. I create |
| _ORDINAL = runtime.global_ordinal() | ||
|
|
||
|
|
||
| def run_once(func): |
There was a problem hiding this comment.
What is this run_once for? It's a neat idea, but I see you opted for global variables for world size and ordinal
There was a problem hiding this comment.
Without the @run_once for using_pjrt, the test test_mp_replication will fail with build the dynamic graph for dynamo compile.
There was a problem hiding this comment.
Do you know why that is? Is it because now all calls to get e.g. world_size go through functions wrapped in requires_pjrt, which in turn actually is checking an env var (device_type and _maybe_select_default_device)? Whereas before the call would have been stopped by xm.world_size.
|
|
||
|
|
||
| @run_once | ||
| def using_pjrt() -> bool: |
There was a problem hiding this comment.
Hah, this function also needs to get deprecated since I assume this is always True
There was a problem hiding this comment.
We still need to call __maybe_select_default_device(), this is the point why I call using_pjrt() only once.
will-cromar
left a comment
There was a problem hiding this comment.
General question: if run_once makes dynamo happy for using_pjrt, do you know why it does't work for world_size and global_ordinal?
| _ORDINAL = runtime.global_ordinal() | ||
|
|
||
|
|
||
| def run_once(func): |
There was a problem hiding this comment.
Do you know why that is? Is it because now all calls to get e.g. world_size go through functions wrapped in requires_pjrt, which in turn actually is checking an env var (device_type and _maybe_select_default_device)? Whereas before the call would have been stopped by xm.world_size.
Take unittest
The C API binding |
will-cromar
left a comment
There was a problem hiding this comment.
Thanks for the explanation! Filed a follow-up bug to clean up use_pjrt and requires_pjrt and fix a concrete usability issue at #7730
|
TPU CI failure seems relevant, can we fix forward or revert this pr? |
|
I forgot to update the TPU CI test. Let me make a follow up PR now. |
| ``` | ||
| new_rank = xm.get_ordinal() | ||
| world_size = xm.xrt_world_size() | ||
| world_size = xr.world_size() |
There was a problem hiding this comment.
@zpcore please add import torch_xla.runtime as xr to section 1. it feels this line comes from left field without the import in the documentation.
There was a problem hiding this comment.
xm.world_size() is also deprecated. They all point to the same thing. We should only use xr.world_size().
Deprecate
torch_xla.xla_model.xrt_world_sizeand usetorch_xla.runtime.world_sizeinstead.Add the
run_oncedecorator to functionruntime.using_pjrtsince we only need to run this once per process. This helps get rid of dynamo compilation issue with xm.all_reduce.