-
Notifications
You must be signed in to change notification settings - Fork 584
fix(jax): fix NaN in sigmoid grad #4724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Fix deepmodeling#4718. See jax-ml/jax#15617 Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes a NaN issue in the sigmoid gradient by refactoring the sigmoid implementation for both JAX and non-JAX arrays.
- Introduces a new helper function, sigmoid_t, which conditionally uses jax.nn.sigmoid for JAX arrays.
- Updates the activation functions for "sigmoid" and "silu" to leverage sigmoid_t for consistent behavior.
Comments suppressed due to low confidence (1)
deepmd/dpmodel/utils/network.py:337
- [nitpick] Verify that using sigmoid_t in the 'silu' activation function produces consistent results with the previous implementation, especially for both JAX and non-JAX arrays.
return x * sigmoid_t(x)
📝 WalkthroughWalkthroughA new utility function Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant get_activation_fn
participant sigmoid_t
User->>get_activation_fn: Request activation function (e.g., "sigmoid", "silu", "silut")
get_activation_fn->>sigmoid_t: Compute sigmoid (for relevant activations)
sigmoid_t-->>get_activation_fn: Return sigmoid result
get_activation_fn-->>User: Return activation function output
Assessment against linked issues
Suggested labels
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
🧰 Additional context used🧬 Code Graph Analysis (1)deepmd/dpmodel/utils/network.py (3)
⏰ Context from checks skipped due to timeout of 90000ms (29)
🔇 Additional comments (4)
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## devel #4724 +/- ##
=======================================
Coverage 84.81% 84.81%
=======================================
Files 696 696
Lines 67264 67267 +3
Branches 3541 3540 -1
=======================================
+ Hits 57047 57050 +3
+ Misses 9085 9084 -1
- Partials 1132 1133 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Fix deepmodeling#4718. See jax-ml/jax#15617 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> - **New Features** - Improved activation functions with optimized and consistent sigmoid computation, including enhanced support for JAX arrays. - **Refactor** - Centralized sigmoid logic for better maintainability and compatibility across different array types. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> (cherry picked from commit 97633fb) Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Fix deepmodeling#4718. See jax-ml/jax#15617 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> - **New Features** - Improved activation functions with optimized and consistent sigmoid computation, including enhanced support for JAX arrays. - **Refactor** - Centralized sigmoid logic for better maintainability and compatibility across different array types. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> (cherry picked from commit 97633fb) Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Fix #4718.
See jax-ml/jax#15617
Summary by CodeRabbit