-
Notifications
You must be signed in to change notification settings - Fork 584
fix(jax): set default_matmul_precision to tensorfloat32
#4726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
See jax-ml/jax#24909. Without setting this flag, the precision will become very low, which seems to be a bug. See https://docs.jax.dev/en/latest/_autosummary/jax.default_matmul_precision.html for what this option is. Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR addresses an issue with the JAX matmul precision by configuring the default value to "tensorfloat32" to prevent unintentional low precision.
- Sets the default matmul precision flag to "tensorfloat32"
- Adds a reference to the related GitHub issue for context
📝 WalkthroughWalkthroughA configuration update was made in the JAX environment setup to set the default matrix multiplication precision to "tensorfloat32". This change involves a single line addition with an explanatory comment referencing a related JAX issue. No other logic or public interfaces were altered. Changes
Suggested reviewers
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
deepmd/jax/env.py (1)
15-15: Nit: correct the GitHub issue URL
The comment refers tohttps://github.com/jax-ml/jax/issues/24909, but the official JAX repository is undergoogle/jax. Please update the link tohttps://github.com/google/jax/issues/24909to avoid confusion.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/jax/env.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (29)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Analyze (python)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
🔇 Additional comments (1)
deepmd/jax/env.py (1)
15-16: Explicitly set JAX matmul precision to tensorfloat32
This change aligns with the JAX documentation and resolves issue #24909, ensuring that GPU matrix multiplications use higher‐precision tensorfloat32 by default.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## devel #4726 +/- ##
=======================================
Coverage 84.81% 84.81%
=======================================
Files 696 696
Lines 67264 67265 +1
Branches 3541 3541
=======================================
+ Hits 57047 57048 +1
Misses 9085 9085
Partials 1132 1132 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…ling#4726) See jax-ml/jax#24909. Without setting this flag, the precision will become very low, which seems to be a bug (the documentation says the GPU uses tensorfloat32 or float32, but the default behavior seems wrong...). See https://docs.jax.dev/en/latest/_autosummary/jax.default_matmul_precision.html for what this option is. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Chores** - Updated environment configuration to set default matrix multiplication precision to "tensorfloat32" for improved performance with JAX. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> (cherry picked from commit 17da9d2)
…ling#4726) See jax-ml/jax#24909. Without setting this flag, the precision will become very low, which seems to be a bug (the documentation says the GPU uses tensorfloat32 or float32, but the default behavior seems wrong...). See https://docs.jax.dev/en/latest/_autosummary/jax.default_matmul_precision.html for what this option is. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Chores** - Updated environment configuration to set default matrix multiplication precision to "tensorfloat32" for improved performance with JAX. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> (cherry picked from commit 17da9d2)
See jax-ml/jax#24909. Without setting this flag, the precision will become very low, which seems to be a bug (the documentation says the GPU uses tensorfloat32 or float32, but the default behavior seems wrong...).
See https://docs.jax.dev/en/latest/_autosummary/jax.default_matmul_precision.html for what this option is.
Summary by CodeRabbit