Skip to content

[ROCm] Add ROCR_VISIBLE_DEVICE parsing logic to initialisation#140398

Closed
jataylo wants to merge 2 commits intopytorch:mainfrom
jataylo:amdsmi-fixes
Closed

[ROCm] Add ROCR_VISIBLE_DEVICE parsing logic to initialisation#140398
jataylo wants to merge 2 commits intopytorch:mainfrom
jataylo:amdsmi-fixes

Conversation

@jataylo
Copy link
Collaborator

@jataylo jataylo commented Nov 12, 2024

Fixes #140318

Currently ROCR_VISIBLE_DEVICES is not respected in the _parse_visible_devices logic, updating this code to ensure visible devices count returned does not exceed the number of devices specified in ROCR_VISIBLE_DEVICES, which restricts the number of GPUs available at runtime.

PyTorch 2.5

ROCR_VISIBLE_DEVICES="0" HIP_VISIBLE_DEVICES="0,1" 
print(torch.cuda.device_count())
>>> 2

PyTorch 2.0

ROCR_VISIBLE_DEVICES="0" HIP_VISIBLE_DEVICES="0,1" 
print(torch.cuda.device_count())
>>> 1

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @hongxiayang @naromero77amd

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140398

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 2757c1b with merge base 51e8a13 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/rocm Trigger "default" config CI on ROCm module: rocm AMD GPU support for Pytorch labels Nov 12, 2024
@jataylo
Copy link
Collaborator Author

jataylo commented Nov 12, 2024

Without this change even a simple code fails if we only specify ROCR_VISIBLE_DEVICES, exclusively when amdsmi path is enabled.

Repro:

import torch
for i in range(torch.cuda.device_count()):
   device_properties = torch.cuda.get_device_properties(i)
   uuid = device_properties.uuid
   print(f"GPU {i}: UUID = {uuid}")

Before:

ROCR_VISIBLE_DEVICES=0 python test.py
  
File "/root/test.py", line 4, in <module>
   device_properties = torch.cuda.get_device_properties(i)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/cuda/__init__.py", line 529, in get_device_properties
   return _get_device_properties(device)  # type: ignore[name-defined]
RuntimeError: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "/tmp/pytorch/aten/src/ATen/hip/HIPContext.cpp":50, please report a bug to PyTorch. device=1, num_gpus=

After:

Plain Text
ROCR_VISIBLE_DEVICES=0 python test.py
GPU 0: UUID = 39343831-6166-3238-3063-306432653239

@jataylo
Copy link
Collaborator Author

jataylo commented Nov 13, 2024

Closing in favour of #140320

@jataylo jataylo closed this Nov 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rocm Trigger "default" config CI on ROCm module: rocm AMD GPU support for Pytorch open source topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Visible devices are not respected on AMD systems

3 participants