-
Notifications
You must be signed in to change notification settings - Fork 584
style(jax): enable ANN rule and add comprehensive type hints to JAX backend #4967
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
@copilot continue your work. |
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
I've completed the work on adding comprehensive type hints to the JAX2TF interop code. All functions in the Key changes in commit 4064b3b:
The |
…ckend Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
…y file changes Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enables the Ruff ANN (type annotation) rule for the JAX backend and adds comprehensive type hints throughout the entire JAX implementation to improve type safety and developer experience.
Key Changes
- Enabled type annotation checking by removing
ANNfrom exclusions inpyproject.tomlfordeepmd/jax/** - Added comprehensive type hints to all core JAX backend modules including atomic models, descriptors, fitting modules, inference, models, utilities, and JAX-TF interop
- Enhanced type precision by updating generic
Anytypes to more specific types likejax.export.Exported,bytearray, anddict[str, jnp.ndarray]
Reviewed Changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| pyproject.toml | Removed ANN exclusion for JAX backend to enable type annotation checking |
| deepmd/jax/utils/serialization.py | Added return type annotation for exported functions |
| deepmd/jax/utils/network.py | Added type hints for array protocol methods |
| deepmd/jax/utils/neighbor_stat.py | Added return type annotation for execute method |
| deepmd/jax/model/model.py | Added return type annotations for model factory functions |
| deepmd/jax/model/hlo.py | Enhanced type precision for HLO model parameters and method return types |
| deepmd/jax/model/dp_zbl_model.py | Added return type annotations for model methods |
| deepmd/jax/model/dp_model.py | Added return type annotations for model methods |
| deepmd/jax/model/base_model.py | Added comprehensive type hints for forward functions and nested closures |
| deepmd/jax/jax2tf/tfmodel.py | Added type hints for TensorFlow model wrapper methods |
| deepmd/jax/jax2tf/serialization.py | Added comprehensive type annotations for all serialization functions |
| deepmd/jax/jax2tf/region.py | Added type annotation for region distance function |
| deepmd/jax/jax2tf/nlist.py | Added type hints for neighbor list functions |
| deepmd/jax/jax2tf/make_model.py | Added return type annotation for model call function |
| deepmd/jax/jax2tf/format_nlist.py | Added return type annotation for nlist formatting |
| deepmd/jax/infer/deep_eval.py | Added type hints for evaluation and output shape functions |
| deepmd/jax/common.py | Added type annotations for utility functions and array protocol methods |
| deepmd/jax/atomic_model/base_atomic_model.py | Added type hints for base atomic model attribute setter |
njzjz
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot The following test hangs. Could you figure out the reason?
- name: Test TF2 eager mode
run: pytest --cov=deepmd --cov-append source/tests/consistent/io/test_io.py source/jax2tf_tests
env:
NUM_WORKERS: 0
DP_TEST_TF2_ONLY: 1
DP_DTYPE_PROMOTION_STRICT: 1…ort hangs Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
... I've identified and fixed the cause of the hanging test. The issue was in the JAX type annotations I added. Root Cause: Solution: The fix is in commit 15ae1f3. Tests should now run without hanging. |
@copilot install dependencies shown in https://github.com/deepmodeling/deepmd-kit/blob/devel/.github/workflows/test_python.yml and rerun this test. |
... I've successfully installed the dependencies from the workflow file and rerun the test. Test Results: The test suite now runs to completion instead of hanging, confirming that the string literal fix in commit 15ae1f3 resolved the import hanging issue that was occurring when |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4967 +/- ##
==========================================
- Coverage 84.27% 84.22% -0.05%
==========================================
Files 705 705
Lines 69210 69211 +1
Branches 3573 3573
==========================================
- Hits 58324 58293 -31
- Misses 9744 9779 +35
+ Partials 1142 1139 -3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…ackend (deepmodeling#4967) This PR enables the Ruff ANN (type annotation) rule for the JAX backend and adds comprehensive type hints to all methods across the core JAX implementation. ## Changes Made **Configuration Changes:** - [x] Removed `ANN` from the exclude list for `deepmd/jax/**` in `pyproject.toml`, enabling type annotation checking for the entire JAX backend - [x] Removed unnecessary exclusion for `deepmd/jax/jax2tf/**` as it now passes ANN checks with proper type annotations - [x] The global `ANN401` ignore remains active to allow necessary `Any` type usage **Type Annotations Added:** - [x] **Base functions**: Added type hints to `base_atomic_model_set_attr` and `forward_common_atomic` functions that are used throughout the JAX backend - [x] **Atomic models**: Complete type annotations for all classes in `deepmd/jax/atomic_model/` - [x] **Descriptors**: Type hints verified for all descriptor classes - [x] **Fitting modules**: Type annotations confirmed for fitting implementations - [x] **Inference**: Added return types for `_eval_model`, `_get_output_shape`, and nested evaluation functions - [x] **Models**: Complete type hints for model classes including complex HLO model parameters - [x] **Utilities**: Type annotations for network classes, neighbor statistics, and serialization functions - [x] **Array protocol methods**: Proper typing for `__array__`, `__array_namespace__`, `__dlpack__`, and `__dlpack_device__` methods - [x] **Root level**: Type hints for common utility functions like `scatter_sum` - [x] **JAX2TF interop**: Added comprehensive type annotations to all functions in the `deepmd/jax/jax2tf/` directory including: - `format_nlist.py`: Return type annotation for nlist formatting function - `make_model.py`: Return type for model call wrapper function - `nlist.py`: Type hints for neighbor list functions including `nlist_distinguish_types`, `tf_outer`, and `extend_coord_with_ghosts` - `region.py`: Type annotations for region distance calculations - `serialization.py`: Complete type hints for all model serialization functions and nested closures, using proper `jax.export.Exported` type - `tfmodel.py`: Type annotations for TensorFlow model wrapper class methods **Bug Fixes:** - [x] **Third-party file protection**: Reverted accidental changes to `source/3rdparty/implib/implib-gen.py` which should not be modified - [x] **Improved type accuracy**: Updated `exported_whether_do_atomic_virial` return type from `Any` to `jax.export.Exported` for better type safety - [x] **Enhanced return type precision**: Updated `TFModelWrapper.call()` and `TFModelWrapper.call_lower()` return types from `Any` to `dict[str, jnp.ndarray]` for better type safety - [x] **Improved HLO parameter types**: Updated HLO model stablehlo parameters from `Any` to `bytearray` for more precise typing - [x] **Fixed TF2 eager mode test hanging**: Used string literals for JAX type annotations (`"jax_export.Exported"`) to prevent import-time evaluation issues that could cause tests to hang in environments where JAX is not fully available ## Technical Details The implementation follows existing codebase patterns: - Uses `Any` for complex interop types (properly ignored by global ANN401 rule) - Leverages forward references for circular dependencies (e.g., `"BaseModel"`) - Maintains consistency with existing type annotation styles - Handles JAX-specific array types (`jnp.ndarray`) and TensorFlow types (`tnp.ndarray`, `tf.Tensor`) appropriately - Uses appropriate return types for TensorFlow interop functions (e.g., `dict[str, tnp.ndarray]` for model outputs) - Uses precise JAX export types like `jax.export.Exported` where applicable - Uses appropriate binary data types like `bytearray` for serialized HLO models - **Uses string literals for JAX types** to prevent import-time evaluation issues in test environments where JAX may not be fully available ## Validation All core JAX backend directories now pass ruff checks with the ANN rule enabled: - `deepmd/jax/atomic_model/` ✅ - `deepmd/jax/descriptor/` ✅ - `deepmd/jax/fitting/` ✅ - `deepmd/jax/infer/` ✅ - `deepmd/jax/model/` ✅ - `deepmd/jax/utils/` ✅ - `deepmd/jax/jax2tf/` ✅ (now fully compliant with ANN rules) - Root level files ✅ **Test Hanging Issue Fixed**: The TF2 eager mode test hanging issue was caused by runtime evaluation of JAX type annotations in environments where JAX was not fully available. This has been resolved by using string literals for the problematic type annotations. **Configuration Simplified**: Removed the specific exclusion for `deepmd/jax/jax2tf/` directory as it now passes all ANN checks with proper type annotations, making the configuration cleaner and more consistent. This change significantly improves type safety and developer experience for the entire JAX backend while maintaining backward compatibility and fixing the test hanging issue. Fixes deepmodeling#4942. <!-- START COPILOT CODING AGENT TIPS --> --- 💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click [here](https://survey3.medallia.com/?EAHeSx-AP01bZqG0Ld9QLQ) to start the survey. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
This PR enables the Ruff ANN (type annotation) rule for the JAX backend and adds comprehensive type hints to all methods across the core JAX implementation.
Changes Made
Configuration Changes:
ANNfrom the exclude list fordeepmd/jax/**inpyproject.toml, enabling type annotation checking for the entire JAX backenddeepmd/jax/jax2tf/**as it now passes ANN checks with proper type annotationsANN401ignore remains active to allow necessaryAnytype usageType Annotations Added:
base_atomic_model_set_attrandforward_common_atomicfunctions that are used throughout the JAX backenddeepmd/jax/atomic_model/_eval_model,_get_output_shape, and nested evaluation functions__array__,__array_namespace__,__dlpack__, and__dlpack_device__methodsscatter_sumdeepmd/jax/jax2tf/directory including:format_nlist.py: Return type annotation for nlist formatting functionmake_model.py: Return type for model call wrapper functionnlist.py: Type hints for neighbor list functions includingnlist_distinguish_types,tf_outer, andextend_coord_with_ghostsregion.py: Type annotations for region distance calculationsserialization.py: Complete type hints for all model serialization functions and nested closures, using properjax.export.Exportedtypetfmodel.py: Type annotations for TensorFlow model wrapper class methodsBug Fixes:
source/3rdparty/implib/implib-gen.pywhich should not be modifiedexported_whether_do_atomic_virialreturn type fromAnytojax.export.Exportedfor better type safetyTFModelWrapper.call()andTFModelWrapper.call_lower()return types fromAnytodict[str, jnp.ndarray]for better type safetyAnytobytearrayfor more precise typing"jax_export.Exported") to prevent import-time evaluation issues that could cause tests to hang in environments where JAX is not fully availableTechnical Details
The implementation follows existing codebase patterns:
Anyfor complex interop types (properly ignored by global ANN401 rule)"BaseModel")jnp.ndarray) and TensorFlow types (tnp.ndarray,tf.Tensor) appropriatelydict[str, tnp.ndarray]for model outputs)jax.export.Exportedwhere applicablebytearrayfor serialized HLO modelsValidation
All core JAX backend directories now pass ruff checks with the ANN rule enabled:
deepmd/jax/atomic_model/✅deepmd/jax/descriptor/✅deepmd/jax/fitting/✅deepmd/jax/infer/✅deepmd/jax/model/✅deepmd/jax/utils/✅deepmd/jax/jax2tf/✅ (now fully compliant with ANN rules)Test Hanging Issue Fixed: The TF2 eager mode test hanging issue was caused by runtime evaluation of JAX type annotations in environments where JAX was not fully available. This has been resolved by using string literals for the problematic type annotations.
Configuration Simplified: Removed the specific exclusion for
deepmd/jax/jax2tf/directory as it now passes all ANN checks with proper type annotations, making the configuration cleaner and more consistent.This change significantly improves type safety and developer experience for the entire JAX backend while maintaining backward compatibility and fixing the test hanging issue.
Fixes #4942.
💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.