Skip to content

Respect ROCR_VISIBLE_DEVICES on AMD GPU device discovery#140320

Closed
tbennun wants to merge 5 commits intopytorch:mainfrom
tbennun:rocr-visible-devices
Closed

Respect ROCR_VISIBLE_DEVICES on AMD GPU device discovery#140320
tbennun wants to merge 5 commits intopytorch:mainfrom
tbennun:rocr-visible-devices

Conversation

@tbennun
Copy link
Contributor

@tbennun tbennun commented Nov 11, 2024

Fixes #140318

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140320

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b5f6863 with merge base 8bdcdae (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@tbennun
Copy link
Contributor Author

tbennun commented Nov 11, 2024

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Nov 11, 2024
@pruthvistony pruthvistony requested a review from jataylo November 12, 2024 04:50
@jataylo
Copy link
Collaborator

jataylo commented Nov 12, 2024

I left a more extended comment here #140318 (comment)

I'm not sure adding ROCR_VISIBLE_DEVICES to the _parse_visible_devices logic like this is a true fix here, I think this is a workaround. That being said there does appear to be issues with the current implementation more details in issue above. I think this may be due to us returning incorrect device count that no longer respects ROCR_VISIBLE_DEVICES

@jataylo
Copy link
Collaborator

jataylo commented Nov 12, 2024

Added a new PR here #140398

Which will ensure interoperability between the two visible devices

@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 13, 2024
Copy link
Collaborator

@jithunnair-amd jithunnair-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks right. @tbennun can you please post some examples showing what the updated logic gives as output of _parse_visible_devices()?

@tbennun
Copy link
Contributor Author

tbennun commented Nov 14, 2024

Looks right. @tbennun can you please post some examples showing what the updated logic gives as output of _parse_visible_devices()?

@jithunnair-amd Of course, essentially now PyTorch takes ROCR_VISIBLE_DEVICES into account the same way it would with CUDA_VISIBLE_DEVICES.

@jataylo already gave some examples in #140318 (comment) so I can start from there.

Would you like this in the form of a test with, e.g., ROCR_VISIBLE_DEVICES=0,3,4 and HIP_VISIBLE_DEVICES=1,2?
I will use this opportunity to fix the mypy issue too.

@tbennun
Copy link
Contributor Author

tbennun commented Nov 16, 2024

@jithunnair-amd @jataylo please re-review. I added tests, fixed the linter issue, and improved the behavior when both environment variables are given. Thanks!

@tbennun
Copy link
Contributor Author

tbennun commented Nov 25, 2024

@jithunnair-amd @jataylo Any updates on this PR? Thanks!

@jataylo
Copy link
Collaborator

jataylo commented Nov 26, 2024

LGTM, approved workflow to see if the UT passes.

@jeffdaily
Copy link
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 2, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / win-vs2019-cuda12.1-py3 / build

Details for Dev Infra team Raised by workflow job

@tbennun
Copy link
Contributor Author

tbennun commented Dec 3, 2024

The failures seem unrelated to this PR and related to the version of distutils.

@jataylo
Copy link
Collaborator

jataylo commented Dec 6, 2024

@pytorchbot rebase

@jataylo
Copy link
Collaborator

jataylo commented Dec 6, 2024

Rebasing to see if unrelated errors go away :)

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@tbennun tbennun deleted the rocr-visible-devices branch December 6, 2024 20:11
@huydhn
Copy link
Contributor

huydhn commented Dec 7, 2024

@pytorchbot revert -m 'Sorry for reverting your change but test_hip_device_count is failing in trunk after this land' -c nosignal

test_cuda.py::TestCuda::test_hip_device_count GH job link HUD commit link

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@tbennun your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Dec 7, 2024
…0320)"

This reverts commit add4a42.

Reverted #140320 on behalf of https://github.com/huydhn due to Sorry for reverting your change but test_hip_device_count is failing in trunk after this land ([comment](#140320 (comment)))
@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Dec 7, 2024
@tbennun
Copy link
Contributor Author

tbennun commented Dec 7, 2024

@huydhn Thanks. The test passed locally, I will check.

Strange that it didn't show up in the tests. Are rocm tests not running on PRs? If not, is there a way I can trigger them?

@huydhn
Copy link
Contributor

huydhn commented Dec 7, 2024

It could be an issue with our target determination when some tests were wrongly skipped. Let me double check that, but you could addci-no-td on your PR to opt-out and this will ensure that all tests are run on your reland PR

@huydhn
Copy link
Contributor

huydhn commented Dec 7, 2024

It doesn't looks like the case, I think you just need to add ciflow/rocm to run all ROCm tests then. Having both ci-no-td and ciflow/rocm labels just to be sure will do.

