-
Notifications
You must be signed in to change notification settings - Fork 584
fix(jax): workaround for "xxTracer is not a valid JAX type" #4776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> (cherry picked from commit c62c356)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This pull request applies a workaround to resolve the "xxTracer is not a valid JAX type" error by explicitly using ellipsis indexing (i.e. “[...]”) on various array variables. The changes update multiple files to ensure that array inputs are correctly passed to underlying JAX‐compatible functions.
- Updated ellipsis indexing in matrix multiplications and bias additions in network operations
- Added ellipsis indexing for masks and coefficient extractions in descriptor and fitting functions
- Introduced ellipsis indexing to improve type consistency before numerical operations
Reviewed Changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| deepmd/dpmodel/utils/network.py | Added ellipsis indexing in matrix multiplication for activation function calls |
| deepmd/dpmodel/utils/exclude_mask.py | Added ellipsis indexing on type_mask slices |
| deepmd/dpmodel/fitting/general_fitting.py | Revised ellipsis indexing for fparam, aparam, case_embd, and bias_atom_e |
| deepmd/dpmodel/descriptor/se_t_tebd.py | Applied ellipsis indexing for mean and stddev in env_mat.call |
| deepmd/dpmodel/descriptor/se_t.py | Updated ellipsis indexing on davg and dstd |
| deepmd/dpmodel/descriptor/se_r.py | Updated ellipsis indexing on davg and dstd with flag |
| deepmd/dpmodel/descriptor/se_e2_a.py | Revised ellipsis indexing on davg and dstd |
| deepmd/dpmodel/descriptor/repformers.py | Applied ellipsis indexing for mean and stddev in env_mat.call |
| deepmd/dpmodel/descriptor/repflows.py | Updated ellipsis indexing for mean and stddev in edge env_mat.call |
| deepmd/dpmodel/descriptor/dpa1.py | Revised ellipsis indexing for mean and stddev in env_mat.call |
| deepmd/dpmodel/atomic_model/pairtab_atomic_model.py | Added ellipsis indexing on tab_data during coefficient extraction |
Comments suppressed due to low confidence (15)
deepmd/dpmodel/fitting/general_fitting.py:449
- [nitpick] Ensure that the ellipsis indexing on 'self.case_embd' does not alter its intended shape before tiling.
case_embd = xp.tile(xp.reshape(self.case_embd[...], [1, 1, -1]), [nf, nloc, 1])
deepmd/dpmodel/utils/network.py:262
- [nitpick] Ensure that using '[...]' indexing on 'self.w' and 'self.b' preserves their intended dimensions and behavior in all cases.
xp.matmul(x, self.w[...]) + self.b[...]
deepmd/dpmodel/utils/exclude_mask.py:56
- [nitpick] Confirm that the ellipsis indexing on 'self.type_mask' maintains the expected shape and behavior when used with xp.take.
xp.take(self.type_mask[...], xp.reshape(atype, [-1]), axis=0),
deepmd/dpmodel/utils/exclude_mask.py:135
- [nitpick] Verify that applying '[...]' indexing on 'self.type_mask' produces the correct reshaped output for the exclusion mask.
xp.take(self.type_mask[...], xp.reshape(type_ij, (-1,))),
deepmd/dpmodel/fitting/general_fitting.py:413
- [nitpick] Ensure that the ellipsis indexing on 'self.fparam_avg' and 'self.fparam_inv_std' does not affect the expected broadcasting and dimensions of 'fparam'.
fparam = (fparam - self.fparam_avg[...]) * self.fparam_inv_std[...]
deepmd/dpmodel/fitting/general_fitting.py:435
- [nitpick] Check that the use of '[...]' on 'self.aparam_avg' and 'self.aparam_inv_std' preserves the dimensional integrity for further operations on 'aparam'.
aparam = (aparam - self.aparam_avg[...]) * self.aparam_inv_std[...]
deepmd/dpmodel/fitting/general_fitting.py:487
- [nitpick] Verify that applying '[...]' to 'self.bias_atom_e' maintains the correct shape for subsequent use with xp.take.
xp.astype(self.bias_atom_e[...], outs.dtype),
deepmd/dpmodel/descriptor/se_t_tebd.py:739
- [nitpick] Confirm that the ellipsis indexing on 'self.mean' (and similarly on 'self.stddev') in the env_mat.call call retains the expected dimensionality.
self.mean[...],
deepmd/dpmodel/descriptor/se_t.py:355
- [nitpick] Validate that using '[...]' with 'self.davg' (and 'self.dstd') yields correct dimensions for the env_mat.call function.
self.davg[...],
deepmd/dpmodel/descriptor/se_r.py:379
- [nitpick] Ensure the ellipsis indexing on 'self.davg' and 'self.dstd' in the env_mat.call call (with the flag) does not impact the expected output shape.
self.davg[...],
deepmd/dpmodel/descriptor/se_e2_a.py:597
- [nitpick] Make sure that applying '[...]' on 'self.davg' (and 'self.dstd') leaves the array structure intact for the subsequent operations.
self.davg[...],
deepmd/dpmodel/descriptor/repformers.py:447
- [nitpick] Review that the ellipsis indexing on 'self.mean' and 'self.stddev' is applied consistently to support the env_mat.call interface.
self.mean[...],
deepmd/dpmodel/descriptor/repflows.py:478
- [nitpick] Ensure that using '[...]' indexing on 'self.mean' and 'self.stddev' in the call to env_mat_edge maintains the correct array shapes.
self.mean[...],
deepmd/dpmodel/descriptor/dpa1.py:957
- [nitpick] Confirm that the ellipsis indexing on 'self.mean' and 'self.stddev' does not change the expected behavior of the env_mat.call function.
self.mean[...],
deepmd/dpmodel/atomic_model/pairtab_atomic_model.py:296
- [nitpick] Verify that applying '[...]' to 'self.tab_data' does not alter the intended extraction of spline coefficients.
i_type, j_type, idx, self.tab_data[...], nspline
📝 WalkthroughWalkthroughThis change updates multiple internal method calls to explicitly use ellipsis slicing ( Changes
Possibly related PRs
Suggested labels
Suggested reviewers
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (11)
⏰ Context from checks skipped due to timeout of 90000ms (23)
🔇 Additional comments (15)
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## devel #4776 +/- ##
=======================================
Coverage 84.79% 84.79%
=======================================
Files 698 698
Lines 67746 67746
Branches 3540 3540
=======================================
+ Hits 57444 57445 +1
Misses 9171 9171
+ Partials 1131 1130 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary by CodeRabbit