Skip to content

[train] Refactor AcceleratorSetupCallback to use before_init_train_context#56509

Merged
justinvyu merged 8 commits intoray-project:masterfrom
matthewdeng:accelerator-callback
Sep 19, 2025
Merged

[train] Refactor AcceleratorSetupCallback to use before_init_train_context#56509
justinvyu merged 8 commits intoray-project:masterfrom
matthewdeng:accelerator-callback

Conversation

@matthewdeng
Copy link
Copy Markdown
Contributor

@matthewdeng matthewdeng commented Sep 13, 2025

This fixes an issue in which the CUDA context is not properly configured during import deserialization.

RuntimeError: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":49, please report a bug to PyTorch. device=1, num_gpus=

Context

The relevant logic happens in WorkerGroup._start_impl:

for callable in self._callbacks:
args = callable.before_init_train_context(workers)
for arg, arg_values in args.items():
assert len(arg_values) == worker_group_context.num_workers, (
f"Callback {callable} returned {arg} with "
f"{len(arg_values)} values, expected {worker_group_context.num_workers}."
)
assert (
arg not in train_context_args
), f"Callback {callable} returned {arg} which is already set."
train_context_args[arg] = arg_values
self._init_train_context_on_workers(
workers, sync_actor, train_context_args
)
self._worker_group_state = worker_group_state_builder.build()
for callback in self._callbacks:
callback.after_worker_group_start(self)

The logic is as follows:

  1. WorkerGroupCallback.before_init_train_context
  2. WorkerGroup._init_train_context_on_workers
  3. WorkerGroupCallback.after_worker_group_start

Problem

The error occurs when CUDA_VISIBLE_DEVICES are not properly configured before torch.cuda initialization happens (when the TrainContext is initialized). torch.cuda.is_available() forces CUDA to be initialized, which reads the CUDA_VISIBLE_DEVICES at that time and locks that state in.

Here's the order of events of the original issue:

  1. The rank X RayTrainWorker actor is created as a blank slate with CUDA_VISIBLE_DEVICES=X, since Ray Core sets the environment variable automatically.
  2. The Controller calls RayTrainWorker.init_train_context as a remote task and passes in parameters such as TrainRunContext, which holds user code such as datasets and train_loop_config, which can depend on a user module dependency.
  3. The call to init_train_context deserializes all of the arguments on the RayTrainWorker, which triggers a bunch of imports, including torch and the user modules.
  4. If the user module calls torch.cuda.is_available() at the import level, then the CUDA initialization locks in the CUDA_VISIBLE_DEVICES=X state.
  5. Only after, we set CUDA_VISIBLE_DEVICES to the correct value. But at this point, it's inconsistent with the initialized CUDA state, so further calls to torch.cuda raise an assertion error.

Solution

The fix is to update the CUDA_VISIBLE_DEVICE sharing logic to be implemented in before_init_train_context instead of after_worker_group_start, so that any calls to torch.cuda will happen after the devices are set up properly.

Alternative Solution

There is another option to set the EXPERIMENTAL_NOSET_CUDA environment variable on the TrainWorkers, so that when they are first scheduled they are not restricted to just the single GPU device. However, this will also allow them to be exposed to more devices, which may not be desired if the user wants to restrict the GPU devices to those required for the training job. The solution implemented in this PR gives the least access while still solving the problem.

Repro

Run on a GPU node with multiple GPUs.

repro.py:

import os
os.environ["RAY_TRAIN_V2_ENABLED"] = "1"

from typing import List

import torch

import ray
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
from ray.train.v2._internal.execution.callback import WorkerGroupCallback
from ray.train.v2._internal.execution.worker_group import Worker

def train_func():
    ...

def init_torch():
    torch.cuda.is_available()

class InitTorchCallback(WorkerGroupCallback):
    def before_init_train_context(self, workers: List[Worker]):
        futures = []
        for worker in workers:
            futures.append(worker.execute_async(init_torch))
        ray.get(futures)
        return {}

trainer = TorchTrainer(
    train_func,
    scaling_config=ScalingConfig(num_workers=2, use_gpu=True),
    run_config=RunConfig(callbacks=[InitTorchCallback()])
)

trainer.fit()

RAY_TRAIN_V2_ENABLED=1 python repro.py

