Skip to content

Conversation

@OutisLi
Copy link
Collaborator

@OutisLi OutisLi commented Aug 23, 2025

Summary by CodeRabbit

  • New Features
    • Added optional support to pass a communication dictionary through lower-level model computations across energy, dipole, DOS, polarization, and related models. This enables advanced workflows while remaining fully backward compatible.
  • Refactor
    • Standardized internal propagation of the communication dictionary across sub-models to ensure consistent behavior.

Copilot AI review requested due to automatic review settings August 23, 2025 07:41
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes bugs mentioned in issue #4906 by adding the missing comm_dict parameter to the forward_lower methods across multiple model classes and ensuring proper parameter passing in the linear atomic model.

  • Add comm_dict parameter to forward_lower method signatures in 5 model classes
  • Pass comm_dict parameter through to forward_common_lower method calls
  • Fix parameter passing in linear atomic model's forward_atomic method

Reviewed Changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated no comments.

Show a summary per file
File Description
deepmd/pt/model/model/polar_model.py Add comm_dict parameter to forward_lower method and pass it to forward_common_lower
deepmd/pt/model/model/dp_zbl_model.py Add comm_dict parameter to forward_lower method and pass it to forward_common_lower
deepmd/pt/model/model/dp_linear_model.py Add comm_dict parameter to forward_lower method and pass it to forward_common_lower
deepmd/pt/model/model/dos_model.py Add comm_dict parameter to forward_lower method and pass it to forward_common_lower
deepmd/pt/model/model/dipole_model.py Add comm_dict parameter to forward_lower method and pass it to forward_common_lower
deepmd/pt/model/atomic_model/linear_atomic_model.py Fix parameter passing by adding comm_dict to method call

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 23, 2025

📝 Walkthrough

Walkthrough

Adds an optional comm_dict parameter to multiple model forward_lower methods and threads it into forward_common_lower. Updates linear atomic model to pass comm_dict to sub-model forward_common_atomic calls. No other logic, return values, or public APIs (beyond method signatures) are changed.

Changes

Cohort / File(s) Summary
Thread comm_dict through forward_lower
deepmd/pt/model/model/dipole_model.py, deepmd/pt/model/model/dos_model.py, deepmd/pt/model/model/dp_linear_model.py, deepmd/pt/model/model/dp_zbl_model.py, deepmd/pt/model/model/polar_model.py
Add optional parameter comm_dict: Optional[dict[str, torch.Tensor]] to forward_lower and pass it to forward_common_lower. No other control-flow changes.
Atomic model comm_dict forwarding
deepmd/pt/model/atomic_model/linear_atomic_model.py
Forward comm_dict to each sub-model’s forward_common_atomic within forward_atomic.

Sequence Diagram(s)

sequenceDiagram
  participant Caller
  participant Model as Model.forward_lower
  participant Common as forward_common_lower

  Caller->>Model: forward_lower(..., comm_dict?)
  Note over Model: Validate/prepare inputs
  Model->>Common: forward_common_lower(..., comm_dict)
  Common-->>Model: outputs
  Model-->>Caller: outputs
Loading
sequenceDiagram
  participant Caller
  participant Atomic as LinearEnergyAtomicModel.forward_atomic
  participant SubA as SubModel A
  participant SubB as SubModel B

  Caller->>Atomic: forward_atomic(..., comm_dict?)
  Atomic->>SubA: forward_common_atomic(..., comm_dict)
  SubA-->>Atomic: atomic outputs
  Atomic->>SubB: forward_common_atomic(..., comm_dict)
  SubB-->>Atomic: atomic outputs
  Atomic-->>Caller: aggregated energy
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Possibly related PRs

Suggested labels

Python

Suggested reviewers

  • njzjz
  • iProzd

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/pt/model/model/dipole_model.py (1)

37-53: Fix no-op squeeze calls in translated_output_def (current code doesn’t store the squeezed views).

Tensor.squeeze(...) returns a new tensor; the current calls don’t assign the result, so shapes remain unsqueezed. This is likely to produce downstream shape mismatches.

Apply this diff:

