Skip to content

[rllib][regression] num_gpus=1 on a device with GPU uses CPU for Pytorch in 0.8.7 #10271

@Iravoo

Description

@Iravoo

What is the problem?

The following configuration of the PPOTrainer (with a custom model and environment) leads to an unexpected behavior since version 0.8.7:

    num_workers: 1
    num_gpus: 2
    num_gpus_per_worker: 2
    framework: "torch"

In version 0.8.6, the backward pass executed by the PPOTrainer works on the GPU as expected.
In version 0.8.7, this is no longer the case and the CPU is used instead.
This seems to be a regression since this PR: #9516.
The Rollout workers are not impacted, only the training.

Removing the call to ray.get_gpu_ids() from the following piece of code of torch_policy.py solves the issue:

    if torch.cuda.is_available():  # and ray.get_gpu_ids():
        self.device = torch.device("cuda")
    else:
        self.device = torch.device("cpu")

The original code is here:

if torch.cuda.is_available() and ray.get_gpu_ids():

The problem is that ray.get_gpu_ids() does not seem to list the GPUs IDs and returns an empty list. The bug seems to come from global_worker.core_worker.resource_ids() since "LOCAL_MODE" works as expected.

If my understanding is correct, the interactions with the environment are done in the rollout workers and the core worker deals with the training. Is that correct? The thing that I don't get however is why ray.get_gpu_ids() always returns an empty list in the core worker. This clearly prevents the core worker to use the GPUs for training which should not be the case.

OS: Linux
Python Version: 3.8
Ray Version: 0.8.7

Reproduction (REQUIRED)

  • I have verified my script runs in a clean environment and reproduces the issue.
  • I have verified the issue also occurs with the latest wheels.

Script to run:

import ray
from ray.rllib import agents

print("Ray Version:", ray.__version__)

if not ray.is_initialized():
    ray.init(num_gpus=2) # Skip or set to ignore if already called

config = {
    "gamma": 0.9,
    "lr": 1e-2,
    "num_workers": 1,
    "num_gpus": 2,
    "num_gpus_per_worker": 2,
    "train_batch_size": 1000,
    "model": {"fcnet_hiddens": [128, 128]},
    "framework": "torch",

}
trainer = agents.ppo.PPOTrainer(env="CartPole-v0", config=config)
results = trainer.train()

Code inserted in torch_policy.py to be able to know what's going on:

    print("DEBUG CUSTOM", "=" * 80)
    print(ray.get_gpu_ids(), self.device)

Output in 0.8.6:

Ray Version: 0.8.6
DEBUG CUSTOM ================================================================================
[] cuda

2020-08-23 13:19:12,489	WARNING worker.py:1047 -- The actor or task with ID ffffffffffffffff9d28cb170100 is pending and cannot currently be scheduled. It requires {CPU: 1.000000}, {GPU: 2.000000} for execution and {CPU: 1.000000}, {GPU: 2.000000} for placement, but this node only has remaining {node:192.168.1.20: 1.000000}, {CPU: 23.000000}, {memory: 63.232422 GiB}, {object_store_memory: 21.435547 GiB}. In total there are 0 pending tasks and 1 pending actors on this node. This is likely due to all cluster resources being claimed by actors. To resolve the issue, consider creating fewer actors or increase the resources available to this Ray cluster. You can ignore this message if this Ray cluster is expected to auto-scale.

(pid=190646) 2020-08-23 13:19:12,614	INFO (unknown file):0 -- gc.collect() freed 85 refs in 0.03185255825519562 seconds
(pid=190660) DEBUG CUSTOM ================================================================================
(pid=190660) [0, 1] cuda
(pid=190660) /pytorch/torch/csrc/utils/tensor_numpy.cpp:141: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program.

Output in 0.8.7:

Ray Version: 0.8.7

2020-08-23 13:27:47,126	INFO resource_spec.py:223 -- Starting Ray with 63.23 GiB memory available for workers and up to 31.09 GiB for objects. You can adjust these settings with ray.init(memory=<bytes>, object_store_memory=<bytes>).
2020-08-23 13:27:47,574	INFO services.py:1191 -- View the Ray dashboard at localhost:8265
2020-08-23 13:27:48,672	INFO trainer.py:630 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.

DEBUG CUSTOM ================================================================================
[] cpu
(pid=196632) DEBUG CUSTOM ================================================================================
(pid=196632) [1, 0] cuda

(pid=196632) /pytorch/torch/csrc/utils/tensor_numpy.cpp:141: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program.

Metadata

Metadata

Assignees

Labels

P1Issue that should be fixed within a few weeksbugSomething that is supposed to be working; but isn't

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions