[RLlib] Issues: 17397, 17425, 16715, 17174. When on driver, Torch|TFPolicy should not use ray.get_gpu_ids() (b/c no GPUs assigned by ray).#17444
Conversation
…torch_policy_get_gpu_ids_error
rllib/policy/torch_policy.py
Outdated
| from ray.worker import global_worker | ||
| if global_worker.mode == 1: |
There was a problem hiding this comment.
global_worker can be None sometimes right?
also, can we use the "WORKER_MODE" enum?
|
Also added a few test cases and better error messages for tf and torch. |
…torch_policy_get_gpu_ids_error
rllib/policy/tf_policy.py
Outdated
| elif len(gpu_ids) < num_gpus: | ||
| raise ValueError( | ||
| "TFPolicy was not able to find enough GPU IDs! Found " | ||
| f"{gpu_ids}, but num_gpus={num_gpus}.") |
There was a problem hiding this comment.
I think we should use if len(self.devices) > 0 bellow. This condition fails on num_gpus=0.5. for i, _ in enumerate(...) if i < num_gpus can handle fractional GPUs.
There was a problem hiding this comment.
Not sure this would be necessary:
E.g.:
if num_gpus=0.5 and gpu_ids=["/physical_device:gpu:0"]
then this tf check would pass, no (and the error would not be raised)?
Also:
self.devices = [f"/gpu:{i}" for i, _ in enumerate(gpu_ids) if i < num_gpus]
would still generate a device list with exactly 1 gpu in it despite num_gpu being 0.5.
rllib/policy/torch_policy.py
Outdated
| elif len(gpu_ids) < num_gpus: | ||
| raise ValueError( | ||
| "TorchPolicy was not able to find enough GPU IDs! Found " | ||
| f"{gpu_ids}, but num_gpus={num_gpus}.") |
There was a problem hiding this comment.
I think we should use if len(self.devices) > 0 bellow. Same reason fractional GPUs.
rllib/evaluation/rollout_worker.py
Outdated
| if policy_config["framework"] in ["tf2", "tf", "tfe"]: | ||
| if len(get_tf_gpu_devices()) < num_gpus: | ||
| raise RuntimeError( | ||
| f"Not enough GPUs found for num_gpus={num_gpus}! " | ||
| f"Found only these IDs: {get_tf_gpu_devices()}.") | ||
| elif policy_config["framework"] == "torch": | ||
| if torch.cuda.device_count() < num_gpus: | ||
| raise RuntimeError( | ||
| f"Not enough GPUs found ({torch.cuda.device_count()}) " | ||
| f"for num_gpus={num_gpus}!") |
There was a problem hiding this comment.
Maybe add math.ceil(num_gpus) to handle fractional GPUs.
There was a problem hiding this comment.
Great catch @XuehaiPan! This would indeed fail for fractional numbers due to the range not handling floats. I will update.
Running some last tests now on a multi-GPU machine.
|
Running the new test case on local laptop (no GPUs) and 4-GPU machine looks all ok now.
|
| ray.worker._mode() != ray.worker.LOCAL_MODE and \ | ||
| not policy_config.get("_fake_gpus"): |
There was a problem hiding this comment.
nice. let's file a feature request for a better way of detecting local mode (on #api-changes)
There was a problem hiding this comment.
Will do, you don't like ray.worker._mode() != ray.worker.LOCAL_MODE? :D
…torch_policy_get_gpu_ids_error
|
Any idea on when this pull request will be merged/authorized? |
…torch_policy_get_gpu_ids_error
|
All tests are passing now, including the new one, which tests all combinations of |
| num_gpus = config["num_gpus"] | ||
| else: | ||
| num_gpus = config["num_gpus_per_worker"] | ||
| gpu_ids = list(range(torch.cuda.device_count())) |
There was a problem hiding this comment.
I don't think we can get all devices directly here.
Image that if we run the driver on the 5 devices node, and a remote worker is also scheduled to this node, all 5 devices are available to the remote worker, which I don't think make sense.
There was a problem hiding this comment.
The environment variable CUDA_VISIBLE_DEVICES is set for the remote worker. torch.cuda.device_count() will respect the CUDA_VISIBLE_DEVICES and return the number of CUDA visible GPUs.
Issues: 17397, 17425, 16715, 17174. When on driver, Torch|TFPolicy should not use
ray.get_gpu_ids()(b/c no GPUs assigned by ray).Issue #17397
Issue #17425
Issue #16715
Issue #17174
Why are these changes needed?
Related issue number
Closes #17397
Closes #17425
Closes #16715
Closes #17174
Checks
scripts/format.shto lint the changes in this PR.