@@
-        if self.do_grad_r("dipole"):
-            output_def["force"] = out_def_data["dipole_derv_r"]
-            output_def["force"].squeeze(-2)
+        if self.do_grad_r("dipole"):
+            output_def["force"] = out_def_data["dipole_derv_r"].squeeze(-2)
@@
-        if self.do_grad_c("dipole"):
-            output_def["virial"] = out_def_data["dipole_derv_c_redu"]
-            output_def["virial"].squeeze(-2)
-            output_def["atom_virial"] = out_def_data["dipole_derv_c"]
-            output_def["atom_virial"].squeeze(-3)
+        if self.do_grad_c("dipole"):
+            output_def["virial"] = out_def_data["dipole_derv_c_redu"].squeeze(-2)
+            output_def["atom_virial"] = out_def_data["dipole_derv_c"].squeeze(-3)

Optional: add a lightweight unit/integration test to assert the returned shapes of force, virial, and atom_virial.

🧹 Nitpick comments (16)
deepmd/pt/model/atomic_model/linear_atomic_model.py (3)

238-257: Docstring is missing the new comm_dict parameter

forward_atomic added a comm_dict argument but the docstring hasn’t been updated. Please document it to avoid confusion for users and downstream bindings.

Apply this diff inside the existing docstring’s “Parameters” section:

         aparam
             atomic parameter. (nframes, nloc, nda)
 
+        comm_dict
+            Optional dict[str, torch.Tensor]. A scratch communication dictionary
+            forwarded to sub-models’ forward_common_atomic. Implementations may
+            read/write entries to share intermediates; pass None to disable.

282-295: Consider namespacing comm_dict per sub-model to avoid key collisions

If multiple sub-models write the same keys, they can clobber each other. If shared keys are not intended, namespace the dict per sub-model index. If shared keys are intended, ignore this.

Example change (only if you observe collisions in practice):

         for i, model in enumerate(self.models):
             type_map_model = self.mapping_list[i].to(extended_atype.device)
             # apply bias to each individual model
             ener_list.append(
                 model.forward_common_atomic(
                     extended_coord,
                     type_map_model[extended_atype],
                     nlists_[i],
                     mapping,
                     fparam,
                     aparam,
-                    comm_dict=comm_dict,
+                    comm_dict=(
+                        None
+                        if comm_dict is None
+                        else comm_dict.setdefault(f"linear[{i}]", {})
+                    ),
                 )["energy"]
             )

To decide, please verify whether sub-models rely on shared comm_dict keys or not.


226-235: TorchScript typing: prefer typing.Dict for broader compatibility

The annotation uses Optional[dict[str, torch.Tensor]]. Older TorchScript versions have been finicky with PEP 585 generics. If your CI still scripts models under older PyTorch, consider switching to Optional[Dict[str, torch.Tensor]] (from typing) across the codebase for forward_* methods.

Would you like me to prepare a follow-up patch that replaces dict[...] with Dict[...] and adds the necessary imports consistently?

deepmd/pt/model/model/dp_linear_model.py (2)

90-112: Add a brief docstring entry for comm_dict to forward_lower

Helps users of the TorchScript-exported API understand the new parameter.

     @torch.jit.export
     def forward_lower(
         self,
         extended_coord,
         extended_atype,
         nlist,
         mapping: Optional[torch.Tensor] = None,
         fparam: Optional[torch.Tensor] = None,
         aparam: Optional[torch.Tensor] = None,
         do_atomic_virial: bool = False,
         comm_dict: Optional[dict[str, torch.Tensor]] = None,
     ):
