-
Notifications
You must be signed in to change notification settings - Fork 584
Fix(pt): add comm_dict for zbl, linear, dipole, dos, polar model to fix bugs mentioned in issue #4906 #4908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes bugs mentioned in issue #4906 by adding the missing comm_dict parameter to the forward_lower methods across multiple model classes and ensuring proper parameter passing in the linear atomic model.
- Add
comm_dictparameter toforward_lowermethod signatures in 5 model classes - Pass
comm_dictparameter through toforward_common_lowermethod calls - Fix parameter passing in linear atomic model's
forward_atomicmethod
Reviewed Changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| deepmd/pt/model/model/polar_model.py | Add comm_dict parameter to forward_lower method and pass it to forward_common_lower |
| deepmd/pt/model/model/dp_zbl_model.py | Add comm_dict parameter to forward_lower method and pass it to forward_common_lower |
| deepmd/pt/model/model/dp_linear_model.py | Add comm_dict parameter to forward_lower method and pass it to forward_common_lower |
| deepmd/pt/model/model/dos_model.py | Add comm_dict parameter to forward_lower method and pass it to forward_common_lower |
| deepmd/pt/model/model/dipole_model.py | Add comm_dict parameter to forward_lower method and pass it to forward_common_lower |
| deepmd/pt/model/atomic_model/linear_atomic_model.py | Fix parameter passing by adding comm_dict to method call |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
📝 WalkthroughWalkthroughAdds an optional comm_dict parameter to multiple model forward_lower methods and threads it into forward_common_lower. Updates linear atomic model to pass comm_dict to sub-model forward_common_atomic calls. No other logic, return values, or public APIs (beyond method signatures) are changed. Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant Model as Model.forward_lower
participant Common as forward_common_lower
Caller->>Model: forward_lower(..., comm_dict?)
Note over Model: Validate/prepare inputs
Model->>Common: forward_common_lower(..., comm_dict)
Common-->>Model: outputs
Model-->>Caller: outputs
sequenceDiagram
participant Caller
participant Atomic as LinearEnergyAtomicModel.forward_atomic
participant SubA as SubModel A
participant SubB as SubModel B
Caller->>Atomic: forward_atomic(..., comm_dict?)
Atomic->>SubA: forward_common_atomic(..., comm_dict)
SubA-->>Atomic: atomic outputs
Atomic->>SubB: forward_common_atomic(..., comm_dict)
SubB-->>Atomic: atomic outputs
Atomic-->>Caller: aggregated energy
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/pt/model/model/dipole_model.py (1)
37-53: Fix no-op squeeze calls in translated_output_def (current code doesn’t store the squeezed views).
Tensor.squeeze(...)returns a new tensor; the current calls don’t assign the result, so shapes remain unsqueezed. This is likely to produce downstream shape mismatches.Apply this diff:
@@ - if self.do_grad_r("dipole"): - output_def["force"] = out_def_data["dipole_derv_r"] - output_def["force"].squeeze(-2) + if self.do_grad_r("dipole"): + output_def["force"] = out_def_data["dipole_derv_r"].squeeze(-2) @@ - if self.do_grad_c("dipole"): - output_def["virial"] = out_def_data["dipole_derv_c_redu"] - output_def["virial"].squeeze(-2) - output_def["atom_virial"] = out_def_data["dipole_derv_c"] - output_def["atom_virial"].squeeze(-3) + if self.do_grad_c("dipole"): + output_def["virial"] = out_def_data["dipole_derv_c_redu"].squeeze(-2) + output_def["atom_virial"] = out_def_data["dipole_derv_c"].squeeze(-3)Optional: add a lightweight unit/integration test to assert the returned shapes of
force,virial, andatom_virial.
🧹 Nitpick comments (16)
deepmd/pt/model/atomic_model/linear_atomic_model.py (3)
238-257: Docstring is missing the new comm_dict parameterforward_atomic added a comm_dict argument but the docstring hasn’t been updated. Please document it to avoid confusion for users and downstream bindings.
Apply this diff inside the existing docstring’s “Parameters” section:
aparam atomic parameter. (nframes, nloc, nda) + comm_dict + Optional dict[str, torch.Tensor]. A scratch communication dictionary + forwarded to sub-models’ forward_common_atomic. Implementations may + read/write entries to share intermediates; pass None to disable.
282-295: Consider namespacing comm_dict per sub-model to avoid key collisionsIf multiple sub-models write the same keys, they can clobber each other. If shared keys are not intended, namespace the dict per sub-model index. If shared keys are intended, ignore this.
Example change (only if you observe collisions in practice):
for i, model in enumerate(self.models): type_map_model = self.mapping_list[i].to(extended_atype.device) # apply bias to each individual model ener_list.append( model.forward_common_atomic( extended_coord, type_map_model[extended_atype], nlists_[i], mapping, fparam, aparam, - comm_dict=comm_dict, + comm_dict=( + None + if comm_dict is None + else comm_dict.setdefault(f"linear[{i}]", {}) + ), )["energy"] )To decide, please verify whether sub-models rely on shared comm_dict keys or not.
226-235: TorchScript typing: prefer typing.Dict for broader compatibilityThe annotation uses Optional[dict[str, torch.Tensor]]. Older TorchScript versions have been finicky with PEP 585 generics. If your CI still scripts models under older PyTorch, consider switching to Optional[Dict[str, torch.Tensor]] (from typing) across the codebase for forward_* methods.
Would you like me to prepare a follow-up patch that replaces dict[...] with Dict[...] and adds the necessary imports consistently?
deepmd/pt/model/model/dp_linear_model.py (2)
90-112: Add a brief docstring entry for comm_dict to forward_lowerHelps users of the TorchScript-exported API understand the new parameter.
@torch.jit.export def forward_lower( self, extended_coord, extended_atype, nlist, mapping: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, comm_dict: Optional[dict[str, torch.Tensor]] = None, ): + """ + Lower-level forward. + Parameters + ---------- + comm_dict + Optional dict[str, torch.Tensor] used as a communication scratchpad + and forwarded to forward_common_lower. + """ model_ret = self.forward_common_lower(
90-112: TorchScript typing consistencySame note as in the atomic model: if CI scripts modules with an older PyTorch, consider Optional[Dict[str, torch.Tensor]] for comm_dict and import Dict from typing.
deepmd/pt/model/model/polar_model.py (3)
87-97: Use named argument for mapping for consistencyOther files (e.g., dp_linear_model.py, dp_zbl_model.py) pass mapping as a named argument. Aligning improves readability and resilience to parameter order changes.
model_ret = self.forward_common_lower( extended_coord, extended_atype, nlist, - mapping, + mapping=mapping, fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, comm_dict=comm_dict, extra_nlist_sort=self.need_sorted_nlist_for_lower(), )
75-86: Optional: document comm_dict in forward_lowerIf you prefer method-level docs, mirror a short note about comm_dict here as well.
75-86: TorchScript typing consistencyConsider using Optional[Dict[str, torch.Tensor]] for broader TorchScript compatibility, as noted in other files.
deepmd/pt/model/model/dos_model.py (3)
93-103: Use named argument for mapping for consistencyMatch the style used elsewhere by naming mapping explicitly.
model_ret = self.forward_common_lower( extended_coord, extended_atype, nlist, - mapping, + mapping=mapping, fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, comm_dict=comm_dict, extra_nlist_sort=self.need_sorted_nlist_for_lower(), )
82-92: Optional: document comm_dict in forward_lowerAdd a brief note/docstring for comm_dict to aid users.
82-92: TorchScript typing consistencyConsider Optional[Dict[str, torch.Tensor]] to maximize TorchScript compatibility, mirroring other files.
deepmd/pt/model/model/dp_zbl_model.py (2)
90-101: Optional: document comm_dict in forward_lowerConsider a short docstring snippet for the new parameter for parity with other forward APIs.
90-101: TorchScript typing consistencyAs with other files, consider switching to Optional[Dict[str, torch.Tensor]] if your TorchScript environment benefits from typing.Dict.
deepmd/pt/model/model/dipole_model.py (3)
101-103: Propagating comm_dict looks correct; adjust type hint for TorchScript compatibility.Good: The new
comm_dictis threaded intoforward_common_lowerand keeps the parameter optional, preserving call-site compatibility.Risk:
@torch.jit.exportfunctions can be picky about annotations. Usingdict[str, torch.Tensor]may not be accepted by TorchScript in all environments;typing.Dict[str, torch.Tensor]is safer and consistent across PyTorch versions.Suggested change:
- Switch to
Dict[str, torch.Tensor]- Import
DictApply this diff:
@@ -from typing import ( - Optional, -) +from typing import ( + Optional, + Dict, +) @@ - comm_dict: Optional[dict[str, torch.Tensor]] = None, + comm_dict: Optional[Dict[str, torch.Tensor]] = None,Verification checklist:
- Ensure the definition of
forward_common_loweracceptscomm_dictwith a compatible type.- Run a TorchScript export path for DipoleModel to confirm no schema/type errors occur when scripting or tracing.
- Confirm other models in this PR use the same
Dict[...]convention for consistency.Also applies to: 111-113
91-103: Consider adding return type annotation and brief docstring for comm_dict.For parity with
forward(), annotateforward_lower’s return type and document the expected keys incomm_dict(e.g., required/optional keys, shapes, device). This aids users and static tooling.Example:
- def forward_lower( + def forward_lower( self, extended_coord, extended_atype, nlist, mapping: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, - comm_dict: Optional[dict[str, torch.Tensor]] = None, - ): + comm_dict: Optional[Dict[str, torch.Tensor]] = None, + ) -> dict[str, torch.Tensor]: + """ + Args: + comm_dict: Optional dictionary for inter-stage communication. + Expected keys (if any): e.g. "neighbor_mask", "env_info". + Values are torch.Tensors on the same device as inputs. + """
55-89: API symmetry: does forward() also need comm_dict?If bug fixes rely on
comm_dictonly in the lower path, this is fine. If upstream callers sometimes only useforward(), consider optionally addingcomm_dictthere for consistency across models. Otherwise, document thatcomm_dictis a lower-path-only feature.Please confirm whether the other models updated in this PR expose
comm_dictonly inforward_lower()or also inforward(), and that callers don’t need it at the higher level in dipole inference/training flows.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (6)
deepmd/pt/model/atomic_model/linear_atomic_model.py(1 hunks)deepmd/pt/model/model/dipole_model.py(2 hunks)deepmd/pt/model/model/dos_model.py(2 hunks)deepmd/pt/model/model/dp_linear_model.py(2 hunks)deepmd/pt/model/model/dp_zbl_model.py(2 hunks)deepmd/pt/model/model/polar_model.py(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cpu, cpu)
🔇 Additional comments (7)
deepmd/pt/model/atomic_model/linear_atomic_model.py (1)
282-295: Good: comm_dict is correctly propagated to sub-modelsPassing comm_dict through to each sub-model’s forward_common_atomic aligns the linear composition path with the rest of the models and should resolve missing cross-module communication for issue #4906. No functional regressions spotted around weighting or aggregation.
deepmd/pt/model/model/dp_linear_model.py (2)
100-111: Good: comm_dict added and correctly threaded to forward_common_lowerSignature change is backward-compatible (new trailing optional arg) and the value is forwarded as expected.
90-112: Allcomm_dictpropagation checks out—no changes neededI’ve verified across the PyTorch backend that:
- Every
forward_lowerdefinition includes thecomm_dict: Optional[...]parameter.- Inside each
forward_lower, the call toself.forward_common_lower(...)explicitly passescomm_dict=comm_dict.- All
forward_common_atomicdefinitions includecomm_dictin their signatures.- Within the lower‐level implementations,
forward_common_atomicis always invoked withcomm_dictwhen called from the atomic‐aware paths.Since the propagation of
comm_dictis consistent and complete, no further modifications are required here.deepmd/pt/model/model/polar_model.py (1)
85-97: Good: comm_dict added and forwarded to forward_common_lowerChange is localized and preserves backward compatibility.
deepmd/pt/model/model/dos_model.py (1)
91-103: Good: comm_dict added and forwarded to forward_common_lowerChange looks correct and non-breaking.
deepmd/pt/model/model/dp_zbl_model.py (1)
100-112: Good: comm_dict threaded through the ZBL lower pathSignature and forwarding look correct; backward compatibility is preserved.
deepmd/pt/model/model/dipole_model.py (1)
91-129: Ensurecomm_dictIs Uniformly Plumbed Through Lower‐Level ForwardsTo avoid any breaking changes when propagating the new
comm_dictargument, please verify across all model implementations:
- Confirm every
forward_common_lowerdefinition includes thecomm_dict: Optional[dict[str, torch.Tensor]] = Noneparameter in its signature.- Confirm each
forward_lowermethod signature likewise declarescomm_dictand passes it into its call toforward_common_lower.- Search for any external or legacy call sites of
forward_lowerandforward_common_lower(including in tests) that might still rely on the previous signatures—especially multi‐line signatures or invocations—and update them to passcomm_dictexplicitly.Recommended checks (run from the repo root):
# Verify comm_dict in all forward_common_lower definitions rg -U -P 'def\s+forward_common_lower.*comm_dict' -n --type=py # Verify comm_dict in all forward_lower definitions rg -U -P 'def\s+forward_lower.*comm_dict' -n --type=py # Find any calls to forward_common_lower without comm_dict rg -U -P 'forward_common_lower\s*\(' -n --type=py | grep -v 'comm_dict' # Find any external calls to forward_lower (e.g., in tests) that lack the comm_dict keyword rg -U -P 'forward_lower\s*\(' -n --type=py | grep -v 'comm_dict'Please run these checks to ensure no model variant or test is accidentally omitted.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## devel #4908 +/- ##
=======================================
Coverage 84.29% 84.29%
=======================================
Files 702 702
Lines 68665 68664 -1
Branches 3573 3573
=======================================
Hits 57882 57882
+ Misses 9643 9642 -1
Partials 1140 1140 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
njzjz
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to sync to other backends.
…ix bugs mentioned in issue deepmodeling#4906 (deepmodeling#4908) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - New Features - Added optional support to pass a communication dictionary through lower-level model computations across energy, dipole, DOS, polarization, and related models. This enables advanced workflows while remaining fully backward compatible. - Refactor - Standardized internal propagation of the communication dictionary across sub-models to ensure consistent behavior. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
Summary by CodeRabbit