-
Notifications
You must be signed in to change notification settings - Fork 584
perf: use torch.embedding for type embedding
#4747
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR optimizes the type embedding lookup in the model’s forward pass by replacing a direct tensor indexing with the dedicated torch.embedding function to improve backward performance.
- Swaps manual indexing of the embedding tensor for the
torch.embeddingfunctional call inforward. - Achieves significant GPU backward-step speedup (32 ms → 0.5 ms).
📝 WalkthroughWalkthroughThe Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant TypeEmbedNet
participant torch
Caller->>TypeEmbedNet: forward(atype)
TypeEmbedNet->>TypeEmbedNet: self.embedding(atype.device)
TypeEmbedNet->>torch: embedding(embedding_tensor, atype)
torch-->>TypeEmbedNet: embedding_result
TypeEmbedNet-->>Caller: embedding_result
Note ⚡️ AI Code Reviews for VS Code, Cursor, WindsurfCodeRabbit now has a plugin for VS Code, Cursor and Windsurf. This brings AI code reviews directly in the code editor. Each commit is reviewed immediately, finding bugs before the PR is raised. Seamless context handoff to your AI code agent ensures that you can easily incorporate review feedback. Note ⚡️ Faster reviews with cachingCodeRabbit now supports caching for code and dependencies, helping speed up reviews. This means quicker feedback, reduced wait times, and a smoother review experience overall. Cached data is encrypted and stored securely. This feature will be automatically enabled for all accounts on May 16th. To opt out, configure 📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
⏰ Context from checks skipped due to timeout of 90000ms (29)
🔇 Additional comments (1)
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## devel #4747 +/- ##
==========================================
- Coverage 84.69% 84.69% -0.01%
==========================================
Files 697 697
Lines 67474 67473 -1
Branches 3540 3540
==========================================
- Hits 57147 57145 -2
+ Misses 9197 9196 -1
- Partials 1130 1132 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
The backward step of indexing operation is costly on GPU. Using dedicated
torch.embeddingmitigates this problem.Profiling results
Before: 32ms
---
After: 0.5ms
Summary by CodeRabbit