Skip to content

Set all_reduce_token to None when exiting#6247

Merged
JackCaoG merged 1 commit intopytorch:masterfrom
yitongh:fix_token
Jan 5, 2024
Merged

Set all_reduce_token to None when exiting#6247
JackCaoG merged 1 commit intopytorch:masterfrom
yitongh:fix_token

Conversation

@yitongh
Copy link
Copy Markdown
Contributor

@yitongh yitongh commented Jan 2, 2024

This pr fixes #6246.

@JackCaoG
Copy link
Copy Markdown
Collaborator

JackCaoG commented Jan 2, 2024

I am curious why does reduce_token matters when we exit the pytorch/xla?

@JackCaoG
Copy link
Copy Markdown
Collaborator

JackCaoG commented Jan 2, 2024

oh ok I saw #6246, my bad. Can you add the test case you mentioned in the issue to a separate test? An example would be https://github.com/pytorch/xla/blob/master/test/test_mp_collective_permute.py which will be run on

run_test "$CDIR/test_mp_collective_permute.py"

@yitongh
Copy link
Copy Markdown
Contributor Author

yitongh commented Jan 3, 2024

@JackCaoG I have added a test case, but I cannot run it. The program hangs after setting GPU_NUM_DEVICES=2. Other mp test cases also hang.

@JackCaoG
Copy link
Copy Markdown
Collaborator

JackCaoG commented Jan 3, 2024

let's see if CI will be able to run it then.

@JackCaoG
Copy link
Copy Markdown
Collaborator

JackCaoG commented Jan 3, 2024

CI seems to be happy, if you can fix the linter I can help you land it.

@JackCaoG JackCaoG self-requested a review January 3, 2024 23:30
@yitongh
Copy link
Copy Markdown
Contributor Author

yitongh commented Jan 4, 2024

@JackCaoG I have fixed the linter, sorry forgot this.

@JackCaoG JackCaoG merged commit f9c12fc into pytorch:master Jan 5, 2024
@yitongh yitongh deleted the fix_token branch January 5, 2024 01:56
@vanbasten23
Copy link
Copy Markdown
Collaborator

@JackCaoG I have added a test case, but I cannot run it. The program hangs after setting GPU_NUM_DEVICES=2. Other mp test cases also hang.

This seems to be a new issue. I have also observed the same #6320. Before this change, all the mp test cases don't hang. @yitongh @JackCaoG

@JackCaoG
Copy link
Copy Markdown
Collaborator

@vanbasten23 does this pr breaks multi device GPU training or just the test?

@JackCaoG
Copy link
Copy Markdown
Collaborator

@vanbasten23 let's revert his pr. @ManfeiBai can you also help reverting this in the 2.2 release branch? Release date is approaching, I don't want to take this risk.

Comment thread torch_xla/__init__.py


def _prepare_to_exit():
device = _XLAC._xla_get_default_device()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this always return the same device regardless of different processes in the pool?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will return the device that belong to current process, assuming each process only has one device.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I'm not sure why it will break other test cases...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the error is due to the ComputationClient has exited in atexit._run_exitfuncs when using xla_multiprocessing, while the client still exists when using torchrun. Therefore, in the xla_multiprocessing, this invocation will result in the creation of a new PjRtComputationClient, causing a hang.

So the better solution is setting the all reduce token in PrepareToExit, like this:

diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py
index 8d4997e28..d753f8f7c 100644
--- a/torch_xla/__init__.py
+++ b/torch_xla/__init__.py
@@ -148,8 +148,6 @@ _aws_ec2_inf_trn_init()


 def _prepare_to_exit():
-  device = _XLAC._xla_get_default_device()
-  _XLAC._set_all_reduce_token(device, None)
   _XLAC._prepare_to_exit()
   if int(os.environ.get('PT_XLA_DEBUG', '0')):
     _summarize_fn_tracker()
diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp
index 3281f0e9a..b255bb043 100644
--- a/torch_xla/csrc/init_python_bindings.cpp
+++ b/torch_xla/csrc/init_python_bindings.cpp
@@ -97,6 +97,8 @@ void PrepareToExit() {
   runtime::ComputationClient* client =
       runtime::GetComputationClientIfInitialized();
   if (client != nullptr) {
+    auto xla_device = GetDeviceOrCurrent("");
+    SetAllReduceToken(xla_device, nullptr);
     XLAGraphExecutor::Get()->WaitDeviceOps({});
   }
 }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

SIGSEGV when exiting the dataloader in the middle of training

4 participants