-
Notifications
You must be signed in to change notification settings - Fork 584
(fix)Make the weighted avarange fit for all kinds of systems #4593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
📝 WalkthroughWalkthroughThis pull request modifies the testing functionalities within the DeePMD framework. In Changes
Suggested labels
Suggested reviewers
✨ Finishing Touches
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
source/tests/pt/test_weighted_avg.py (1)
70-101: Well-structured test cases with good coverage!The test cases effectively validate different component combinations. Consider making the variable names more descriptive for better readability.
Consider renaming variables to be more descriptive:
- expected_mae_f = (2*3 +1*3 )/(3+3) + expected_force_mae = (2*3 + 1*3)/(3+3) - expected_mae_v = (3*5 +1*5 )/(5+5) + expected_virial_mae = (3*5 + 1*5)/(5+5)deepmd/entrypoints/test.py (3)
331-333: Good addition of component flags!Consider using more descriptive variable names for better clarity.
- find_energy = test_data.get('find_energy') - find_force = test_data.get('find_force') - find_virial = test_data.get('find_virial') + has_energy_component = test_data.get('find_energy') + has_force_component = test_data.get('find_force') + has_virial_component = test_data.get('find_virial')
146-167: Good selective error collection logic!Consider adding error handling for missing components.
Add error handling for missing components:
err_part = {} + if test_data.get('find_energy') is None: + log.warning("Energy component flag not found in test data") + if test_data.get('find_force') is None: + log.warning("Force component flag not found in test data") + if test_data.get('find_virial') is None: + log.warning("Virial component flag not found in test data") if find_energy == 1: err_part['mae_e'] = err['mae_e']
459-470: Good conditional logging implementation!Consider adding debug logging for better troubleshooting.
Add debug logging:
+ log.debug(f"Processing system with energy={find_energy}, force={find_force}, virial={find_virial}") if find_force == 1: if not out_put_spin: log.info(f"Force MAE : {mae_f:e} eV/A")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/entrypoints/test.py(12 hunks)source/tests/pt/test_weighted_avg.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (19)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Analyze (python)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Analyze (javascript-typescript)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (3)
source/tests/pt/test_weighted_avg.py (3)
7-31: Well-structured implementation for handling different error metrics!The function effectively handles different combinations of energy, force, and virial metrics with clean conditional logic and proper error collection.
33-39: Clean baseline implementation!The function provides a good reference point for comparing weighted averages with and without filtering.
43-67: Comprehensive test coverage for energy-only metrics!The test case effectively validates:
- Correct weighted average calculations
- Proper handling of energy-only systems
- Edge cases with force and virial metrics
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
source/tests/pt/test_weighted_avg.py (2)
11-36: Consider adding type hints and docstring.The function lacks type hints and documentation which would improve code maintainability and help users understand the expected input/output format.
Apply this diff to add type hints and docstring:
-def test(all_sys): +def test(all_sys: list[tuple[dict, int, int, int]]) -> dict[str, tuple[float, int]]: + """Calculate weighted average of errors with selective inclusion of metrics. + + Args: + all_sys: List of tuples containing (error_dict, find_energy, find_force, find_virial) + where error_dict contains the error metrics + + Returns: + Dictionary mapping error names to tuples of (error_value, sample_size) + """ err_coll = []
38-44: Add type hints and docstring to test_ori function.Similar to the
testfunction, this function would benefit from type hints and documentation.Apply this diff:
-def test_ori(all_sys): +def test_ori(all_sys: list[tuple[dict, int, int, int]]) -> dict[str, tuple[float, int]]: + """Calculate weighted average of all errors without selective inclusion. + + Args: + all_sys: List of tuples containing (error_dict, find_energy, find_force, find_virial) + where error_dict contains the error metrics + + Returns: + Dictionary mapping error names to tuples of (error_value, sample_size) + """ err_coll = []deepmd/entrypoints/test.py (2)
331-334: Consider using dictionary get() with default values.The code uses
get()without default values which could return None. Consider providing default values for safety.Apply this diff:
- find_energy = test_data.get("find_energy") - find_force = test_data.get("find_force") - find_virial = test_data.get("find_virial") + find_energy = test_data.get("find_energy", 0) + find_force = test_data.get("find_force", 0) + find_virial = test_data.get("find_virial", 0)
744-747: Use f-strings instead of % operator for string formatting.The code uses the older % operator for string formatting. Consider using f-strings for better readability and maintainability.
Apply this diff:
- detail_path.with_suffix(".dos.out.%.d" % ii), + detail_path.with_suffix(f".dos.out.{ii:d}"), frame_output, - header="%s - %.d: data_dos pred_dos" % (system, ii), + header=f"{system} - {ii:d}: data_dos pred_dos",- detail_path.with_suffix(".ados.out.%.d" % ii), + detail_path.with_suffix(f".ados.out.{ii:d}"), frame_output, - header="%s - %.d: data_ados pred_ados" % (system, ii), + header=f"{system} - {ii:d}: data_ados pred_ados",Also applies to: 758-761
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/entrypoints/test.py(10 hunks)source/tests/pt/test_weighted_avg.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (19)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (javascript-typescript)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
🔇 Additional comments (6)
source/tests/pt/test_weighted_avg.py (3)
47-113: LGTM! Well-structured test case for energy-only scenario.The test case thoroughly validates the weighted average calculation for energy metrics, including proper assertions for both the main function and original implementation.
114-171: LGTM! Comprehensive test case for energy and force metrics.The test case effectively validates the combined energy and force calculations, with appropriate assertions to verify the differences between the two implementations.
172-229: LGTM! Complete test coverage for all components.The test case provides thorough validation of all metrics (energy, force, virial) with appropriate assertions.
deepmd/entrypoints/test.py (3)
137-168: LGTM! Improved error handling with selective metric inclusion.The changes enhance error handling by:
- Using flags to determine which metrics to include
- Organizing errors into a separate dictionary
- Handling different force metric cases (regular vs spin)
459-468: LGTM! Improved conditional logging for force metrics.The changes enhance logging by:
- Only logging when force data is present
- Handling different force metric types (regular vs spin)
469-469: LGTM! Added condition for virial logging.The change ensures virial metrics are only logged when appropriate (PBC enabled and virial data present).
…md-kit into debug-weightedavg
for more information, see https://pre-commit.ci
…md-kit into debug-weightedavg
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
source/tests/pt/test_weighted_avg.py (3)
11-36: Add type hints and docstring for better maintainability.The function needs documentation and type hints to improve maintainability. Also, consider using an enum or constants for flag values.
-def fake_test(all_sys): +from typing import List, Dict, Tuple, Union + +def fake_test(all_sys: List[Tuple[Dict[str, Tuple[float, float]], int, int, int]]) -> Dict[str, float]: + """Calculate weighted average of error metrics based on system components. + + Args: + all_sys: List of tuples containing (error_dict, has_energy, has_force, has_virial) + where error_dict contains metrics like mae_e, rmse_e etc. + + Returns: + Dictionary containing weighted averages of error metrics. + """Also, add error handling for missing force metrics:
if find_force == 1: + if not any(key in err for key in ['rmse_f', 'rmse_fr']): + raise KeyError("No force metrics found in error dictionary") if "rmse_f" in err:
38-44: Add type hints and docstring for consistency.For consistency with
fake_test, add type hints and documentation.-def fake_test_ori(all_sys): +def fake_test_ori(all_sys: List[Tuple[Dict[str, Tuple[float, float]], int, int, int]]) -> Dict[str, float]: + """Calculate weighted average of error metrics ignoring component flags. + + Used as a baseline for comparison with fake_test. + + Args: + all_sys: List of tuples containing (error_dict, has_energy, has_force, has_virial) + + Returns: + Dictionary containing weighted averages of all error metrics. + """
47-229: Enhance test maintainability and coverage.While the test cases are comprehensive, consider these improvements:
- Move test data to class-level setup
- Add docstrings to test methods
- Add negative test cases
Example refactor:
class TestWeightedAverage(unittest.TestCase): + def setUp(self): + """Set up test data.""" + # Define common test data structure + self.base_system = { + "mae_e": (2, 2), + "mae_ea": (4, 2), + "rmse_e": (3, 2), + "rmse_ea": (5, 2), + "mae_f": (2, 3), + "rmse_f": (1, 3), + "mae_v": (3, 5), + "rmse_v": (3, 3), + } + def test_case1_energy_only(self): + """Test weighted average calculation with energy metrics only."""Add negative test cases:
def test_invalid_input(self): """Test handling of invalid input data.""" # Test with empty system list with self.assertRaises(ValueError): fake_test([]) # Test with invalid flag values invalid_sys = [(self.base_system, 2, 0, 0)] # Invalid energy flag with self.assertRaises(ValueError): fake_test(invalid_sys)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
source/tests/pt/test_weighted_avg.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (20)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (javascript-typescript)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
🔇 Additional comments (2)
source/tests/pt/test_weighted_avg.py (2)
1-8: LGTM! Imports are well-organized.The license is properly specified and imports follow Python conventions.
231-232: LGTM! Standard unittest main block.Follows Python conventions for unittest execution.
…md-kit into debug-weightedavg
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🔭 Outside diff range comments (2)
deepmd/entrypoints/test.py (2)
562-574:⚠️ Potential issueFix incorrect type hints.
The function's return type hint
tuple[list[np.ndarray], list[int]]doesn't match the actual return type, which is a dictionary of tuples containing error metrics and their sizes.Update the type hint to match the actual return type:
-) -> tuple[list[np.ndarray], list[int]]: +) -> dict[str, tuple[float, int]]:Also, remove the commented-out flags from the return statements:
- } # find_energy,find_force,find_virial, + }Also applies to: 575-588
137-169: 💡 Verification agent🧩 Analysis chain
Verify the impact of removing error filtering.
The commented-out code previously filtered errors based on flags (
find_energy,find_force,find_virial). Removing this filtering could lead to inconsistencies in error reporting, especially when certain metrics are not available or relevant.Run the following script to verify error handling in test cases:
🏁 Script executed:
#!/bin/bash # Description: Check test cases for error handling patterns. # Search for test cases that rely on filtered errors rg -A 5 'test_ener.*find_' 'source/tests' # Search for error dictionary usage in test cases rg -A 5 'err\[.*(mae|rmse)_(e|f|v)' 'source/tests'Length of output: 4540
Inconsistent error reporting: Unfiltered error dictionary vs. filtered expectations in tests
The removal of the error filtering (i.e., the commented-out aggregation logic that created a filtered
err_part) in thetest_enerfunction means that the function now returns a complete error dictionary. However, test cases insource/tests/pt/test_weighted_avg.pystill expect the filtered error metrics (e.g., keys such as"mae_e","rmse_e", etc., populated conditionally viafind_energy,find_force, andfind_virial). This mismatch can lead to discrepancies in the averaged error computation or even test failures if additional keys or unexpected values are present.
- Action Item: Either update the test case expectations to work with the full error dictionary or reintroduce/adjust the filtering logic to ensure that the returned error dictionary strictly contains the expected keys and values.
🧹 Nitpick comments (1)
deepmd/entrypoints/test.py (1)
735-738: Use f-strings for better readability.The code uses the
%operator for string formatting, which is outdated in Python. Using f-strings would improve readability and maintainability.Apply this diff to update the string formatting:
- detail_path.with_suffix(".dos.out.%.d" % ii), - header="%s - %.d: data_dos pred_dos" % (system, ii), + detail_path.with_suffix(f".dos.out.{ii}"), + header=f"{system} - {ii}: data_dos pred_dos", - detail_path.with_suffix(".ados.out.%.d" % ii), - header="%s - %.d: data_ados pred_ados" % (system, ii), + detail_path.with_suffix(f".ados.out.{ii}"), + header=f"{system} - {ii}: data_ados pred_ados", - detail_path.with_suffix(".property.out.%.d" % ii), - header="%s - %.d: data_property pred_property" % (system, ii), + detail_path.with_suffix(f".property.out.{ii}"), + header=f"{system} - {ii}: data_property pred_property", - detail_path.with_suffix(".aproperty.out.%.d" % ii), - header="%s - %.d: data_aproperty pred_aproperty" % (system, ii), + detail_path.with_suffix(f".aproperty.out.{ii}"), + header=f"{system} - {ii}: data_aproperty pred_aproperty",Also applies to: 749-752, 891-894, 905-908
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/entrypoints/test.py(12 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/entrypoints/test.py
332-332: Local variable find_energy is assigned to but never used
Remove assignment to unused variable find_energy
(F841)
⏰ Context from checks skipped due to timeout of 90000ms (5)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (javascript-typescript)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (1)
deepmd/entrypoints/test.py (1)
810-812: LGTM! Improved code consistency.The changes to the property test function improve readability by:
- Using consistent variable names (
propertyinstead of mixed usage)- Using
task_dimconsistently for dimensionalityAlso applies to: 863-864, 868-869
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
deepmd/entrypoints/test.py (2)
734-750: Use f-strings for string formatting.For better readability and consistency with modern Python practices, consider using f-strings instead of the % formatting style.
Apply this diff to update the string formatting:
- detail_path.with_suffix(".dos.out.%.d" % ii), + detail_path.with_suffix(f".dos.out.{ii:d}"), frame_output, - header="%s - %.d: data_dos pred_dos" % (system, ii), + header=f"{system} - {ii:d}: data_dos pred_dos", append=append_detail, ) save_txt_file( - detail_path.with_suffix(".ados.out.%.d" % ii), + detail_path.with_suffix(f".ados.out.{ii:d}"), frame_output, - header="%s - %.d: data_ados pred_ados" % (system, ii), + header=f"{system} - {ii:d}: data_ados pred_ados", append=append_detail, )
890-906: Use f-strings for string formatting.For better readability and consistency with modern Python practices, consider using f-strings instead of the % formatting style.
Apply this diff to update the string formatting:
- detail_path.with_suffix(".property.out.%.d" % ii), + detail_path.with_suffix(f".property.out.{ii:d}"), frame_output, - header="%s - %.d: data_property pred_property" % (system, ii), + header=f"{system} - {ii:d}: data_property pred_property", append=append_detail, ) save_txt_file( - detail_path.with_suffix(".aproperty.out.%.d" % ii), + detail_path.with_suffix(f".aproperty.out.{ii:d}"), frame_output, - header="%s - %.d: data_aproperty pred_aproperty" % (system, ii), + header=f"{system} - {ii:d}: data_aproperty pred_aproperty", append=append_detail, )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/entrypoints/test.py(12 hunks)
🔇 Additional comments (4)
deepmd/entrypoints/test.py (4)
137-145: LGTM! Function signature update improves type safety.The updated function signature now correctly returns the error dictionary along with the flags, making the return type explicit and improving type safety.
146-167: LGTM! Error filtering logic is now consistent.The error filtering logic has been improved to selectively include metrics based on the flags, addressing the inconsistency mentioned in past review comments.
809-811: LGTM! Using task_dim improves flexibility.Using
dp.task_dimfor property dimensions makes the code more flexible and maintainable by relying on the model's configuration.
331-333: Static analysis warning can be ignored.The flags
find_energy,find_force, andfind_virialare now properly used in error filtering (lines 148-165) and returned from the function. The static analysis warning is outdated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
deepmd/entrypoints/test.py (1)
744-747: Modernize string formatting.Replace old-style % string formatting with f-strings for better readability and maintainability.
- detail_path.with_suffix(".dos.out.%.d" % ii), + detail_path.with_suffix(f".dos.out.{ii:d}"), - header="%s - %.d: data_dos pred_dos" % (system, ii), + header=f"{system} - {ii:d}: data_dos pred_dos",
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/entrypoints/test.py(10 hunks)
🔇 Additional comments (4)
deepmd/entrypoints/test.py (4)
137-167: LGTM! Improved error handling and type safety.The changes improve error handling by making the function return type explicit and filtering error metrics based on what was actually calculated. The structured error dictionary now only includes relevant metrics.
331-333: LGTM! Resolved unused flag issues.The previously unused flags are now properly utilized for error filtering and conditional logging, addressing past review comments and static analysis warnings.
459-470: LGTM! Consistent flag usage in logging.The changes implement consistent conditional logging based on the presence of force and virial calculations, with proper handling of spin calculations.
819-821: LGTM! Consistent property testing implementation.The property testing implementation follows the same patterns as other test functions, with consistent dimension handling and error calculation.
Also applies to: 872-878
…md-kit into debug-weightedavg
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4593 +/- ##
=======================================
Coverage 84.77% 84.78%
=======================================
Files 688 688
Lines 66097 66115 +18
Branches 3539 3539
=======================================
+ Hits 56036 56054 +18
Misses 8919 8919
Partials 1142 1142 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (3)
source/tests/pt/test_weighted_avg.py (3)
64-77: 🛠️ Refactor suggestionValidate calculated error values, not just presence of keys.
The test only checks for the presence/absence of keys without validating the actual calculated error values. This doesn't fully test whether the weighted average functionality works correctly.
Add validation for the error calculations:
err = dp_test_ener( dp, data, system, numb_test=1, detail_file=None, has_atom_ener=False, ) self.assertIn("mae_e", err, "'mae_e' key is missing in the result") self.assertNotIn("mae_fm", err, "'mae_fm' key should not be present in the result") self.assertNotIn( "mae_v", err, "'mae_v' key should not be present in the result" ) self.assertIn("mae_f", err, "'mae_f' key is missing in the result") + + # Validate error calculations by comparing with reference values + # Reset data for fresh loading + data.add_dict = {} + data.get_test_set(0) + + # Get raw data for validation + data.get_test_batch(1) + energy = data.energy[0] + force = data.force[0] + + # Calculate expected errors + expected_mae_e = np.mean(np.abs(energy)) + expected_rmse_e = np.sqrt(np.mean(np.square(energy))) + expected_mae_f = np.mean(np.abs(force)) + expected_rmse_f = np.sqrt(np.mean(np.square(force))) + + # Verify error calculations + self.assertAlmostEqual(err["mae_e"][0], expected_mae_e, places=5, + msg="Energy MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_e"][0], expected_rmse_e, places=5, + msg="Energy RMSE calculation incorrect") + self.assertAlmostEqual(err["mae_f"][0], expected_mae_f, places=5, + msg="Force MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_f"][0], expected_rmse_f, places=5, + msg="Force RMSE calculation incorrect")
113-140: 🛠️ Refactor suggestionValidate virial error calculations.
Similar to the first test, this test only checks for the presence/absence of keys without validating the actual virial error calculations.
Add validation for the virial error calculations:
def test_dp_test_ener_with_virial(self) -> None: virial_path_fake = os.path.join( self.config["training"]["validation_data"]["systems"][0], "set.000", "virial.npy", ) np.save(virial_path_fake, np.ones([1, 9], dtype=np.float64)) dp = DeepEval(self.tmp_model.name, head="PyTorch") system = self.config["training"]["validation_data"]["systems"][0] data = DeepmdData( sys_path=system, set_prefix="set", shuffle_test=False, type_map=dp.get_type_map(), sort_atoms=False, ) err = dp_test_ener( dp, data, system, numb_test=1, detail_file=None, has_atom_ener=False, ) self.assertIn("mae_e", err, "'mae_e' key is missing in the result") self.assertNotIn("mae_fm", err, "'mae_fm' key should not be present in the result") self.assertIn("mae_v", err, "'mae_v' key is missing in the result") self.assertIn("mae_f", err, "'mae_f' key is missing in the result") + + # Validate virial error calculations + # Reset data for fresh loading + data.add_dict = {} + data.get_test_set(0) + + # Get raw data for validation + data.get_test_batch(1) + energy = data.energy[0] + force = data.force[0] + virial = data.virial[0] + + # Calculate expected errors + expected_mae_e = np.mean(np.abs(energy)) + expected_rmse_e = np.sqrt(np.mean(np.square(energy))) + expected_mae_f = np.mean(np.abs(force)) + expected_rmse_f = np.sqrt(np.mean(np.square(force))) + expected_mae_v = np.mean(np.abs(virial)) + expected_rmse_v = np.sqrt(np.mean(np.square(virial))) + + # Verify error calculations + self.assertAlmostEqual(err["mae_e"][0], expected_mae_e, places=5, + msg="Energy MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_e"][0], expected_rmse_e, places=5, + msg="Energy RMSE calculation incorrect") + self.assertAlmostEqual(err["mae_f"][0], expected_mae_f, places=5, + msg="Force MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_f"][0], expected_rmse_f, places=5, + msg="Force RMSE calculation incorrect") + self.assertAlmostEqual(err["mae_v"][0], expected_mae_v, places=5, + msg="Virial MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_v"][0], expected_rmse_v, places=5, + msg="Virial RMSE calculation incorrect")
183-209: 🛠️ Refactor suggestionValidate spin-related calculations.
This test is critical for the PR's purpose of making the weighted average work for all system types, but it only checks for the presence of keys without validating the actual weighted average calculations.
Add validation for the spin-related calculations:
def test_dp_test_ener_with_spin(self) -> None: dp = DeepEval(self.tmp_model.name, head="PyTorch") system = self.config["training"]["validation_data"]["systems"][0] data = DeepmdData( sys_path=system, set_prefix="set", shuffle_test=False, type_map=dp.get_type_map(), sort_atoms=False, ) err = dp_test_ener( dp, data, system, numb_test=1, detail_file=None, has_atom_ener=False, ) self.assertIn("mae_e", err, "'mae_e' key is missing in the result") self.assertIn("mae_fm", err, "'mae_fm' key is missing in the result") self.assertNotIn( "mae_v", err, "'mae_v' key should not be present in the result" ) self.assertNotIn( "mae_f", err, "'mae_f' key should not be present in the result" ) + + # Validate spin-related calculations + # Reset data for fresh loading + data.add_dict = {} + data.get_test_set(0) + + # Get raw data for validation + data.get_test_batch(1) + energy = data.energy[0] + force_m = data.force_m[0] if hasattr(data, 'force_m') else None + + # Calculate expected errors + expected_mae_e = np.mean(np.abs(energy)) + expected_rmse_e = np.sqrt(np.mean(np.square(energy))) + expected_mae_fm = np.mean(np.abs(force_m)) if force_m is not None else 0.0 + expected_rmse_fm = np.sqrt(np.mean(np.square(force_m))) if force_m is not None else 0.0 + + # Verify error calculations for energy and magnetic forces + self.assertAlmostEqual(err["mae_e"][0], expected_mae_e, places=5, + msg="Energy MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_e"][0], expected_rmse_e, places=5, + msg="Energy RMSE calculation incorrect") + if "mae_fm" in err: + self.assertAlmostEqual(err["mae_fm"][0], expected_mae_fm, places=5, + msg="Magnetic force MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_fm"][0], expected_rmse_fm, places=5, + msg="Magnetic force RMSE calculation incorrect")
🧹 Nitpick comments (3)
source/tests/pt/test_weighted_avg.py (3)
79-79: Move file cleanup to tearDown method.The temporary model file is being deleted in the test method, but it would be better to handle all cleanup in the tearDown method for consistency and to ensure cleanup happens even if the test fails.
Apply this diff:
- os.unlink(self.tmp_model.name)And add it to the tearDown method:
def tearDown(self) -> None: + if hasattr(self, 'tmp_model') and os.path.exists(self.tmp_model.name): + os.unlink(self.tmp_model.name) for f in os.listdir("."): if f.startswith("model") and f.endswith(".pt"): os.remove(f)
141-141: Move file cleanup to tearDown method.Same issue as in the first test class.
Apply this diff:
- os.unlink(self.tmp_model.name)And add it to the tearDown method:
def tearDown(self) -> None: + if hasattr(self, 'tmp_model') and os.path.exists(self.tmp_model.name): + os.unlink(self.tmp_model.name) for f in os.listdir("."): if f.startswith("model") and f.endswith(".pt"): os.remove(f)
210-210: Move file cleanup to tearDown method.Same issue as in the previous test classes.
Apply this diff:
- os.unlink(self.tmp_model.name)And add it to the tearDown method:
def tearDown(self) -> None: + if hasattr(self, 'tmp_model') and os.path.exists(self.tmp_model.name): + os.unlink(self.tmp_model.name) for f in os.listdir("."): if f.startswith("model") and f.endswith(".pt"): os.remove(f)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
source/tests/pt/test_weighted_avg.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
source/tests/pt/test_weighted_avg.py
51-51: Use a context manager for opening files
(SIM115)
110-110: Use a context manager for opening files
(SIM115)
180-180: Use a context manager for opening files
(SIM115)
🔇 Additional comments (3)
source/tests/pt/test_weighted_avg.py (3)
51-52: Use context manager for temporary file.The current temporary file handling could lead to resource leaks. Using a context manager would ensure proper cleanup even if an exception occurs.
Apply this diff to use a context manager:
- self.tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, self.tmp_model.name) + with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as self.tmp_model: + torch.jit.save(model, self.tmp_model.name)🧰 Tools
🪛 Ruff (0.8.2)
51-51: Use a context manager for opening files
(SIM115)
110-111: Use context manager for temporary file.Same issue as in the first test class.
Apply this diff to use a context manager:
- self.tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, self.tmp_model.name) + with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as self.tmp_model: + torch.jit.save(model, self.tmp_model.name)🧰 Tools
🪛 Ruff (0.8.2)
110-110: Use a context manager for opening files
(SIM115)
180-181: Use context manager for temporary file.Same issue as in the previous test classes.
Apply this diff to use a context manager:
- self.tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, self.tmp_model.name) + with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as self.tmp_model: + torch.jit.save(model, self.tmp_model.name)🧰 Tools
🪛 Ruff (0.8.2)
180-180: Use a context manager for opening files
(SIM115)
…md-kit into debug-weightedavg
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (5)
source/tests/pt/test_weighted_avg.py (5)
70-71: Use a context manager for opening files.While the code ultimately unlinks the temporary file in
tearDown, adopting a context manager for opening files is a safer, more conventional practice. This ensures resources are released predictably, even if exceptions occur.Below is an example of how you might adapt lines 70-71. Apply a similar approach to lines 150-151, 239-240, and 356-357:
- trainer = get_trainer(deepcopy(self.config)) - model = torch.jit.script(trainer.model) - self.tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, self.tmp_model.name) + with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as tmp_model: + trainer = get_trainer(deepcopy(self.config)) + model = torch.jit.script(trainer.model) + torch.jit.save(model, tmp_model.name) # ... # Run testsAlso applies to: 150-151, 239-240, 356-357
🧰 Tools
🪛 Ruff (0.8.2)
70-70: Use a context manager for opening files
(SIM115)
83-99: Add numeric result validation to strengthen the test.Currently, this test only verifies the presence or absence of certain keys (e.g., "mae_e"). Consider adding assertions to compare actual error values against expected values, ensuring the correctness of computed metrics rather than just confirming the keys exist.
153-183: Introduce numeric result checks for virial-related metrics.Similar to the previous test, this checks only for key existence ("mae_e", "mae_v", etc.) but does not validate the computed values. Strengthen the test by comparing computed results (e.g., force, virial values) to expected values to confirm overall correctness of the virial handling.
291-292: Use unittest assertions instead of plain 'assert'.Unittest's specialized methods (e.g.,
assertEqual,assertAlmostEqual) provide more comprehensive reporting and remain active even if Python is run with optimizations. Consider refactoring lines 291-292 to:- assert avg_err["mae_v"] == mae_v, f"Expected mae_v in avg_err to be {mae_v} but got {avg_err['mae_v']}" - assert avg_err["mae_e"] == mae_e_expected, f"Expected mae_e in avg_err to be {mae_e_expected} but got {avg_err['mae_e']}" + self.assertAlmostEqual(avg_err["mae_v"], mae_v, places=7, + msg=f"Expected mae_v to be {mae_v} but got {avg_err['mae_v']}") + self.assertAlmostEqual(avg_err["mae_e"], mae_e_expected, places=7, + msg=f"Expected mae_e to be {mae_e_expected} but got {avg_err['mae_e']}")
359-385: Validate the computed spin-related metrics.This test verifies only the presence or absence of "mae_e", "mae_fm", etc. Strengthen spin functionality testing by comparing computed spin metrics to a known reference or manually computed baseline, ensuring the weighted average logic is correct for spin systems.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
source/tests/pt/test_weighted_avg.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
source/tests/pt/test_weighted_avg.py
70-70: Use a context manager for opening files
(SIM115)
150-150: Use a context manager for opening files
(SIM115)
239-239: Use a context manager for opening files
(SIM115)
356-356: Use a context manager for opening files
(SIM115)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (3)
source/tests/pt/test_weighted_avg.py (3)
330-353:⚠️ Potential issueFix duplicate variable assignment.
There is a subtle syntax error in the model_spin variable definition. The variable is being assigned to itself.
- model_spin=model_spin = { + model_spin = {
152-181: 🛠️ Refactor suggestionEnhance test validation for virial system.
Similar to the first test, this test only verifies the presence of keys without validating the actual calculated values. Add validation for the virial error calculations.
def test_dp_test_ener_with_virial(self) -> None: virial_path_fake = os.path.join( self.config["training"]["validation_data"]["systems"][0], "set.000", "virial.npy", ) np.save(virial_path_fake, np.ones([1, 9], dtype=np.float64)) dp = DeepEval(self.tmp_model.name, head="PyTorch") system = self.config["training"]["validation_data"]["systems"][0] data = DeepmdData( sys_path=system, set_prefix="set", shuffle_test=False, type_map=dp.get_type_map(), sort_atoms=False, ) err = dp_test_ener( dp, data, system, numb_test=1, detail_file=None, has_atom_ener=False, ) self.assertIn("mae_e", err, "'mae_e' key is missing in the result") self.assertNotIn( "mae_fm", err, "'mae_fm' key should not be present in the result" ) self.assertIn("mae_v", err, "'mae_v' key is missing in the result") self.assertIn("mae_f", err, "'mae_f' key is missing in the result") + + # Validate virial error calculations + # Reset data for fresh loading + data.add_dict = {} + data.get_test_set(0) + + # Get raw data for validation + data.get_test_batch(1) + energy = data.energy[0] + force = data.force[0] + virial = data.virial[0] + + # Calculate expected errors + expected_mae_e = np.mean(np.abs(energy)) + expected_rmse_e = np.sqrt(np.mean(np.square(energy))) + expected_mae_f = np.mean(np.abs(force)) + expected_rmse_f = np.sqrt(np.mean(np.square(force))) + expected_mae_v = np.mean(np.abs(virial)) + expected_rmse_v = np.sqrt(np.mean(np.square(virial))) + + # Verify error calculations + self.assertAlmostEqual(err["mae_e"][0], expected_mae_e, places=5, + msg="Energy MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_e"][0], expected_rmse_e, places=5, + msg="Energy RMSE calculation incorrect") + self.assertAlmostEqual(err["mae_f"][0], expected_mae_f, places=5, + msg="Force MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_f"][0], expected_rmse_f, places=5, + msg="Force RMSE calculation incorrect") + self.assertAlmostEqual(err["mae_v"][0], expected_mae_v, places=5, + msg="Virial MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_v"][0], expected_rmse_v, places=5, + msg="Virial RMSE calculation incorrect")
364-390: 🛠️ Refactor suggestionValidate spin-related calculations.
This test is critical for the PR's purpose of making the weighted average work for all system types, but it only checks for the presence of keys without validating that the weighted average calculations are correct for systems with spin.
def test_dp_test_ener_with_spin(self) -> None: dp = DeepEval(self.tmp_model.name, head="PyTorch") system = self.config["training"]["validation_data"]["systems"][0] data = DeepmdData( sys_path=system, set_prefix="set", shuffle_test=False, type_map=dp.get_type_map(), sort_atoms=False, ) err = dp_test_ener( dp, data, system, numb_test=1, detail_file=None, has_atom_ener=False, ) self.assertIn("mae_e", err, "'mae_e' key is missing in the result") self.assertIn("mae_fm", err, "'mae_fm' key is missing in the result") self.assertNotIn( "mae_v", err, "'mae_v' key should not be present in the result" ) self.assertNotIn( "mae_f", err, "'mae_f' key should not be present in the result" ) + + # Validate spin-related calculations + # Reset data for fresh loading + data.add_dict = {} + data.get_test_set(0) + + # Get raw data for validation + data.get_test_batch(1) + energy = data.energy[0] + force_m = data.force_m[0] if hasattr(data, 'force_m') else None + + # Calculate expected errors + expected_mae_e = np.mean(np.abs(energy)) + expected_rmse_e = np.sqrt(np.mean(np.square(energy))) + expected_mae_fm = np.mean(np.abs(force_m)) if force_m is not None else 0.0 + expected_rmse_fm = np.sqrt(np.mean(np.square(force_m))) if force_m is not None else 0.0 + + # Verify error calculations for energy and magnetic forces + self.assertAlmostEqual(err["mae_e"][0], expected_mae_e, places=5, + msg="Energy MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_e"][0], expected_rmse_e, places=5, + msg="Energy RMSE calculation incorrect") + if "mae_fm" in err: + self.assertAlmostEqual(err["mae_fm"][0], expected_mae_fm, places=5, + msg="Magnetic force MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_fm"][0], expected_rmse_fm, places=5, + msg="Magnetic force RMSE calculation incorrect")
🧹 Nitpick comments (2)
source/tests/pt/test_weighted_avg.py (2)
69-70: Use context manager for temporary file handling.Using
NamedTemporaryFilewithout a context manager can lead to resource leaks if exceptions occur. Context managers ensure proper cleanup even in error scenarios.- self.tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, self.tmp_model.name) + with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as tmp_file: + self.tmp_model_name = tmp_file.name + torch.jit.save(model, self.tmp_model_name)Remember to update other references to
self.tmp_model.nametoself.tmp_model_name.🧰 Tools
🪛 Ruff (0.8.2)
69-69: Use a context manager for opening files
(SIM115)
291-296: Use unittest's assertion methods instead of assert statements.Using Python's built-in
assertstatements in unittest methods is not recommended. If Python is run with the-Oflag, these assertions will be stripped out. Use the unittest's assertion methods instead.- assert avg_err["mae_v"] == mae_v, ( - f"Expected mae_v in avg_err to be {mae_v} but got {avg_err['mae_v']}" - ) - assert avg_err["mae_e"] == mae_e_expected, ( - f"Expected mae_e in avg_err to be {mae_e_expected} but got {avg_err['mae_e']}" - ) + self.assertEqual(avg_err["mae_v"], mae_v, + f"Expected mae_v in avg_err to be {mae_v} but got {avg_err['mae_v']}" + ) + self.assertEqual(avg_err["mae_e"], mae_e_expected, + f"Expected mae_e in avg_err to be {mae_e_expected} but got {avg_err['mae_e']}" + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
source/tests/pt/test_weighted_avg.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
source/tests/pt/test_weighted_avg.py
69-69: Use a context manager for opening files
(SIM115)
149-149: Use a context manager for opening files
(SIM115)
239-239: Use a context manager for opening files
(SIM115)
361-361: Use a context manager for opening files
(SIM115)
🔇 Additional comments (3)
source/tests/pt/test_weighted_avg.py (3)
149-150: Use context manager for temporary file handling.Using
NamedTemporaryFilewithout a context manager can lead to resource leaks if exceptions occur.- self.tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, self.tmp_model.name) + with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as tmp_file: + self.tmp_model_name = tmp_file.name + torch.jit.save(model, self.tmp_model_name)🧰 Tools
🪛 Ruff (0.8.2)
149-149: Use a context manager for opening files
(SIM115)
239-240: Use context manager for temporary file handling.Using
NamedTemporaryFilewithout a context manager can lead to resource leaks if exceptions occur.- self.tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, self.tmp_model.name) + with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as tmp_file: + self.tmp_model_name = tmp_file.name + torch.jit.save(model, self.tmp_model_name)🧰 Tools
🪛 Ruff (0.8.2)
239-239: Use a context manager for opening files
(SIM115)
361-362: Use context manager for temporary file handling.Using
NamedTemporaryFilewithout a context manager can lead to resource leaks if exceptions occur.- self.tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, self.tmp_model.name) + with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as tmp_file: + self.tmp_model_name = tmp_file.name + torch.jit.save(model, self.tmp_model_name)🧰 Tools
🪛 Ruff (0.8.2)
361-361: Use a context manager for opening files
(SIM115)
…md-kit into debug-weightedavg
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (5)
source/tests/pt/test_weighted_avg.py (5)
115-145: 🛠️ Refactor suggestionValidate virial error calculations.
The test only checks for the presence of keys without validating the actual calculated values for virial errors.
Add validation for virial error calculations:
def test_dp_test_ener_with_virial(self) -> None: virial_path_fake = os.path.join( self.config["training"]["validation_data"]["systems"][0], "set.000", "virial.npy", ) np.save(virial_path_fake, np.ones([1, 9], dtype=np.float64)) dp = DeepEval(self.tmp_model.name, head="PyTorch") system = self.config["training"]["validation_data"]["systems"][0] data = DeepmdData( sys_path=system, set_prefix="set", shuffle_test=False, type_map=dp.get_type_map(), sort_atoms=False, ) err = dp_test_ener( dp, data, system, numb_test=1, detail_file=None, has_atom_ener=False, ) self.assertIn("mae_e", err, "'mae_e' key is missing in the result") self.assertNotIn( "mae_fm", err, "'mae_fm' key should not be present in the result" ) self.assertIn("mae_v", err, "'mae_v' key is missing in the result") self.assertIn("mae_f", err, "'mae_f' key is missing in the result") + + # Validate virial error calculations + # Reset data for fresh loading + data.add_dict = {} + data.get_test_set(0) + + # Get raw data for validation + data.get_test_batch(1) + energy = data.energy[0] + force = data.force[0] + virial = data.virial[0] + + # Calculate expected errors + expected_mae_e = np.mean(np.abs(energy)) + expected_rmse_e = np.sqrt(np.mean(np.square(energy))) + expected_mae_f = np.mean(np.abs(force)) + expected_rmse_f = np.sqrt(np.mean(np.square(force))) + expected_mae_v = np.mean(np.abs(virial)) + expected_rmse_v = np.sqrt(np.mean(np.square(virial))) + + # Verify error calculations + self.assertAlmostEqual(err["mae_e"][0], expected_mae_e, places=5, + msg="Energy MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_e"][0], expected_rmse_e, places=5, + msg="Energy RMSE calculation incorrect") + self.assertAlmostEqual(err["mae_f"][0], expected_mae_f, places=5, + msg="Force MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_f"][0], expected_rmse_f, places=5, + msg="Force RMSE calculation incorrect") + self.assertAlmostEqual(err["mae_v"][0], expected_mae_v, places=5, + msg="Virial MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_v"][0], expected_rmse_v, places=5, + msg="Virial RMSE calculation incorrect") os.unlink(self.tmp_model.name)
143-144: 🛠️ Refactor suggestionFix inconsistent assertion messages.
The assertion messages contradict the assertions themselves. You're checking for presence of keys, but the messages state they're missing.
- self.assertIn("mae_v", err, "'mae_v' key is missing in the result") - self.assertIn("mae_f", err, "'mae_f' key is missing in the result") + self.assertIn("mae_v", err, "'mae_v' key should be present in the result") + self.assertIn("mae_f", err, "'mae_f' key should be present in the result")
304-304: 🛠️ Refactor suggestionFix inconsistent assertion message.
The assertion message states that the key is missing while you're checking for its presence.
- self.assertIn("mae_fm", err, "'mae_fm' key is missing in the result") + self.assertIn("mae_fm", err, "'mae_fm' key should be present in the result")
80-80: 🛠️ Refactor suggestionFix inconsistent assertion message.
The assertion message contradicts the assertion itself. You're asserting that the key should be present, but the message implies it shouldn't be present.
- self.assertIn("mae_f", err, "'mae_f' key is missing in the result") + self.assertIn("mae_f", err, "'mae_f' key should be present in the result")
55-82: 🛠️ Refactor suggestionValidate calculated errors against expected values.
The test currently only verifies that certain keys exist in the error dictionary but doesn't validate the actual calculated values. This doesn't fully test that the weighted average functionality is working correctly.
Add validation of calculated errors by comparing with reference values:
def test_dp_test_ener_without_spin(self) -> None: dp = DeepEval(self.tmp_model.name, head="PyTorch") system = self.config["training"]["validation_data"]["systems"][0] data = DeepmdData( sys_path=system, set_prefix="set", shuffle_test=False, type_map=dp.get_type_map(), sort_atoms=False, ) err = dp_test_ener( dp, data, system, numb_test=1, detail_file=None, has_atom_ener=False, ) self.assertIn("mae_e", err, "'mae_e' key is missing in the result") self.assertNotIn( "mae_fm", err, "'mae_fm' key should not be present in the result" ) self.assertNotIn( "mae_v", err, "'mae_v' key should not be present in the result" ) self.assertIn("mae_f", err, "'mae_f' key is missing in the result") + + # Validate error calculations by comparing with reference values + # Reset data for fresh loading + data.add_dict = {} # Reset the data dictionary + data.get_test_set(0) # Load the first test set + + # Get raw data for validation + data.get_test_batch(1) + energy = data.energy[0] + force = data.force[0] + + # Calculate expected errors + expected_mae_e = np.mean(np.abs(energy)) + expected_rmse_e = np.sqrt(np.mean(np.square(energy))) + expected_mae_f = np.mean(np.abs(force)) + expected_rmse_f = np.sqrt(np.mean(np.square(force))) + + # Verify error calculations + self.assertAlmostEqual(err["mae_e"][0], expected_mae_e, places=5, + msg="Energy MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_e"][0], expected_rmse_e, places=5, + msg="Energy RMSE calculation incorrect") + self.assertAlmostEqual(err["mae_f"][0], expected_mae_f, places=5, + msg="Force MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_f"][0], expected_rmse_f, places=5, + msg="Force RMSE calculation incorrect") os.unlink(self.tmp_model.name)
🧹 Nitpick comments (3)
source/tests/pt/test_weighted_avg.py (3)
52-53: Use context manager for temporary file handling.The current approach of creating a temporary file without a context manager could lead to resource leaks if an exception occurs before
os.unlink()is called.Apply this diff to use a context manager instead:
- self.tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, self.tmp_model.name) + with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as self.tmp_model: + torch.jit.save(model, self.tmp_model.name)🧰 Tools
🪛 Ruff (0.8.2)
52-52: Use a context manager for opening files
(SIM115)
186-241: Enhance validation in multisystem test.While this test does some validation of the weighted average calculation, it would be better to use proper assertions.
Replace the custom assertions with proper unittest assertions:
- assert avg_err["mae_v"] == mae_v, ( - f"Expected mae_v in avg_err to be {mae_v} but got {avg_err['mae_v']}" - ) - assert avg_err["mae_e"] == mae_e_expected, ( - f"Expected mae_e in avg_err to be {mae_e_expected} but got {avg_err['mae_e']}" - ) + self.assertEqual( + avg_err["mae_v"], + mae_v, + f"Expected mae_v in avg_err to be {mae_v} but got {avg_err['mae_v']}" + ) + self.assertEqual( + avg_err["mae_e"], + mae_e_expected, + f"Expected mae_e in avg_err to be {mae_e_expected} but got {avg_err['mae_e']}" + )Also, consider using
assertAlmostEqualfor floating-point comparisons to account for minor numerical differences.
35-326: Add comprehensive test for multiple system types.While the PR is about making weighted average fit for all kinds of systems, there isn't a dedicated test that verifies the weighted average works across different system types (e.g., with/without spin, with/without virial).
Consider adding a test that builds on the multi-system test but uses different system types:
class Test_testener_with_mixed_systems(unittest.TestCase): def setUp(self) -> None: # Similar setup as other classes # ... def test_dp_test_ener_with_mixed_systems(self) -> None: # Test weighted average across systems with different properties # 1. System without spin or virial # 2. System with virial # 3. System with spin # Then verify the weighted average is calculated correctly # ... def tearDown(self) -> None: # Similar teardown as other classes # ...This would provide a more comprehensive test of the weighted average functionality across all system types.
🧰 Tools
🪛 Ruff (0.8.2)
52-52: Use a context manager for opening files
(SIM115)
112-112: Use a context manager for opening files
(SIM115)
183-183: Use a context manager for opening files
(SIM115)
281-281: Use a context manager for opening files
(SIM115)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
source/tests/pt/test_weighted_avg.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
source/tests/pt/test_weighted_avg.py
52-52: Use a context manager for opening files
(SIM115)
112-112: Use a context manager for opening files
(SIM115)
183-183: Use a context manager for opening files
(SIM115)
281-281: Use a context manager for opening files
(SIM115)
🔇 Additional comments (3)
source/tests/pt/test_weighted_avg.py (3)
112-113: Use context manager for temporary file handling.Same issue as before - not using a context manager for temporary file handling could lead to resource leaks.
Apply this diff:
- self.tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, self.tmp_model.name) + with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as self.tmp_model: + torch.jit.save(model, self.tmp_model.name)🧰 Tools
🪛 Ruff (0.8.2)
112-112: Use a context manager for opening files
(SIM115)
183-184: Use context manager for temporary file handling.Same issue with temporary file handling without context manager.
Apply this diff:
- self.tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, self.tmp_model.name) + with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as self.tmp_model: + torch.jit.save(model, self.tmp_model.name)🧰 Tools
🪛 Ruff (0.8.2)
183-183: Use a context manager for opening files
(SIM115)
281-282: Use context manager for temporary file handling.Same issue with temporary file handling without context manager.
Apply this diff:
- self.tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, self.tmp_model.name) + with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as self.tmp_model: + torch.jit.save(model, self.tmp_model.name)🧰 Tools
🪛 Ruff (0.8.2)
281-281: Use a context manager for opening files
(SIM115)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (5)
source/tests/pt/test_weighted_avg.py (5)
57-83: 🛠️ Refactor suggestionEnhance test validation for non-spin system.
The test only checks for the presence of keys without validating the calculated error values, which doesn't fully test the weighted average functionality.
Add validation for the calculated error values:
def test_dp_test_ener_without_spin(self) -> None: dp = DeepEval(self.tmp_model.name, head="PyTorch") system = self.config["training"]["validation_data"]["systems"][0] data = DeepmdData( sys_path=system, set_prefix="set", shuffle_test=False, type_map=dp.get_type_map(), sort_atoms=False, ) err = dp_test_ener( dp, data, system, numb_test=1, detail_file=None, has_atom_ener=False, ) self.assertIn("mae_e", err, "'mae_e' key is missing in the result") self.assertNotIn( "mae_fm", err, "'mae_fm' key should not be present in the result" ) self.assertNotIn( "mae_v", err, "'mae_v' key should not be present in the result" ) self.assertIn("mae_f", err, "'mae_f' key is missing in the result") + + # Validate error calculations by comparing with reference values + # Reset data for fresh loading + data.add_dict = {} # Reset the data dictionary + data.get_test_set(0) # Load the first test set + + # Get raw data for validation + data.get_test_batch(1) + energy = data.energy[0] + force = data.force[0] + + # Calculate expected errors + expected_mae_e = np.mean(np.abs(energy)) + expected_rmse_e = np.sqrt(np.mean(np.square(energy))) + expected_mae_f = np.mean(np.abs(force)) + expected_rmse_f = np.sqrt(np.mean(np.square(force))) + + # Verify error calculations + self.assertAlmostEqual(err["mae_e"][0], expected_mae_e, places=5, + msg="Energy MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_e"][0], expected_rmse_e, places=5, + msg="Energy RMSE calculation incorrect") + self.assertAlmostEqual(err["mae_f"][0], expected_mae_f, places=5, + msg="Force MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_f"][0], expected_rmse_f, places=5, + msg="Force RMSE calculation incorrect")
82-82: 🛠️ Refactor suggestionFix incorrect assertion message.
The assertion message contradicts the assertion itself.
- self.assertIn("mae_f", err, "'mae_f' key is missing in the result") + self.assertIn("mae_f", err, "'mae_f' key should be present in the result")
118-147: 🛠️ Refactor suggestionEnhance test validation for virial system.
Similar to the first test, this test only verifies the presence of keys without validating the actual calculated values.
Add validation for the virial error calculations:
def test_dp_test_ener_with_virial(self) -> None: virial_path_fake = os.path.join( self.config["training"]["validation_data"]["systems"][0], "set.000", "virial.npy", ) np.save(virial_path_fake, np.ones([1, 9], dtype=np.float64)) dp = DeepEval(self.tmp_model.name, head="PyTorch") system = self.config["training"]["validation_data"]["systems"][0] data = DeepmdData( sys_path=system, set_prefix="set", shuffle_test=False, type_map=dp.get_type_map(), sort_atoms=False, ) err = dp_test_ener( dp, data, system, numb_test=1, detail_file=None, has_atom_ener=False, ) self.assertIn("mae_e", err, "'mae_e' key is missing in the result") self.assertNotIn( "mae_fm", err, "'mae_fm' key should not be present in the result" ) self.assertIn("mae_v", err, "'mae_v' key is missing in the result") self.assertIn("mae_f", err, "'mae_f' key is missing in the result") + + # Validate virial error calculations + # Reset data for fresh loading + data.add_dict = {} # Reset the data dictionary + data.get_test_set(0) # Load the first test set + + # Get raw data for validation + data.get_test_batch(1) + energy = data.energy[0] + force = data.force[0] + virial = data.virial[0] + + # Calculate expected errors + expected_mae_e = np.mean(np.abs(energy)) + expected_rmse_e = np.sqrt(np.mean(np.square(energy))) + expected_mae_f = np.mean(np.abs(force)) + expected_rmse_f = np.sqrt(np.mean(np.square(force))) + expected_mae_v = np.mean(np.abs(virial)) + expected_rmse_v = np.sqrt(np.mean(np.square(virial))) + + # Verify error calculations + self.assertAlmostEqual(err["mae_e"][0], expected_mae_e, places=5, + msg="Energy MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_e"][0], expected_rmse_e, places=5, + msg="Energy RMSE calculation incorrect") + self.assertAlmostEqual(err["mae_f"][0], expected_mae_f, places=5, + msg="Force MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_f"][0], expected_rmse_f, places=5, + msg="Force RMSE calculation incorrect") + self.assertAlmostEqual(err["mae_v"][0], expected_mae_v, places=5, + msg="Virial MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_v"][0], expected_rmse_v, places=5, + msg="Virial RMSE calculation incorrect")
287-314: 🛠️ Refactor suggestionValidate spin-related calculations.
This test is critical for the PR's purpose of making the weighted average work for all system types, but it only checks for the presence of keys without validating the calculations.
Add validation for the spin-related calculations:
def test_dp_test_ener_with_spin(self) -> None: dp = DeepEval(self.tmp_model.name, head="PyTorch") system = self.config["training"]["validation_data"]["systems"][0] data = DeepmdData( sys_path=system, set_prefix="set", shuffle_test=False, type_map=dp.get_type_map(), sort_atoms=False, ) err = dp_test_ener( dp, data, system, numb_test=1, detail_file=None, has_atom_ener=False, ) self.assertIn("mae_e", err, "'mae_e' key is missing in the result") self.assertIn("mae_fm", err, "'mae_fm' key is missing in the result") self.assertNotIn( "mae_v", err, "'mae_v' key should not be present in the result" ) self.assertNotIn( "mae_f", err, "'mae_f' key should not be present in the result" ) + + # Validate spin-related calculations + # Reset data for fresh loading + data.add_dict = {} + data.get_test_set(0) + + # Get raw data for validation + data.get_test_batch(1) + energy = data.energy[0] + force_m = data.force_m[0] if hasattr(data, 'force_m') else None + + # Calculate expected errors + expected_mae_e = np.mean(np.abs(energy)) + expected_rmse_e = np.sqrt(np.mean(np.square(energy))) + expected_mae_fm = np.mean(np.abs(force_m)) if force_m is not None else 0.0 + expected_rmse_fm = np.sqrt(np.mean(np.square(force_m))) if force_m is not None else 0.0 + + # Verify error calculations for energy and magnetic forces + self.assertAlmostEqual(err["mae_e"][0], expected_mae_e, places=5, + msg="Energy MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_e"][0], expected_rmse_e, places=5, + msg="Energy RMSE calculation incorrect") + if "mae_fm" in err: + self.assertAlmostEqual(err["mae_fm"][0], expected_mae_fm, places=5, + msg="Magnetic force MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_fm"][0], expected_rmse_fm, places=5, + msg="Magnetic force RMSE calculation incorrect")
146-147: 🛠️ Refactor suggestionFix incorrect assertion messages.
The assertion messages should indicate that the keys should be present, not missing.
- self.assertIn("mae_v", err, "'mae_v' key is missing in the result") - self.assertIn("mae_f", err, "'mae_f' key is missing in the result") + self.assertIn("mae_v", err, "'mae_v' key should be present in the result") + self.assertIn("mae_f", err, "'mae_f' key should be present in the result")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
source/tests/pt/test_weighted_avg.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
source/tests/pt/test_weighted_avg.py
54-54: Use a context manager for opening files
(SIM115)
115-115: Use a context manager for opening files
(SIM115)
186-186: Use a context manager for opening files
(SIM115)
284-284: Use a context manager for opening files
(SIM115)
🔇 Additional comments (5)
source/tests/pt/test_weighted_avg.py (5)
54-56: Use context manager for temporary file handling.The current approach to managing temporary files may lead to resource leaks if exceptions occur.
Apply this diff to use a context manager:
- self.tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, self.tmp_model.name) + with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as self.tmp_model: + torch.jit.save(model, self.tmp_model.name)🧰 Tools
🪛 Ruff (0.8.2)
54-54: Use a context manager for opening files
(SIM115)
115-117: Use context manager for temporary file handling.Similar to the first test class, use a context manager for better resource management.
Apply this diff:
- self.tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, self.tmp_model.name) + with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as self.tmp_model: + torch.jit.save(model, self.tmp_model.name)🧰 Tools
🪛 Ruff (0.8.2)
115-115: Use a context manager for opening files
(SIM115)
186-188: Use context manager for temporary file handling.Apply the same pattern for consistent resource management.
- self.tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, self.tmp_model.name) + with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as self.tmp_model: + torch.jit.save(model, self.tmp_model.name)🧰 Tools
🪛 Ruff (0.8.2)
186-186: Use a context manager for opening files
(SIM115)
284-286: Use context manager for temporary file handling.Apply the same pattern for consistent resource management.
- self.tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, self.tmp_model.name) + with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as self.tmp_model: + torch.jit.save(model, self.tmp_model.name)🧰 Tools
🪛 Ruff (0.8.2)
284-284: Use a context manager for opening files
(SIM115)
189-243: Excellent implementation of weighted average test.This test correctly validates the weighted average functionality by comparing manually calculated expected values with the output from the weighted_average function. This approach directly addresses the PR objective of making the weighted average fit for all kinds of systems.
…md-kit into debug-weightedavg
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (2)
source/tests/pt/test_weighted_avg.py (2)
215-222: 🛠️ Refactor suggestionFix inconsistent assertion messages.
The assertion messages don't match their purpose. For the
assertInchecks, the message states the key is missing, which contradicts the assertion itself.- self.assertIn("mae_e", err, "'mae_e' key is missing in the result") - self.assertIn("mae_fm", err, "'mae_fm' key is missing in the result") + self.assertIn("mae_e", err, "'mae_e' key should be present in the result") + self.assertIn("mae_fm", err, "'mae_fm' key should be present in the result") self.assertNotIn( "mae_v", err, "'mae_v' key should not be present in the result" ) self.assertNotIn( "mae_f", err, "'mae_f' key should not be present in the result" )
196-223: 🛠️ Refactor suggestionValidate spin-related calculations.
This test is critical for the PR's purpose of making the weighted average work for all system types, but it only checks for the presence of keys without validating the calculations.
def test_dp_test_ener_with_spin(self) -> None: dp = DeepEval(self.tmp_model.name, head="PyTorch") system = self.config["training"]["validation_data"]["systems"][0] data = DeepmdData( sys_path=system, set_prefix="set", shuffle_test=False, type_map=dp.get_type_map(), sort_atoms=False, ) err = dp_test_ener( dp, data, system, numb_test=1, detail_file=None, has_atom_ener=False, ) self.assertIn("mae_e", err, "'mae_e' key is missing in the result") self.assertIn("mae_fm", err, "'mae_fm' key is missing in the result") self.assertNotIn( "mae_v", err, "'mae_v' key should not be present in the result" ) self.assertNotIn( "mae_f", err, "'mae_f' key should not be present in the result" ) + + # Validate spin-related calculations + # Reset data for fresh loading + data.add_dict = {} + data.get_test_set(0) + + # Get raw data for validation + data.get_test_batch(1) + energy = data.energy[0] + force_m = data.force_m[0] if hasattr(data, 'force_m') else None + + # Calculate expected errors + expected_mae_e = np.mean(np.abs(energy)) + expected_rmse_e = np.sqrt(np.mean(np.square(energy))) + expected_mae_fm = np.mean(np.abs(force_m)) if force_m is not None else 0.0 + expected_rmse_fm = np.sqrt(np.mean(np.square(force_m))) if force_m is not None else 0.0 + + # Verify error calculations for energy and magnetic forces + self.assertAlmostEqual(err["mae_e"][0], expected_mae_e, places=5, + msg="Energy MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_e"][0], expected_rmse_e, places=5, + msg="Energy RMSE calculation incorrect") + if "mae_fm" in err: + self.assertAlmostEqual(err["mae_fm"][0], expected_mae_fm, places=5, + msg="Magnetic force MAE calculation incorrect") + self.assertAlmostEqual(err["rmse_fm"][0], expected_rmse_fm, places=5, + msg="Magnetic force RMSE calculation incorrect") os.unlink(self.tmp_model.name)
🧹 Nitpick comments (2)
source/tests/pt/test_weighted_avg.py (2)
54-55: Use context manager for temporary file operations.The current implementation does not properly handle file cleanup if an exception occurs during the test. Using a context manager would ensure proper resource management.
- self.tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, self.tmp_model.name) + with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as self.tmp_model: + torch.jit.save(model, self.tmp_model.name)🧰 Tools
🪛 Ruff (0.8.2)
54-54: Use a context manager for opening files
(SIM115)
84-152: Enhance test validation for virial system.Similar to the first test, you're only partially validating the calculations. Add validation for the virial error calculations to ensure that the weighted average is correctly computed.
def test_dp_test_ener_with_multisys_and_with_virial(self) -> None: # ... existing code ... self.assertIn("mae_e", err_virial, "'mae_e' key is missing in the result") self.assertNotIn( "mae_fm", err_virial, "'mae_fm' key should not be present in the result" ) self.assertIn("mae_v", err_virial, "'mae_v' key is missing in the result") self.assertIn("mae_f", err_virial, "'mae_f' key is missing in the result") + # Validate virial error calculations + # Reset data for fresh loading + data.add_dict = {} + data.get_test_set(0) + + # Get raw data for validation + data.get_test_batch(1) + energy = data.energy[0] + force = data.force[0] + virial = data.virial[0] + + # Calculate expected errors + expected_mae_e = np.mean(np.abs(energy)) + expected_rmse_e = np.sqrt(np.mean(np.square(energy))) + expected_mae_f = np.mean(np.abs(force)) + expected_rmse_f = np.sqrt(np.mean(np.square(force))) + expected_mae_v = np.mean(np.abs(virial)) + expected_rmse_v = np.sqrt(np.mean(np.square(virial))) + + # Verify error calculations + self.assertAlmostEqual(err_virial["mae_e"][0], expected_mae_e, places=5, + msg="Energy MAE calculation incorrect") + self.assertAlmostEqual(err_virial["rmse_e"][0], expected_rmse_e, places=5, + msg="Energy RMSE calculation incorrect") + self.assertAlmostEqual(err_virial["mae_f"][0], expected_mae_f, places=5, + msg="Force MAE calculation incorrect") + self.assertAlmostEqual(err_virial["rmse_f"][0], expected_rmse_f, places=5, + msg="Force RMSE calculation incorrect") + self.assertAlmostEqual(err_virial["mae_v"][0], expected_mae_v, places=5, + msg="Virial MAE calculation incorrect") + self.assertAlmostEqual(err_virial["rmse_v"][0], expected_rmse_v, places=5, + msg="Virial RMSE calculation incorrect")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
source/tests/pt/test_weighted_avg.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
source/tests/pt/test_weighted_avg.py
54-54: Use a context manager for opening files
(SIM115)
193-193: Use a context manager for opening files
(SIM115)
⏰ Context from checks skipped due to timeout of 90000ms (20)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (1)
source/tests/pt/test_weighted_avg.py (1)
193-194: Use context manager for temporary file operations.The current implementation using NamedTemporaryFile without a context manager could lead to resource leaks. Use a context manager for better resource management.
- self.tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") - torch.jit.save(model, self.tmp_model.name) + with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as self.tmp_model: + torch.jit.save(model, self.tmp_model.name)🧰 Tools
🪛 Ruff (0.8.2)
193-193: Use a context manager for opening files
(SIM115)
|
The modification was tested using the command 'dp test' in the command line, and the result met the expectations. @iProzd |
|
@SumGuo-88 please resolve all conversations to let the PR be merged. |
Summary by CodeRabbit
Refactor
Tests