fix(ROCm): remove dead code fix_rocm_triton_key_error#4125
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical issue where Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly identifies and fixes an issue where torch.compile was not being disabled on ROCm when triton_key is missing. The change to directly set torch._dynamo.config.disable = True is the right approach. My feedback includes a suggestion to improve the accuracy of the logging to reflect whether the disable operation was successful.
| # Runtime disable — this is the only mechanism that works after torch | ||
| # has already been imported (env vars are read once at import time). | ||
| try: | ||
| torch._dynamo.config.disable = True | ||
| except Exception: | ||
| pass | ||
|
|
||
| logger.info( | ||
| "Unsloth: ROCm detected and Triton lacks triton_key; " | ||
| "disabling torch.compile/Inductor to avoid backend crash." | ||
| "disabling torch.compile/Dynamo to avoid backend crash." | ||
| ) |
There was a problem hiding this comment.
The log message indicating that torch.compile is disabled is currently unconditional. If setting torch._dynamo.config.disable = True fails for some reason (e.g., due to future changes in PyTorch's internal API), the log message would be misleading as it would suggest the operation was successful when it wasn't. It's better to move the logging inside the try block to ensure it's only printed upon successful disabling.
# Runtime disable — this is the only mechanism that works after torch
# has already been imported (env vars are read once at import time).
try:
torch._dynamo.config.disable = True
logger.info(
"Unsloth: ROCm detected and Triton lacks triton_key; "
"disabling torch.compile/Dynamo to avoid backend crash."
)
except Exception:
pass352ecba to
8574d28
Compare
8574d28 to
1f6cda0
Compare
The function (introduced in unslothai#3923) assumed that the absence of `triton.runtime.triton_key` on ROCm means torch.compile will crash. Investigation shows this is incorrect: 1. `triton.runtime.triton_key` was renamed/removed in the ROCm Triton fork — it does not exist at that path. However, `triton.compiler.compiler.triton_key` (the path torch._inductor actually imports) EXISTS and works correctly on ROCm. 2. Both call-sites in torch._inductor (codecache.py and async_compile.py) already wrap the import in try/except, so even a genuinely missing triton_key would be handled gracefully. 3. Comprehensive testing on ROCm 7.1 + Triton 3.4.0 + gfx1100 confirms torch.compile works correctly for matmul, cross-entropy, RMSNorm, multi-layer transformer forward+backward, and LoRA — all without triton.runtime.triton_key. The original code was also ineffective (environment variables set after torch import have no effect on torch._dynamo config), so removing it has zero behavioral change on existing installations. Supersedes the compile-disable portion of unslothai#3923.
danielhanchen
left a comment
There was a problem hiding this comment.
Thank you! This works great!
unslothai#4125) The function (introduced in unslothai#3923) assumed that the absence of `triton.runtime.triton_key` on ROCm means torch.compile will crash. Investigation shows this is incorrect: 1. `triton.runtime.triton_key` was renamed/removed in the ROCm Triton fork — it does not exist at that path. However, `triton.compiler.compiler.triton_key` (the path torch._inductor actually imports) EXISTS and works correctly on ROCm. 2. Both call-sites in torch._inductor (codecache.py and async_compile.py) already wrap the import in try/except, so even a genuinely missing triton_key would be handled gracefully. 3. Comprehensive testing on ROCm 7.1 + Triton 3.4.0 + gfx1100 confirms torch.compile works correctly for matmul, cross-entropy, RMSNorm, multi-layer transformer forward+backward, and LoRA — all without triton.runtime.triton_key. The original code was also ineffective (environment variables set after torch import have no effect on torch._dynamo config), so removing it has zero behavioral change on existing installations. Supersedes the compile-disable portion of unslothai#3923.
Remove
fix_rocm_triton_key_error()from #3923 — the function is dead code with multiple issues:triton.runtime.triton_key(missing on ROCm), buttorch._inductorusestriton.compiler.compiler.triton_keywhich exists and works on ROCm.TORCHINDUCTOR_DISABLEwhich does not exist in PyTorch (102TORCHINDUCTOR_*vars exist, this is not one of them).TORCH_COMPILE_DISABLEis set afterimport torch;torch._dynamoreads it only once at import time, so it has no effect.torch.compileon working ROCm setups.torch.compileworks correctly on Triton 3.4.0 / gfx1100 — verified with matmul, CE loss, RMSNorm, transformer forward+backward, and LoRA backward.cc @danielhanchen