+        """
+        Lower-level forward.
+        Parameters
+        ----------
+        comm_dict
+            Optional dict[str, torch.Tensor] used as a communication scratchpad
+            and forwarded to forward_common_lower.
+        """
         model_ret = self.forward_common_lower(

90-112: TorchScript typing consistency

Same note as in the atomic model: if CI scripts modules with an older PyTorch, consider Optional[Dict[str, torch.Tensor]] for comm_dict and import Dict from typing.

deepmd/pt/model/model/polar_model.py (3)

87-97: Use named argument for mapping for consistency

Other files (e.g., dp_linear_model.py, dp_zbl_model.py) pass mapping as a named argument. Aligning improves readability and resilience to parameter order changes.

         model_ret = self.forward_common_lower(
             extended_coord,
             extended_atype,
             nlist,
-            mapping,
+            mapping=mapping,
             fparam=fparam,
             aparam=aparam,
             do_atomic_virial=do_atomic_virial,
             comm_dict=comm_dict,
             extra_nlist_sort=self.need_sorted_nlist_for_lower(),
         )

75-86: Optional: document comm_dict in forward_lower

If you prefer method-level docs, mirror a short note about comm_dict here as well.


75-86: TorchScript typing consistency

Consider using Optional[Dict[str, torch.Tensor]] for broader TorchScript compatibility, as noted in other files.

deepmd/pt/model/model/dos_model.py (3)

93-103: Use named argument for mapping for consistency

Match the style used elsewhere by naming mapping explicitly.

         model_ret = self.forward_common_lower(
             extended_coord,
             extended_atype,
             nlist,
-            mapping,
+            mapping=mapping,
             fparam=fparam,
             aparam=aparam,
             do_atomic_virial=do_atomic_virial,
             comm_dict=comm_dict,
             extra_nlist_sort=self.need_sorted_nlist_for_lower(),
         )

82-92: Optional: document comm_dict in forward_lower

Add a brief note/docstring for comm_dict to aid users.


82-92: TorchScript typing consistency

Consider Optional[Dict[str, torch.Tensor]] to maximize TorchScript compatibility, mirroring other files.

deepmd/pt/model/model/dp_zbl_model.py (2)

90-101: Optional: document comm_dict in forward_lower

Consider a short docstring snippet for the new parameter for parity with other forward APIs.


90-101: TorchScript typing consistency

As with other files, consider switching to Optional[Dict[str, torch.Tensor]] if your TorchScript environment benefits from typing.Dict.

deepmd/pt/model/model/dipole_model.py (3)

101-103: Propagating comm_dict looks correct; adjust type hint for TorchScript compatibility.

Good: The new comm_dict is threaded into forward_common_lower and keeps the parameter optional, preserving call-site compatibility.

Risk: @torch.jit.export functions can be picky about annotations. Using dict[str, torch.Tensor] may not be accepted by TorchScript in all environments; typing.Dict[str, torch.Tensor] is safer and consistent across PyTorch versions.

Suggested change:

  • Switch to Dict[str, torch.Tensor]
  • Import Dict

Apply this diff:

@@
-from typing import (
-    Optional,
-)
+from typing import (
+    Optional,
+    Dict,
+)
@@
-        comm_dict: Optional[dict[str, torch.Tensor]] = None,
+        comm_dict: Optional[Dict[str, torch.Tensor]] = None,

Verification checklist:

  • Ensure the definition of forward_common_lower accepts comm_dict with a compatible type.
  • Run a TorchScript export path for DipoleModel to confirm no schema/type errors occur when scripting or tracing.
  • Confirm other models in this PR use the same Dict[...] convention for consistency.

Also applies to: 111-113


91-103: Consider adding return type annotation and brief docstring for comm_dict.

For parity with forward(), annotate forward_lower’s return type and document the expected keys in comm_dict (e.g., required/optional keys, shapes, device). This aids users and static tooling.

Example:

-    def forward_lower(
+    def forward_lower(
         self,
         extended_coord,
         extended_atype,
         nlist,
         mapping: Optional[torch.Tensor] = None,
         fparam: Optional[torch.Tensor] = None,
         aparam: Optional[torch.Tensor] = None,
         do_atomic_virial: bool = False,
-        comm_dict: Optional[dict[str, torch.Tensor]] = None,
-    ):
+        comm_dict: Optional[Dict[str, torch.Tensor]] = None,
+    ) -> dict[str, torch.Tensor]:
+        """
+        Args:
+            comm_dict: Optional dictionary for inter-stage communication.
+                Expected keys (if any): e.g. "neighbor_mask", "env_info".
+                Values are torch.Tensors on the same device as inputs.
+        """

55-89: API symmetry: does forward() also need comm_dict?

If bug fixes rely on comm_dict only in the lower path, this is fine. If upstream callers sometimes only use forward(), consider optionally adding comm_dict there for consistency across models. Otherwise, document that comm_dict is a lower-path-only feature.

Please confirm whether the other models updated in this PR expose comm_dict only in forward_lower() or also in forward(), and that callers don’t need it at the higher level in dipole inference/training flows.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 6dac4f9 and f4f0926.

📒 Files selected for processing (6)
  • deepmd/pt/model/atomic_model/linear_atomic_model.py (1 hunks)
  • deepmd/pt/model/model/dipole_model.py (2 hunks)
  • deepmd/pt/model/model/dos_model.py (2 hunks)
  • deepmd/pt/model/model/dp_linear_model.py (2 hunks)
  • deepmd/pt/model/model/dp_zbl_model.py (2 hunks)
  • deepmd/pt/model/model/polar_model.py (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Test C++ (false)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
🔇 Additional comments (7)
deepmd/pt/model/atomic_model/linear_atomic_model.py (1)

282-295: Good: comm_dict is correctly propagated to sub-models

Passing comm_dict through to each sub-model’s forward_common_atomic aligns the linear composition path with the rest of the models and should resolve missing cross-module communication for issue #4906. No functional regressions spotted around weighting or aggregation.

deepmd/pt/model/model/dp_linear_model.py (2)

100-111: Good: comm_dict added and correctly threaded to forward_common_lower

Signature change is backward-compatible (new trailing optional arg) and the value is forwarded as expected.


90-112: All comm_dict propagation checks out—no changes needed

I’ve verified across the PyTorch backend that:

  • Every forward_lower definition includes the comm_dict: Optional[...] parameter.
  • Inside each forward_lower, the call to self.forward_common_lower(...) explicitly passes comm_dict=comm_dict.
  • All forward_common_atomic definitions include comm_dict in their signatures.
  • Within the lower‐level implementations, forward_common_atomic is always invoked with comm_dict when called from the atomic‐aware paths.

Since the propagation of comm_dict is consistent and complete, no further modifications are required here.

deepmd/pt/model/model/polar_model.py (1)

85-97: Good: comm_dict added and forwarded to forward_common_lower

Change is localized and preserves backward compatibility.

deepmd/pt/model/model/dos_model.py (1)

91-103: Good: comm_dict added and forwarded to forward_common_lower

Change looks correct and non-breaking.

deepmd/pt/model/model/dp_zbl_model.py (1)

100-112: Good: comm_dict threaded through the ZBL lower path

Signature and forwarding look correct; backward compatibility is preserved.

deepmd/pt/model/model/dipole_model.py (1)

91-129: Ensure comm_dict Is Uniformly Plumbed Through Lower‐Level Forwards

To avoid any breaking changes when propagating the new comm_dict argument, please verify across all model implementations:

  • Confirm every forward_common_lower definition includes the comm_dict: Optional[dict[str, torch.Tensor]] = None parameter in its signature.
  • Confirm each forward_lower method signature likewise declares comm_dict and passes it into its call to forward_common_lower.
  • Search for any external or legacy call sites of forward_lower and forward_common_lower (including in tests) that might still rely on the previous signatures—especially multi‐line signatures or invocations—and update them to pass comm_dict explicitly.

Recommended checks (run from the repo root):

# Verify comm_dict in all forward_common_lower definitions
rg -U -P 'def\s+forward_common_lower.*comm_dict' -n --type=py

# Verify comm_dict in all forward_lower definitions
rg -U -P 'def\s+forward_lower.*comm_dict' -n --type=py

# Find any calls to forward_common_lower without comm_dict
rg -U -P 'forward_common_lower\s*\(' -n --type=py | grep -v 'comm_dict'

# Find any external calls to forward_lower (e.g., in tests) that lack the comm_dict keyword
rg -U -P 'forward_lower\s*\(' -n --type=py | grep -v 'comm_dict'

Please run these checks to ensure no model variant or test is accidentally omitted.

@codecov
Copy link

codecov bot commented Aug 23, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 84.29%. Comparing base (6dac4f9) to head (f4f0926).
⚠️ Report is 70 commits behind head on devel.

Additional details and impacted files
@@           Coverage Diff           @@
##            devel    #4908   +/-   ##
=======================================
  Coverage   84.29%   84.29%           
=======================================
  Files         702      702           
  Lines       68665    68664    -1     
  Branches     3573     3573           
=======================================
  Hits        57882    57882           
+ Misses       9643     9642    -1     
  Partials     1140     1140           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@iProzd iProzd linked an issue Aug 23, 2025 that may be closed by this pull request
Copy link
Member

@njzjz njzjz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to sync to other backends.

@njzjz njzjz added this pull request to the merge queue Aug 23, 2025
Merged via the queue into deepmodeling:devel with commit 191759b Aug 23, 2025
60 checks passed
@OutisLi OutisLi deleted the fix/zbl branch September 17, 2025 08:08
ChiahsinChu pushed a commit to ChiahsinChu/deepmd-kit that referenced this pull request Dec 17, 2025
…ix bugs mentioned in issue deepmodeling#4906 (deepmodeling#4908)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- New Features
- Added optional support to pass a communication dictionary through
lower-level model computations across energy, dipole, DOS, polarization,
and related models. This enables advanced workflows while remaining
fully backward compatible.
- Refactor
- Standardized internal propagation of the communication dictionary
across sub-models to ensure consistent behavior.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] parameter miss match for ZBL model in python and cpp source codes

3 participants