-
Notifications
You must be signed in to change notification settings - Fork 584
Feat(pt): Support fitting_net input statistics. #4504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughThe pull request modifies the Changes
Possibly related PRs
Suggested reviewers
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms (20)
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
deepmd/pt/model/task/fitting.py (2)
437-440: Use a ternary operator for compactness.
Ruff suggests replacing theif callable(...)block with a ternary operator. This is a minor readability enhancement.- if callable(merged): - sampled = merged() - else: - sampled = merged + sampled = merged() if callable(merged) else merged🧰 Tools
🪛 Ruff (0.8.2)
437-440: Use ternary operator
sampled = merged() if callable(merged) else mergedinstead ofif-else-blockReplace
if-else-block withsampled = merged() if callable(merged) else merged(SIM108)
457-457: Implementaparamstatistics.
The TODO suggests you plan to handleaparamsimilarly tofparam. Let me know if you’d like help implementing that.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py(1 hunks)deepmd/pt/model/task/fitting.py(3 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py
437-440: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block
Replace if-else-block with sampled = merged() if callable(merged) else merged
(SIM108)
🔇 Additional comments (3)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
301-301: Consider verifying that fitting_net is defined before usage.
If fitting_net were ever None, invoking compute_input_stats would raise an exception. A quick check ensures safe execution.
deepmd/pt/model/task/fitting.py (2)
7-7: New import statements look good.
Thanks for adding Callable, Union, and DPPath; these additions enable clearer type hints and better modularity.
Also applies to: 43-45
416-436: Comprehensive documentation.
The docstring clearly explains the purpose and usage of compute_input_stats. This addition aligns with the PR objective to compute input statistics for fitting parameters.
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## devel #4504 +/- ##
=======================================
Coverage 84.58% 84.58%
=======================================
Files 680 680
Lines 64510 64544 +34
Branches 3539 3539
=======================================
+ Hits 54563 54595 +32
Misses 8807 8807
- Partials 1140 1142 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
deepmd/pt/model/task/fitting.py (1)
442-456:⚠️ Potential issueAdd protection against zero standard deviation.
The division by
fparam_stdcould lead to numerical instability orinfvalues when the standard deviation is zero or very small.Apply this diff to add protection:
fparam_std = torch.std(cat_data, axis=0) +epsilon = 1e-12 # Small constant to prevent division by zero +fparam_std = torch.where(fparam_std < epsilon, torch.tensor(epsilon, dtype=fparam_std.dtype, device=fparam_std.device), fparam_std) fparam_inv_std = 1.0 / fparam_std
🧹 Nitpick comments (2)
deepmd/pt/model/task/fitting.py (2)
456-457: Implement aparam statistics calculation.The TODO comment indicates missing implementation for aparam statistics, which is needed for complete functionality as indicated by the PR objectives.
Would you like me to generate the implementation for aparam statistics calculation? It would follow a similar pattern to the fparam implementation but handle the different dimensionality of aparam.
437-440: Simplify conditional assignment using ternary operator.The if-else block can be simplified using a ternary operator.
Apply this diff to improve code style:
-if callable(merged): - sampled = merged() -else: - sampled = merged +sampled = merged() if callable(merged) else merged🧰 Tools
🪛 Ruff (0.8.2)
437-440: Use ternary operator
sampled = merged() if callable(merged) else mergedinstead ofif-else-blockReplace
if-else-block withsampled = merged() if callable(merged) else merged(SIM108)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py(1 hunks)deepmd/pt/model/task/fitting.py(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pt/model/atomic_model/dp_atomic_model.py
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py
437-440: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block
Replace if-else-block with sampled = merged() if callable(merged) else merged
(SIM108)
⏰ Context from checks skipped due to timeout of 90000ms (17)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Analyze (python)
- GitHub Check: Test C++ (false)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test C++ (true)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Test Python (1, 3.9)
🔇 Additional comments (1)
deepmd/pt/model/task/fitting.py (1)
7-7: LGTM!The new imports are correctly organized and necessary for the added functionality.
Also applies to: 43-45
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
deepmd/pt/model/task/fitting.py (1)
99-102: Simplify the code using a ternary operator.The code can be more concise using a ternary operator.
- if callable(merged): - sampled = merged() - else: - sampled = merged + sampled = merged() if callable(merged) else merged🧰 Tools
🪛 Ruff (0.8.2)
99-102: Use ternary operator
sampled = merged() if callable(merged) else mergedinstead ofif-else-blockReplace
if-else-block withsampled = merged() if callable(merged) else merged(SIM108)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/pt/model/task/fitting.py(3 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py
99-102: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block
Replace if-else-block with sampled = merged() if callable(merged) else merged
(SIM108)
⏰ Context from checks skipped due to timeout of 90000ms (12)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
🔇 Additional comments (2)
deepmd/pt/model/task/fitting.py (2)
43-45: LGTM!The import statement is correctly placed and follows the existing import structure.
78-98: LGTM!The method signature and docstring are well-structured with clear parameter descriptions.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
deepmd/pt/model/task/fitting.py (2)
103-116:⚠️ Potential issueAdd protection against division by zero.
The standard deviation calculation needs protection against zero or near-zero values.
Apply this diff to handle potential division by zero:
fparam_std = torch.std(cat_data, dim=0, unbiased=False) +epsilon = 1e-12 +fparam_std = torch.where(fparam_std < epsilon, torch.tensor(epsilon, dtype=fparam_std.dtype, device=fparam_std.device), fparam_std) fparam_inv_std = 1.0 / fparam_std
118-140:⚠️ Potential issueAdd protection against division by zero in aparam calculations.
Similar to fparam, the aparam standard deviation calculation needs protection.
Apply this diff to handle potential division by zero:
aparam_std = torch.sqrt(sumv2 / sumn - (sumv / sumn) ** 2) +epsilon = 1e-12 +aparam_std = torch.where(aparam_std < epsilon, torch.tensor(epsilon, dtype=aparam_std.dtype, device=aparam_std.device), aparam_std) aparam_inv_std = 1.0 / aparam_std
🧹 Nitpick comments (2)
source/tests/pt/test_fitting_stat.py (1)
71-93: Add more test cases for edge conditions.While the current test case validates the basic functionality, consider adding tests for:
- Empty data
- Zero standard deviation
- Single frame/atom scenarios
Would you like me to generate additional test cases to improve coverage?
deepmd/pt/model/task/fitting.py (1)
98-101: Simplify conditional with ternary operator.The if-else block can be simplified.
-if callable(merged): - sampled = merged() -else: - sampled = merged +sampled = merged() if callable(merged) else merged🧰 Tools
🪛 Ruff (0.8.2)
98-101: Use ternary operator
sampled = merged() if callable(merged) else mergedinstead ofif-else-blockReplace
if-else-block withsampled = merged() if callable(merged) else merged(SIM108)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/model/task/fitting.py(3 hunks)source/tests/pt/test_fitting_stat.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py
98-101: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block
Replace if-else-block with sampled = merged() if callable(merged) else merged
(SIM108)
⏰ Context from checks skipped due to timeout of 90000ms (19)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Analyze (python)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (3)
source/tests/pt/test_fitting_stat.py (2)
17-41: LGTM! Well-structured test data generation.The
_make_fake_data_ptfunction is well-implemented with:
- Clear parameter handling
- Proper data structure generation
- Correct tensor conversion
44-55: LGTM! Robust statistical computation helpers.The
_brute_fparam_ptand_brute_aparam_ptfunctions provide a reliable "brute force" approach to compute statistics, serving as a good reference for validation.Also applies to: 58-69
deepmd/pt/model/task/fitting.py (1)
78-97: LGTM! Well-documented method signature.The method signature and docstring are clear and comprehensive.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
deepmd/pt/model/task/fitting.py (1)
101-101: 🛠️ Refactor suggestionAdd input validation for empty data.
The method should validate the input data before processing to ensure robustness.
sampled = merged + if not sampled: + raise ValueError("No data samples provided") + if self.numb_fparam > 0 and not all("fparam" in frame for frame in sampled): + raise ValueError("Missing 'fparam' in some data samples") + if self.numb_aparam > 0 and not all("aparam" in frame for frame in sampled): + raise ValueError("Missing 'aparam' in some data samples")🧰 Tools
🪛 Ruff (0.8.2)
98-101: Use ternary operator
sampled = merged() if callable(merged) else mergedinstead ofif-else-blockReplace
if-else-block withsampled = merged() if callable(merged) else merged(SIM108)
🧹 Nitpick comments (1)
deepmd/pt/model/task/fitting.py (1)
98-101: Simplify the code using a ternary operator.The code can be more concise by using a ternary operator.
- if callable(merged): - sampled = merged() - else: - sampled = merged + sampled = merged() if callable(merged) else merged🧰 Tools
🪛 Ruff (0.8.2)
98-101: Use ternary operator
sampled = merged() if callable(merged) else mergedinstead ofif-else-blockReplace
if-else-block withsampled = merged() if callable(merged) else merged(SIM108)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/model/task/fitting.py(3 hunks)source/tests/pt/test_fitting_stat.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- source/tests/pt/test_fitting_stat.py
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py
98-101: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block
Replace if-else-block with sampled = merged() if callable(merged) else merged
(SIM108)
⏰ Context from checks skipped due to timeout of 90000ms (19)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test C++ (false)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test C++ (true)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (2)
deepmd/pt/model/task/fitting.py (2)
43-45: LGTM!The import statement is correctly placed and follows the project's import style.
103-152: LGTM! Well-implemented statistical computations.The implementation correctly:
- Handles both frame and atomic parameters
- Prevents division by zero using epsilon
- Uses efficient tensor operations
- Properly handles data dimensions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
deepmd/pt/model/task/fitting.py (2)
95-98: Simplify the callable check using a ternary operator.The code can be more concise.
- if callable(merged): - sampled = merged() - else: - sampled = merged + sampled = merged() if callable(merged) else merged🧰 Tools
🪛 Ruff (0.8.2)
95-98: Use ternary operator
sampled = merged() if callable(merged) else mergedinstead ofif-else-blockReplace
if-else-block withsampled = merged() if callable(merged) else merged(SIM108)
122-151: Add a comment explaining the aparam statistics computation approach.The implementation is correct, but it would be helpful to add a comment explaining why the statistics are computed differently for atomic parameters compared to frame parameters.
# stat aparam if self.numb_aparam > 0: + # Computing statistics for atomic parameters requires accumulating sums + # across all systems due to varying number of atoms per frame sys_sumv = [] sys_sumv2 = [] sys_sumn = []
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py(1 hunks)deepmd/pt/model/task/fitting.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pt/model/atomic_model/dp_atomic_model.py
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py
95-98: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block
Replace if-else-block with sampled = merged() if callable(merged) else merged
(SIM108)
⏰ Context from checks skipped due to timeout of 90000ms (15)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test C++ (false)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Test C++ (true)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (3)
deepmd/pt/model/task/fitting.py (3)
7-7: LGTM!The
Callableimport is correctly added and properly used in type hints.
75-79: LGTM! Protection value aligns with team's decision.The protection value of 1e-2 was chosen based on team discussion, providing a good balance for numerical stability.
100-120: LGTM! Robust implementation of fparam statistics.The implementation correctly handles:
- Data reshaping and statistics computation
- Protection against numerical instability
- Device and dtype consistency
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
|
I do not know why |
|
for more information, see https://pre-commit.ci
d5417fb
Solve issue deepmodeling#4281 Support fitting_net statistics to calculate the mean value and standard deviation of `fparam`/`aparam`. So that `fparam`/`aparam` can be normalized automatically before concatenating to descriptor. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Summary by CodeRabbit - **New Features** - Introduced a method to compute input statistics, including mean and standard deviation for fitting parameters. - Enhanced functionality to compute additional statistics alongside existing ones. - Added new parameters for data protection statistics to model configurations. - Added unit tests to validate the energy fitting model's statistical computations. - **Bug Fixes** - Improved error handling for input data dimensions to ensure consistency. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Solve issue #4281
Support fitting_net statistics to calculate the mean value and standard deviation of
fparam/aparam. So thatfparam/aparamcan be normalized automatically before concatenating to descriptor.Summary by CodeRabbit
Summary by CodeRabbit
New Features
Bug Fixes