-
Notifications
You must be signed in to change notification settings - Fork 584
fix: set fused option for Adam optimizer based on device type #4669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR disables the fused optimizer for CPU backends in the PyTorch training module.
- Conditionally disables the fused parameter based on DEVICE type.
- Updates the torch.optim.Adam instantiation to support CPU-specific behavior.
Comments suppressed due to low confidence (1)
deepmd/pt/train/training.py:599
- [nitpick] Consider using a clearer boolean expression for the 'fused' parameter, for example, 'fused=(DEVICE.type != "cpu")', to improve readability.
fused=False if DEVICE.type == "cpu" else True,
📝 WalkthroughWalkthroughThe changes modify the initialization of the Adam optimizer in the Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer as warm_up_linear()
participant Device as Device Checker
participant Optimizer as torch.optim.Adam
Trainer->>Device: Check device type
Device-->>Trainer: Return "cpu" or "gpu/other"
alt If device is CPU
Trainer->>Optimizer: Initialize Adam(fused=False, …)
else Otherwise
Trainer->>Optimizer: Initialize Adam(fused=True, …)
end
Assessment against linked issues
Possibly related PRs
Suggested labels
Suggested reviewers
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
deepmd/pt/train/training.py (1)
597-599: Good refactoring for device-dependent fused optimization.The changes correctly implement device-dependent fused optimization by disabling it for CPU devices. This is a solid approach since optimizer updates take significantly less time compared to forward/backward passes during training on CPUs.
The code could be slightly improved for readability by using a more idiomatic Python pattern.
- fused=False if DEVICE.type == "cpu" else True, + fused=not (DEVICE.type == "cpu"),🧰 Tools
🪛 Ruff (0.8.2)
599-599: Use
not ...instead ofFalse if ... else TrueReplace with
not ...(SIM211)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/pt/train/training.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/train/training.py
599-599: Use not ... instead of False if ... else True
Replace with not ...
(SIM211)
⏰ Context from checks skipped due to timeout of 90000ms (29)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Analyze (python)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test C++ (false)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test C++ (true)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Test Python (1, 3.9)
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## devel #4669 +/- ##
=======================================
Coverage 84.81% 84.81%
=======================================
Files 692 692
Lines 66360 66360
Branches 3539 3538 -1
=======================================
Hits 56282 56282
Misses 8937 8937
Partials 1141 1141 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Parsing pytorch version and determining if CPU fused optimizer is supported would be verbose, and the time for optimizer update time is usually much more smaller than forward/backward time when training on CPU. So, this PR disables using fused optimizer for CPU backend.
The issue only affects PyTorch backed.
Fix #4667
Summary by CodeRabbit
Bug Fixes
Refactor