Set all_reduce_token to None when exiting#6247
Conversation
|
I am curious why does reduce_token matters when we exit the pytorch/xla? |
|
oh ok I saw #6246, my bad. Can you add the test case you mentioned in the issue to a separate test? An example would be https://github.com/pytorch/xla/blob/master/test/test_mp_collective_permute.py which will be run on Line 231 in d7c4430 |
|
@JackCaoG I have added a test case, but I cannot run it. The program hangs after setting GPU_NUM_DEVICES=2. Other mp test cases also hang. |
|
let's see if CI will be able to run it then. |
|
CI seems to be happy, if you can fix the linter I can help you land it. |
|
@JackCaoG I have fixed the linter, sorry forgot this. |
|
@vanbasten23 does this pr breaks multi device GPU training or just the test? |
|
@vanbasten23 let's revert his pr. @ManfeiBai can you also help reverting this in the 2.2 release branch? Release date is approaching, I don't want to take this risk. |
|
|
||
|
|
||
| def _prepare_to_exit(): | ||
| device = _XLAC._xla_get_default_device() |
There was a problem hiding this comment.
Does this always return the same device regardless of different processes in the pool?
There was a problem hiding this comment.
I think it will return the device that belong to current process, assuming each process only has one device.
There was a problem hiding this comment.
Then I'm not sure why it will break other test cases...
There was a problem hiding this comment.
I think the error is due to the ComputationClient has exited in atexit._run_exitfuncs when using xla_multiprocessing, while the client still exists when using torchrun. Therefore, in the xla_multiprocessing, this invocation will result in the creation of a new PjRtComputationClient, causing a hang.
So the better solution is setting the all reduce token in PrepareToExit, like this:
diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py
index 8d4997e28..d753f8f7c 100644
--- a/torch_xla/__init__.py
+++ b/torch_xla/__init__.py
@@ -148,8 +148,6 @@ _aws_ec2_inf_trn_init()
def _prepare_to_exit():
- device = _XLAC._xla_get_default_device()
- _XLAC._set_all_reduce_token(device, None)
_XLAC._prepare_to_exit()
if int(os.environ.get('PT_XLA_DEBUG', '0')):
_summarize_fn_tracker()
diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp
index 3281f0e9a..b255bb043 100644
--- a/torch_xla/csrc/init_python_bindings.cpp
+++ b/torch_xla/csrc/init_python_bindings.cpp
@@ -97,6 +97,8 @@ void PrepareToExit() {
runtime::ComputationClient* client =
runtime::GetComputationClientIfInitialized();
if (client != nullptr) {
+ auto xla_device = GetDeviceOrCurrent("");
+ SetAllReduceToken(xla_device, nullptr);
XLAGraphExecutor::Get()->WaitDeviceOps({});
}
}
This pr fixes #6246.