@tbennun
Copy link
Contributor Author

tbennun commented Dec 7, 2024

Turns out the rocm tests did run (and terminated successfully) before: https://github.com/pytorch/pytorch/actions/runs/12199014152/job/34033407881#step:15:961

pytorch-bot bot pushed a commit that referenced this pull request Dec 9, 2024
pytorch-bot bot pushed a commit that referenced this pull request Dec 9, 2024
…0320)"

This reverts commit add4a42.

Reverted #140320 on behalf of https://github.com/huydhn due to Sorry for reverting your change but test_hip_device_count is failing in trunk after this land ([comment](#140320 (comment)))
@jataylo
Copy link
Collaborator

jataylo commented Dec 9, 2024

@tbennun I wonder if it has anything to do with the CI jobs setting HIP_VISIBLE_DEVICES=0 before execution then us trying to set ROCR_VISIBLE_DEVICES=0,1,2 at runtime during the UT.

The application may only see a single GPU and then try to set device_count to 3. Might need some local testing to try and reproduce.

@tbennun
Copy link
Contributor Author

tbennun commented Dec 9, 2024

@jataylo it might be, which is why I made sure to remove the environment variable, like other tests do, in the reland PR (#142292). What’s weird is that it passed on the PR build but not on main afterwards. Regardless, the new PR should resolve potential index errors.

pytorchmergebot pushed a commit that referenced this pull request Dec 25, 2024
Reland of #140320 after failing test on trunk. Fixes potential environment clobbering in test, makes ROCr+HIP devices (if specified together) more robust to index errors.

Fixes #140318

Pull Request resolved: #142292
Approved by: https://github.com/jataylo, https://github.com/huydhn, https://github.com/jeffdaily

Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
Co-authored-by: Jeff Daily <jeff.daily@amd.com>
pytorchbot pushed a commit that referenced this pull request Dec 31, 2024
Reland of #140320 after failing test on trunk. Fixes potential environment clobbering in test, makes ROCr+HIP devices (if specified together) more robust to index errors.

Fixes #140318

Pull Request resolved: #142292
Approved by: https://github.com/jataylo, https://github.com/huydhn, https://github.com/jeffdaily

Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
Co-authored-by: Jeff Daily <jeff.daily@amd.com>
(cherry picked from commit c0d7106)
kit1980 pushed a commit that referenced this pull request Jan 6, 2025
Respect ROCR_VISIBLE_DEVICES on AMD GPU device discovery (#142292)

Reland of #140320 after failing test on trunk. Fixes potential environment clobbering in test, makes ROCr+HIP devices (if specified together) more robust to index errors.

Fixes #140318

Pull Request resolved: #142292
Approved by: https://github.com/jataylo, https://github.com/huydhn, https://github.com/jeffdaily

Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
Co-authored-by: Jeff Daily <jeff.daily@amd.com>
(cherry picked from commit c0d7106)

Co-authored-by: Tal Ben-Nun <tbennun@users.noreply.github.com>
jataylo pushed a commit to jataylo/pytorch that referenced this pull request Feb 12, 2025
)

Respect ROCR_VISIBLE_DEVICES on AMD GPU device discovery (pytorch#142292)

Reland of pytorch#140320 after failing test on trunk. Fixes potential environment clobbering in test, makes ROCr+HIP devices (if specified together) more robust to index errors.

Fixes pytorch#140318

Pull Request resolved: pytorch#142292
Approved by: https://github.com/jataylo, https://github.com/huydhn, https://github.com/jeffdaily

Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
Co-authored-by: Jeff Daily <jeff.daily@amd.com>
(cherry picked from commit c0d7106)

Co-authored-by: Tal Ben-Nun <tbennun@users.noreply.github.com>
(cherry picked from commit 23e390c)
pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request Feb 19, 2025
…covery (pytorch#144026) (#1895)

Respect ROCR_VISIBLE_DEVICES on AMD GPU device discovery (pytorch#142292)

Reland of pytorch#140320 after failing test on trunk. Fixes potential
environment clobbering in test, makes ROCr+HIP devices (if specified
together) more robust to index errors.

Fixes pytorch#140318

Pull Request resolved: pytorch#142292
Approved by: https://github.com/jataylo, https://github.com/huydhn,
https://github.com/jeffdaily

Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
Co-authored-by: Jeff Daily <jeff.daily@amd.com>
(cherry picked from commit c0d7106)

Co-authored-by: Tal Ben-Nun <tbennun@users.noreply.github.com>
(cherry picked from commit 23e390c)

Fixes #ISSUE_NUMBER

Co-authored-by: pytorchbot <soumith+bot@pytorch.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged open source Reverted topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Visible devices are not respected on AMD systems

9 participants