Failure

Traceback (most recent call last):
File "/home/ray/anaconda3/lib/python3.11/site-packages/torch/cuda/__init__.py", line 174, in _check_capability
capability = get_device_capability(d)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ray/anaconda3/lib/python3.11/site-packages/torch/cuda/__init__.py", line 430, in get_device_capability
prop = get_device_properties(device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ray/anaconda3/lib/python3.11/site-packages/torch/cuda/__init__.py", line 448, in get_device_properties
return _get_device_properties(device) # type: ignore[name-defined]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":50, please report a bug to PyTorch. device=�, num_gpus=�

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/train/torch/config.py", line 34, in __enter__
torch.cuda.set_device(device)
File "/home/ray/anaconda3/lib/python3.11/site-packages/torch/cuda/__init__.py", line 399, in set_device
torch._C._cuda_setDevice(device)
File "/home/ray/anaconda3/lib/python3.11/site-packages/torch/cuda/__init__.py", line 312, in _lazy_init
raise DeferredCudaCallError(msg) from e
torch.cuda.DeferredCudaCallError: CUDA call failed lazily at initialization with error: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":50, please report a bug to PyTorch. device=�, num_gpus=�

CUDA call was originally invoked at:

File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/_private/workers/default_worker.py", line 322, in <module>
worker.main_loop()
File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/_private/worker.py", line 1041, in main_loop
self.core_worker.run_task_loop()
File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/_private/worker.py", line 940, in deserialize_objects
return context.deserialize_objects(
File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/_private/serialization.py", line 586, in deserialize_objects
obj = self._deserialize_object(
File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/_private/serialization.py", line 423, in _deserialize_object
return self._deserialize_msgpack_data(
File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/_private/serialization.py", line 370, in _deserialize_msgpack_data
python_objects = self._deserialize_pickle5_data(
File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/_private/serialization.py", line 352, in _deserialize_pickle5_data
obj = pickle.loads(in_band)
File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/cloudpickle/cloudpickle.py", line 457, in subimport
__import__(name)
File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
File "<frozen importlib._bootstrap>", line 1126, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
File "<frozen importlib._bootstrap_external>", line 940, in exec_module
File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
File "/home/ray/anaconda3/lib/python3.11/site-packages/torch/__init__.py", line 1478, in <module>
_C._initExtension(manager_path())
File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
File "<frozen importlib._bootstrap_external>", line 940, in exec_module
File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
File "/home/ray/anaconda3/lib/python3.11/site-packages/torch/cuda/__init__.py", line 238, in <module>
_lazy_call(_check_capability)
File "/home/ray/anaconda3/lib/python3.11/site-packages/torch/cuda/__init__.py", line 235, in _lazy_call
_queued_calls.append((callable, traceback.format_stack()))

Repro Notes: The repro requires the following characteristics, and will not result in failure if any of these are false.

Line Reason
use_gpu=True Must be on GPUs.
num_workers >= 2 Must be multiple workers.
torch.cuda.is_available() Must trigger torch.cuda instantiation on import.

…ntext

Signed-off-by: Matthew Deng <matthew.j.deng@gmail.com>
@matthewdeng matthewdeng requested a review from a team as a code owner September 13, 2025 20:58
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors AcceleratorSetupCallback to use the before_init_train_context hook instead of after_worker_group_start. This change is crucial for correctly setting up the CUDA context on workers before the training context is initialized, which resolves an import deserialization issue with PyTorch. The refactoring correctly passes the list of workers down through _maybe_share_cuda_visible_devices, _share_cuda_visible_devices, and _share_accelerator_ids, and updates the remote execution calls accordingly. My feedback includes a minor style improvement for a docstring. As noted in the PR description, tests will need to be updated to reflect these changes, as the current tests for AcceleratorSetupCallback will likely fail.

Signed-off-by: Matthew Deng <matthew.j.deng@gmail.com>
Signed-off-by: Matthew Deng <matthew.j.deng@gmail.com>
@ray-gardener ray-gardener bot added the train Ray Train Related Issue label Sep 14, 2025
Copy link
Copy Markdown
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this fix could be related to the ordering of the TorchBackend setup and the cuda visible device sharing callback. But the AcceleratorSetupCallback already happens before the BackendSetupCallback based on the default callback ordering:

https://github.com/anyscale/rayturbo/blob/dbcb0742f85ca489b8d7d63e53640536ce009411/python/ray/train/v2/api/data_parallel_trainer.py#L172-L175

Another hypothesis I have is that the first torch import on the Worker actor happens on init_train_context, through the deserialization of something in the train run context depending on torch.

@justinvyu
Copy link
Copy Markdown
Contributor

justinvyu commented Sep 17, 2025

More minimal repro without Ray datasets confusion:

import os
os.environ["RAY_TRAIN_V2_ENABLED"]="1"

import ray
ray.init(ignore_reinit_error=True)

from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig

from helper import noop

def train_func():
    print(os.environ["CUDA_VISIBLE_DEVICES"])
    # Capturing in the train function scope only doesn't fail:
    print(noop)
    ...

trainer = TorchTrainer(
    train_func,
    # This fails:
    # train_loop_config={"asdf": noop},
    scaling_config=ScalingConfig(num_workers=2, use_gpu=True),
)

trainer.fit()

helper.py:

import torch
torch.cuda.is_available()

def noop(batch):
    return batch

@justinvyu
Copy link
Copy Markdown
Contributor

More minimal repro with just torch, mimicking what happens on the worker actor initialization:

import os
import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# Fresh RayTrainWorker state at the beginning, ex: CUDA_VISIBLE_DEVICES=0
print(f"BEFORE SETTING! {os.environ['CUDA_VISIBLE_DEVICES']=}")

# init_train_context(train_run_context) gets called and a bunch of imports
# happen on deserialization.
# A local module import with a torch.cuda.is_available() call initializes
# CUDA using this incorrect CUDA_VISIBLE_DEVICES, which "locks in" the invalid
# state and won't be re-initialized.
print(f"{torch.cuda.is_available()=}")

# AcceleratorSetupCallback updates the CUDA_VISIBLE_DEVICES AFTER the CUDA init
# has already happened. Ex: CUDA_VISIBLE_DEVICES=0,1
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
print(f"AFTER SETTING! {os.environ['CUDA_VISIBLE_DEVICES']=}")

# Setting the CUDA device now errors, probably an assertion that fails due to this
# mismatch between "locked in state" with old CUDA_VISIBLE_DEVICES and the updated CUDA_VISIBLE_DEVICES.
# torch.cuda.DeferredCudaCallError: CUDA call failed lazily at initialization with error: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":50, please report a bug to PyTorch. device=, num_gpus=
device = "cuda:0"
print(f"{device=}")
torch.cuda.set_device(device)

Signed-off-by: Matthew Deng <matthew.j.deng@gmail.com>
@justinvyu
Copy link
Copy Markdown
Contributor

Can you also describe why we don't fix with EXPERIMENTAL_NOSET_CUDA on the workers instead?

Signed-off-by: Matthew Deng <matthew.j.deng@gmail.com>
@matthewdeng matthewdeng added the go add ONLY when ready to merge, run all tests label Sep 19, 2025
Copy link
Copy Markdown
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LFGTM

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
@justinvyu justinvyu enabled auto-merge (squash) September 19, 2025 17:27
@justinvyu justinvyu merged commit 975f363 into ray-project:master Sep 19, 2025
6 checks passed
ZacAttack pushed a commit to ZacAttack/ray that referenced this pull request Sep 24, 2025
…_context` (ray-project#56509)

This fixes an issue in which the CUDA context is not properly configured
during import deserialization.

```
RuntimeError: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":49, please report a bug to PyTorch. device=1, num_gpus=
```

The fix is to update the CUDA_VISIBLE_DEVICE sharing logic to be implemented in before_init_train_context instead of after_worker_group_start, so that torch.cuda initialization happens after the environment variable is set up properly.

---------

Signed-off-by: Matthew Deng <matthew.j.deng@gmail.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-authored-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: zac <zac@anyscale.com>
elliot-barn pushed a commit that referenced this pull request Sep 24, 2025
…_context` (#56509)

This fixes an issue in which the CUDA context is not properly configured
during import deserialization.

```
RuntimeError: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":49, please report a bug to PyTorch. device=1, num_gpus=
```

The fix is to update the CUDA_VISIBLE_DEVICE sharing logic to be implemented in before_init_train_context instead of after_worker_group_start, so that torch.cuda initialization happens after the environment variable is set up properly.

---------

Signed-off-by: Matthew Deng <matthew.j.deng@gmail.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-authored-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
marcostephan pushed a commit to marcostephan/ray that referenced this pull request Sep 24, 2025
…_context` (ray-project#56509)

This fixes an issue in which the CUDA context is not properly configured
during import deserialization.

```
RuntimeError: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":49, please report a bug to PyTorch. device=1, num_gpus=
```

The fix is to update the CUDA_VISIBLE_DEVICE sharing logic to be implemented in before_init_train_context instead of after_worker_group_start, so that torch.cuda initialization happens after the environment variable is set up properly.

---------

Signed-off-by: Matthew Deng <matthew.j.deng@gmail.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-authored-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Marco Stephan <marco@magic.dev>
elliot-barn pushed a commit that referenced this pull request Sep 27, 2025
…_context` (#56509)

This fixes an issue in which the CUDA context is not properly configured
during import deserialization.

```
RuntimeError: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":49, please report a bug to PyTorch. device=1, num_gpus=
```

The fix is to update the CUDA_VISIBLE_DEVICE sharing logic to be implemented in before_init_train_context instead of after_worker_group_start, so that torch.cuda initialization happens after the environment variable is set up properly.

---------

Signed-off-by: Matthew Deng <matthew.j.deng@gmail.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-authored-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
dstrodtman pushed a commit to dstrodtman/ray that referenced this pull request Oct 6, 2025
…_context` (ray-project#56509)

This fixes an issue in which the CUDA context is not properly configured
during import deserialization.

```
RuntimeError: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":49, please report a bug to PyTorch. device=1, num_gpus=
```

The fix is to update the CUDA_VISIBLE_DEVICE sharing logic to be implemented in before_init_train_context instead of after_worker_group_start, so that torch.cuda initialization happens after the environment variable is set up properly.

---------

Signed-off-by: Matthew Deng <matthew.j.deng@gmail.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-authored-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Douglas Strodtman <douglas@anyscale.com>
justinyeh1995 pushed a commit to justinyeh1995/ray that referenced this pull request Oct 20, 2025
…_context` (ray-project#56509)

This fixes an issue in which the CUDA context is not properly configured
during import deserialization.

```
RuntimeError: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":49, please report a bug to PyTorch. device=1, num_gpus=
```

The fix is to update the CUDA_VISIBLE_DEVICE sharing logic to be implemented in before_init_train_context instead of after_worker_group_start, so that torch.cuda initialization happens after the environment variable is set up properly.

---------

Signed-off-by: Matthew Deng <matthew.j.deng@gmail.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-authored-by: Justin Yu <justinvyu@anyscale.com>
landscapepainter pushed a commit to landscapepainter/ray that referenced this pull request Nov 17, 2025
…_context` (ray-project#56509)

This fixes an issue in which the CUDA context is not properly configured
during import deserialization.

```
RuntimeError: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":49, please report a bug to PyTorch. device=1, num_gpus=
```

The fix is to update the CUDA_VISIBLE_DEVICE sharing logic to be implemented in before_init_train_context instead of after_worker_group_start, so that torch.cuda initialization happens after the environment variable is set up properly.

---------

Signed-off-by: Matthew Deng <matthew.j.deng@gmail.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-authored-by: Justin Yu <justinvyu@anyscale.com>
Future-Outlier pushed a commit to Future-Outlier/ray that referenced this pull request Dec 7, 2025
…_context` (ray-project#56509)

This fixes an issue in which the CUDA context is not properly configured
during import deserialization.

```
RuntimeError: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":49, please report a bug to PyTorch. device=1, num_gpus=
```

The fix is to update the CUDA_VISIBLE_DEVICE sharing logic to be implemented in before_init_train_context instead of after_worker_group_start, so that torch.cuda initialization happens after the environment variable is set up properly.

---------

Signed-off-by: Matthew Deng <matthew.j.deng@gmail.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-authored-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Future-Outlier <eric901201@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

go add ONLY when ready to merge, run all tests train Ray Train Related Issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants