-
Notifications
You must be signed in to change notification settings - Fork 584
perf: use einsum to calculate virial #4746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR optimizes the virial computation by replacing multiple matmul operations with a single einsum call, reducing runtime overhead especially for smaller batches and models.
- Switch PyTorch virial calc from
@ matmultotorch.einsum - Switch Paddle virial calc from
@ matmultopaddle.einsum
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| deepmd/pt/model/model/transform_output.py | Replace unsqueeze + @ matmul with torch.einsum for virial |
| deepmd/pd/model/model/transform_output.py | Replace unsqueeze + @ matmul with paddle.einsum for virial |
Comments suppressed due to low confidence (2)
deepmd/pt/model/model/transform_output.py:88
- Add a unit test to verify that the new
einsum-based virial matches the originalunsqueeze + @matmul output for various tensor shapes to ensure correctness.
extended_virial = torch.einsum("...ik,...ij->...ikj", extended_force, extended_coord)
deepmd/pd/model/model/transform_output.py:85
- [nitpick] Add a gradient consistency test to ensure the Paddle
einsumvariant produces equivalent backward results compared to the original matmul approach.
extended_virial = paddle.einsum(
📝 Walkthrough""" WalkthroughThe changes update the computation of the Changes
Suggested labels
Suggested reviewers
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
for more information, see https://pre-commit.ci
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## devel #4746 +/- ##
==========================================
- Coverage 84.69% 84.69% -0.01%
==========================================
Files 697 697
Lines 67474 67473 -1
Branches 3540 3540
==========================================
- Hits 57147 57145 -2
+ Misses 9197 9196 -1
- Partials 1130 1132 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
By profiling results, using
@ matmulcreates 7 matmul to complete this operation, which are time-consuming.einsumonly requires one matmul.This optimization benefits smaller batch size and model.
Profiling results on PyTorch
Summary by CodeRabbit