-
Notifications
You must be signed in to change notification settings - Fork 584
feat(pt): add compression support for se_e3_tebd #4992
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds compression support for the se_e3_tebd (SE_T_TEBD) descriptor type. The compression functionality allows tabulation of the embedding network to improve computational efficiency during inference.
- Adds "T_TEBD" descriptor type support to the tabulation system
- Refactors variable names from
xx/vv/ttto more descriptive names likemesh/value/stride - Implements compression methods for the SE_T_TEBD descriptor block
Reviewed Changes
Copilot reviewed 5 out of 6 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| examples/water/.gitignore | Adds *.hdf5 to ignored files |
| deepmd/utils/tabulate.py | Adds T_TEBD support and refactors variable naming throughout tabulation logic |
| deepmd/tf/utils/tabulate.py | Updates parameter names from xx to mesh in TensorFlow tabulation |
| deepmd/pt/utils/tabulate.py | Adds T_TEBD descriptor support and updates PyTorch tabulation implementation |
| deepmd/pt/model/descriptor/se_t_tebd.py | Implements compression functionality for SE_T_TEBD descriptor |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughAdds TEBD (se_t_tebd) tabulation support across Python and C++ (CPU/GPU): new TEBD tabulation kernels, autograd op and Torch binding, shared-table tabulation flow, enable_compression APIs and compression state in PyTorch descriptor classes, serialization of compression metadata, and unit tests; updates docs to mark compression supported. Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python caller
participant Desc as DescrptSeTTebd
participant TabPy as DPTabulate / Python tabulate
participant Lib as C++ tabulate (CPU/GPU)
Note over Py,Desc: Enable compression flow
Py->>Desc: enable_compression(min_nbor_dist,...)
Desc->>TabPy: build DPTabulate table (ActivationFn, strides, config)
TabPy-->>Desc: return table_data, table_config, lower, upper
Desc->>Desc: store compress_info/compress_data, set compress=True
Desc->>Desc: call underlying se_ttebd.enable_compression(...)
Note over Py,Desc: Forward/backward when compressed
Py->>Desc: forward(inputs)
alt compressed && tebd_input_mode == "strip"
Desc->>TabPy: prepare em_x/em, call tabulate op with stored table_info
TabPy->>Lib: tabulate_fusion_se_t_tebd(table, table_info, em_x, em, last_layer_size)
Lib-->>TabPy: descriptor (forward) / dy_dem_x (backward)
TabPy-->>Desc: deliver outputs (and gradients on backward)
else fallback
Desc->>Desc: standard (non-tabulated) embedding path
Desc-->>Py: outputs
end
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/pt/utils/tabulate.py (1)
314-318: Off-by-one in stride window selection (T/T_TEBD) corrupts spline segmentsUsing +1 shifts the fine-stride window by one interval; the first fine interval remains coarse, breaking coefficients.
Apply:
- start_index = int((lower - extrapolate * lower) / stride1) + 1 + start_index = int((lower - extrapolate * lower) / stride1) end_index = start_index + int((upper - lower) / stride0) stride[start_index:end_index, :] = stride0
🧹 Nitpick comments (4)
deepmd/pt/model/descriptor/se_t_tebd.py (3)
527-603: Enable-compression flow for TEBD — mostly OK; tighten message and comment
- Logic and wiring to DPTabulate look correct.
- Improve the mode check error to reflect the actual value; update the comment to TEBD.
- if self.tebd_input_mode != "strip": - raise RuntimeError("Cannot compress model when tebd_input_mode == 'concat'") + if self.tebd_input_mode != "strip": + raise RuntimeError(f"Compression requires tebd_input_mode='strip' (got '{self.tebd_input_mode}')") ... - # Scale the stride values for SE_T descriptor + # Scale the stride values for TEBD descriptorAs per coding guidelines
1021-1060: Signature/type hints mismatch for table_configtable_config is used as an indexable sequence (list), but annotated as dict. Align the type hints and docstring.
- def enable_compression( - self, - table_data: dict, - table_config: dict, + def enable_compression( + self, + table_data: dict, + table_config: list[float] | tuple[float, float, float, int], lower: dict, upper: dict, ) -> None: @@ - table_config : dict - Configuration for table compression + table_config : list[float] | tuple[float, float, float, int] + [extrapolate, stride0, stride1, check_frequency]As per coding guidelines
552-577: Shorten/standardize exception messages (ruff TRY003)A couple of raised messages are verbose. Keep them concise to satisfy TRY003.
- assert not self.se_ttebd.resnet_dt, ( - "Model compression error: descriptor resnet_dt must be false!" - ) + assert not self.se_ttebd.resnet_dt, "resnet_dt must be False for compression" @@ - raise RuntimeError( - "Empty embedding-nets are not supported in model compression!" - ) + raise RuntimeError("Empty embedding nets not supported for compression")Run ruff check . to confirm.
deepmd/utils/tabulate.py (1)
293-304: _generate_spline_table type hints incorrectstride0/stride1 and extrapolate are floats; current annotations are misleading.
- stride0: int, - stride1: int, - extrapolate: bool, + stride0: float, + stride1: float, + extrapolate: float,As per coding guidelines
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
.gitignore(1 hunks)deepmd/pt/model/descriptor/se_t_tebd.py(6 hunks)deepmd/pt/utils/tabulate.py(9 hunks)deepmd/tf/utils/tabulate.py(2 hunks)deepmd/utils/tabulate.py(12 hunks)examples/water/.gitignore(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (AGENTS.md)
Always run
ruff check .andruff format .before committing changes to Python code
Files:
deepmd/pt/utils/tabulate.pydeepmd/tf/utils/tabulate.pydeepmd/utils/tabulate.pydeepmd/pt/model/descriptor/se_t_tebd.py
🧬 Code graph analysis (4)
deepmd/pt/utils/tabulate.py (2)
deepmd/utils/tabulate.py (1)
_make_data(391-408)deepmd/tf/utils/tabulate.py (3)
_make_data(319-466)_layer_1(471-473)_layer_0(468-469)
deepmd/tf/utils/tabulate.py (2)
deepmd/pt/utils/tabulate.py (5)
_make_data(114-281)_layer_0(283-286)unaggregated_dy_dx_s(502-521)unaggregated_dy2_dx_s(524-550)_layer_1(288-292)deepmd/tf/utils/compress.py (2)
_layer_0(59-60)_layer_1(63-65)
deepmd/utils/tabulate.py (2)
deepmd/pt/utils/tabulate.py (1)
_make_data(114-281)deepmd/tf/utils/tabulate.py (1)
_make_data(319-466)
deepmd/pt/model/descriptor/se_t_tebd.py (4)
deepmd/pt/utils/tabulate.py (1)
DPTabulate(30-441)deepmd/pt/model/descriptor/se_t.py (3)
enable_compression(284-330)enable_compression(743-770)serialize(412-438)deepmd/pt/utils/utils.py (1)
ActivationFn(175-220)deepmd/pt/model/descriptor/se_atten.py (4)
enable_compression(427-448)serialize(780-804)serialize(892-914)serialize(1082-1105)
🪛 Ruff (0.13.1)
deepmd/utils/tabulate.py
320-320: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/pt/model/descriptor/se_t_tebd.py
552-552: Avoid specifying long messages outside the exception class
(TRY003)
572-574: Avoid specifying long messages outside the exception class
(TRY003)
577-577: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (7)
deepmd/tf/utils/tabulate.py (1)
318-466: Mesh rename keeps TensorFlow path alignedThe reshaping/derivative flow remains identical while matching the new mesh terminology, so the TensorFlow tabulator stays in sync with the PT backend. Looks good.
deepmd/pt/utils/tabulate.py (4)
309-311: Add TEBD descriptor type mapping — OKCorrectly recognizes DescrptSeTTebd as "T_TEBD".
325-326: Layer-size handling for T_TEBD — OKTreating "T_TEBD" like "Atten" for layer sizing aligns with shared embedding net design.
394-401: Shared network variables for T_TEBD — OKFetching the single shared embedding network per layer matches the TEBD design.
154-155: Device inconsistency for torch.ones (duplicate of prior feedback)Use env.DEVICE for consistency with the rest of the tensor placements.
- ) + torch.ones((1, 1), dtype=yy.dtype, device=yy.device) + ) + torch.ones((1, 1), dtype=yy.dtype, device=env.DEVICE) ... - ) + torch.ones((1, 2), dtype=yy.dtype, device=yy.device) + ) + torch.ones((1, 2), dtype=yy.dtype, device=env.DEVICE)Also applies to: 173-174
deepmd/utils/tabulate.py (2)
203-247: Shared-mesh tabulation for T_TEBD — OKGlobal range + single mesh for the shared geometric net is correct; nspline formula matches the constructed mesh.
Please run ruff check . and a quick smoke build of tables for a small synthetic range to ensure nspline equals len(mesh)-1.
486-505: Env range update for T_TEBD — OKHandling T_TEBD with the (cos theta)^2 bounds matches the T path.
iProzd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- please remove uncessary changes, such as
.gitignore - please remove rename in this PR, you should make it a seperate PR.
- please add a UT for your modification such as source/tests/pt/test_tabulate_fusion_se_atten.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
deepmd/pt/utils/tabulate.py (1)
278-281: Improved variable naming for clarity.The rename from
vvtovaluemakes the final output variable more descriptive and clear, improving code readability without changing functionality.deepmd/utils/tabulate.py (3)
107-109: LGTM! Refactored to use unified spline table generation.The renaming from
_build_lowerto_generate_spline_tablewith updated parameters improves code clarity and provides a consistent interface for spline table generation across different descriptor types.
291-302: LGTM! Improved function signature and documentation.The renamed function with updated parameter names (
xx→mesh, clearer parameter names) and better documentation improves code maintainability. The signature change from the old_build_lowerto_generate_spline_tablemakes the purpose more explicit.
504-512: LGTM! Cleaner spline switch function implementation.The updated parameter names (
xx→x) and simplified variable names improve code readability. The function logic remains correct and maintains the same mathematical behavior.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/utils/tabulate.py(8 hunks)deepmd/utils/tabulate.py(10 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (AGENTS.md)
Always run
ruff check .andruff format .before committing changes to Python code
Files:
deepmd/pt/utils/tabulate.pydeepmd/utils/tabulate.py
🧬 Code graph analysis (1)
deepmd/pt/utils/tabulate.py (2)
deepmd/tf/utils/tabulate.py (1)
_layer_1(471-473)deepmd/pt/model/descriptor/se_t_tebd.py (1)
DescrptSeTTebd(78-602)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (17)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
🔇 Additional comments (15)
deepmd/pt/utils/tabulate.py (8)
69-69: LGTM! Added T_TEBD to supported descriptor types.The addition of "T_TEBD" to the supported descriptor types correctly enables tabulation support for the se_e3_tebd descriptor, which was previously not supported for model compression according to the documentation.
154-154: Device consistency maintained correctly.
163-180: LGTM! Correct residual handling for TEBD-style descriptors.The changes properly handle residual connections in the first layer for TEBD descriptors, matching the pattern used elsewhere in the codebase for similar descriptor types. The device-aware tensor creation is correctly implemented.
173-173: Device consistency maintained correctly.
234-247: LGTM! Correct residual handling for TEBD-style descriptors in deeper layers.The changes extend the residual connection handling to deeper layers, maintaining consistency with the first layer implementation. The variable naming with
residualimproves code clarity.
309-311: LGTM! Added TEBD descriptor type recognition.The addition correctly maps
DescrptSeTTebdto the "T_TEBD" type, enabling proper classification for the new descriptor type.
325-326: LGTM! Correct layer sizing for TEBD descriptors.Treating "T_TEBD" similarly to "Atten" for layer sizing is appropriate since both use shared network architectures, as evidenced by the network variable handling.
394-400: LGTM! Shared embedding network for TEBD descriptors.The implementation correctly handles the shared embedding network approach for T_TEBD descriptors, where a single network is used for all type pairs. This aligns with the architectural design described in the relevant code snippets.
deepmd/utils/tabulate.py (7)
148-150: LGTM! Consistent use of unified spline generation.The function call correctly uses the new
_generate_spline_tableinterface, maintaining the same functionality with improved naming.
188-199: LGTM! Consistent spline generation for T descriptor.The updated call maintains the same functionality while using the new unified interface.
200-245: LGTM! Well-implemented T_TEBD descriptor support.The T_TEBD implementation correctly:
- Calculates global ranges across all types for the shared network
- Creates a unified input grid based on global bounds
- Generates spline coefficients once for the shared geometric network
- Uses appropriate naming convention ("filter_net")
This approach aligns with the shared network architecture of TEBD descriptors and follows the established patterns for other descriptor types.
280-282: LGTM! Consistent spline generation for R descriptor.The updated call maintains functionality while using the new unified interface.
309-316: Critical stride indexing issue needs attention.The existing review comment identifies a critical off-by-one error in the stride window selection that affects both T and T_TEBD descriptor types. This issue corrupts spline segments and should be addressed.
436-436: LGTM! Correct table size handling for TEBD.Including "T_TEBD" with "Atten" and "AEbdV2" for table size calculation is appropriate since they all use shared network architectures requiring only one table.
490-490: LGTM! Correct environment matrix range handling for TEBD.Including "T_TEBD" with "T" for environment matrix range calculation is appropriate since they handle similar geometric (angular) information and use the same mathematical formulation for computing bounds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (1)
deepmd/pt/model/descriptor/se_t_tebd.py (1)
527-603: Check the scaled stride values for SE_T_TEBD compression.The stride values are scaled by 10x (lines 587-588) to match SE_T behavior. However, the comment on line 586 says "Scale the stride values for SE_T descriptor" rather than SE_T_TEBD. Consider updating the comment for clarity.
Additionally, the runtime checks for
tebd_input_mode(line 576) and other validations look good.Apply this diff to clarify the comment:
- # Scale the stride values for SE_T descriptor + # Scale the stride values to match SE_T behavior for TEBD compression
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
deepmd/pt/model/descriptor/se_t_tebd.py(10 hunks)source/lib/include/tabulate.h(2 hunks)source/lib/src/gpu/tabulate.cu(3 hunks)source/lib/src/tabulate.cc(2 hunks)source/op/pt/tabulate_multi_device.cc(4 hunks)source/tests/pt/test_tabulate_fusion_se_t_tebd.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (AGENTS.md)
Always run
ruff check .andruff format .before committing changes to Python code
Files:
source/tests/pt/test_tabulate_fusion_se_t_tebd.pydeepmd/pt/model/descriptor/se_t_tebd.py
🧬 Code graph analysis (6)
source/lib/include/tabulate.h (2)
source/lib/src/gpu/tabulate.cu (25)
void(39-48)void(165-252)void(255-362)void(365-471)void(474-514)void(517-578)void(581-631)void(634-672)void(675-718)void(721-761)void(764-795)void(798-840)void(843-882)tabulate_fusion_se_t_tebd_gpu(1058-1078)tabulate_fusion_se_t_tebd_gpu(1058-1066)tabulate_fusion_se_t_tebd_gpu(1393-1401)tabulate_fusion_se_t_tebd_gpu(1403-1411)tabulate_fusion_se_t_tebd_grad_gpu(1081-1104)tabulate_fusion_se_t_tebd_grad_gpu(1081-1090)tabulate_fusion_se_t_tebd_grad_gpu(1413-1423)tabulate_fusion_se_t_tebd_grad_gpu(1425-1435)tabulate_fusion_se_t_tebd_grad_grad_gpu(1107-1132)tabulate_fusion_se_t_tebd_grad_grad_gpu(1107-1116)tabulate_fusion_se_t_tebd_grad_grad_gpu(1437-1447)tabulate_fusion_se_t_tebd_grad_grad_gpu(1449-1459)source/lib/src/tabulate.cc (12)
tabulate_fusion_se_t_tebd_cpu(545-590)tabulate_fusion_se_t_tebd_cpu(545-553)tabulate_fusion_se_t_tebd_cpu(963-972)tabulate_fusion_se_t_tebd_cpu(973-982)tabulate_fusion_se_t_tebd_grad_cpu(593-641)tabulate_fusion_se_t_tebd_grad_cpu(593-602)tabulate_fusion_se_t_tebd_grad_cpu(983-993)tabulate_fusion_se_t_tebd_grad_cpu(994-1004)tabulate_fusion_se_t_tebd_grad_grad_cpu(644-692)tabulate_fusion_se_t_tebd_grad_grad_cpu(644-654)tabulate_fusion_se_t_tebd_grad_grad_cpu(1005-1015)tabulate_fusion_se_t_tebd_grad_grad_cpu(1016-1026)
source/lib/src/tabulate.cc (1)
source/lib/src/gpu/tabulate.cu (1)
locate_xx_se_t(76-83)
source/tests/pt/test_tabulate_fusion_se_t_tebd.py (2)
source/tests/consistent/common.py (1)
parameterized(580-640)source/op/pt/tabulate_multi_device.cc (2)
tabulate_fusion_se_t(1193-1201)tabulate_fusion_se_t(1193-1198)
deepmd/pt/model/descriptor/se_t_tebd.py (6)
source/lib/include/tabulate.h (1)
deepmd(4-293)deepmd/pt/utils/tabulate.py (1)
DPTabulate(30-441)deepmd/pt/model/descriptor/se_t.py (3)
enable_compression(284-330)enable_compression(743-770)serialize(412-438)deepmd/pt/utils/utils.py (1)
ActivationFn(175-220)deepmd/utils/tabulate.py (1)
build(70-289)source/op/pt/tabulate_multi_device.cc (2)
tabulate_fusion_se_t_tebd(1203-1211)tabulate_fusion_se_t_tebd(1203-1208)
source/op/pt/tabulate_multi_device.cc (2)
source/lib/src/gpu/tabulate.cu (12)
tabulate_fusion_se_t_tebd_gpu(1058-1078)tabulate_fusion_se_t_tebd_gpu(1058-1066)tabulate_fusion_se_t_tebd_gpu(1393-1401)tabulate_fusion_se_t_tebd_gpu(1403-1411)tabulate_fusion_se_t_tebd_grad_gpu(1081-1104)tabulate_fusion_se_t_tebd_grad_gpu(1081-1090)tabulate_fusion_se_t_tebd_grad_gpu(1413-1423)tabulate_fusion_se_t_tebd_grad_gpu(1425-1435)tabulate_fusion_se_t_tebd_grad_grad_gpu(1107-1132)tabulate_fusion_se_t_tebd_grad_grad_gpu(1107-1116)tabulate_fusion_se_t_tebd_grad_grad_gpu(1437-1447)tabulate_fusion_se_t_tebd_grad_grad_gpu(1449-1459)source/lib/src/tabulate.cc (12)
tabulate_fusion_se_t_tebd_cpu(545-590)tabulate_fusion_se_t_tebd_cpu(545-553)tabulate_fusion_se_t_tebd_cpu(963-972)tabulate_fusion_se_t_tebd_cpu(973-982)tabulate_fusion_se_t_tebd_grad_cpu(593-641)tabulate_fusion_se_t_tebd_grad_cpu(593-602)tabulate_fusion_se_t_tebd_grad_cpu(983-993)tabulate_fusion_se_t_tebd_grad_cpu(994-1004)tabulate_fusion_se_t_tebd_grad_grad_cpu(644-692)tabulate_fusion_se_t_tebd_grad_grad_cpu(644-654)tabulate_fusion_se_t_tebd_grad_grad_cpu(1005-1015)tabulate_fusion_se_t_tebd_grad_grad_cpu(1016-1026)
source/lib/src/gpu/tabulate.cu (1)
source/lib/src/tabulate.cc (2)
locate_xx_se_t(45-73)locate_xx_se_t(45-52)
🪛 Ruff (0.13.1)
deepmd/pt/model/descriptor/se_t_tebd.py
552-552: Avoid specifying long messages outside the exception class
(TRY003)
572-574: Avoid specifying long messages outside the exception class
(TRY003)
577-577: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (18)
deepmd/pt/model/descriptor/se_t_tebd.py (5)
191-191: Initializecompressstate variable.The implementation correctly initializes the compression state flag to
False.
696-702: LGTM! Compression metadata storage initialized correctly.The
compress_infoandcompress_dataParameterLists are properly initialized for storing compression tables and configuration.
954-970: TEBD compression uses correct tabulation operation.The tabulated TEBD path correctly uses
torch.ops.deepmd.tabulate_fusion_se_t_tebdwith proper tensor reshaping and dimension handling. The compressed embedding computation preserves the full neighbor structure as expected.
1021-1023: Proper combination of geometric and type embeddings.The combination formula
gg = gg_s * gg_t + gg_scorrectly implements thegg_s * (1 + gg_t)pattern for merging geometric and type embeddings.
1041-1080: Block-level compression setup is correct.The
enable_compressionmethod for the block properly configures the compression metadata using the shared geometric embedding network key "filter_net" and updates the compression state.source/tests/pt/test_tabulate_fusion_se_t_tebd.py (2)
18-20: Test class properly parameterized and conditionally skipped.The test class correctly uses the
@parameterizeddecorator for multiple dtypes and skips when PyTorch customized ops are unavailable.
201-270: Incorrect shape concern — test uses the 2D op, so the expected (4,4) shape is correct.The test calls torch.ops.deepmd.tabulate_fusion_se_t (test_forward in source/tests/pt/test_tabulate_fusion_se_t_tebd.py), and the C++ wrapper for tabulate_fusion_se_t allocates descriptor as torch::empty({em_tensor.size(0), last_layer_size}) (see source/op/pt/tabulate_multi_device.cc:953–963). The TEBD-specific wrapper tabulate_fusion_se_t_tebd does allocate a 4D tensor (see source/op/pt/tabulate_multi_device.cc:1119–1125), but this test intentionally invokes the 2D variant; expected_descriptor_tensor.reshape(4,4) matches the actual op output.
Likely an incorrect or invalid review comment.
source/lib/src/tabulate.cc (4)
544-590: TEBD forward implementation preserves full neighbor structure.The implementation correctly preserves the nt_i x nt_j x ng structure for SE_T_TEBD, which differs from SE_T's reduction pattern. The polynomial evaluation and output indexing are correct.
592-641: TEBD gradient computation correctly accumulates over last_layer_size.The gradient implementation properly accumulates gradients across all last_layer_size dimensions and stores the result in the correct index pattern.
643-692: TEBD grad-grad implementation correctly applies chain rule.The second-order gradient computation properly multiplies the incoming gradient with the derivative of the polynomial.
963-1026: Template instantiations properly added for TEBD functions.All six TEBD functions (forward, grad, grad_grad) are correctly instantiated for both float and double types, matching the pattern of existing tabulation functions.
source/lib/include/tabulate.h (2)
114-147: CPU function declarations follow existing patterns.The three new CPU function declarations for TEBD tabulation follow the established pattern and parameter ordering conventions.
258-291: GPU function declarations properly guarded and consistent.The GPU function declarations are correctly placed within the CUDA/ROCM preprocessor guards and maintain consistency with the CPU declarations.
source/op/pt/tabulate_multi_device.cc (5)
338-385: Forward function implementation looks good.The
TabulateFusionSeTTebdForwardfunction correctly validates input dimensions, handles device dispatch, and calls the appropriate CPU/GPU kernels.
387-431: Gradient forward function has incorrect dimension check.At line 396, the function checks that
dy_tensor.dim() != 4, but this seems incorrect. For SE_T_TEBD with shape[nloc, nnei_i, nnei_j, last_layer_size], the gradient should indeed be 4D.The dimension check appears correct - the gradient tensor dy should match the descriptor tensor's 4D shape for SE_T_TEBD.
1086-1168: Autograd operator correctly implements forward and backward passes.The
TabulateFusionSeTTebdOpclass properly:
- Allocates 4D descriptor tensors with correct dimensions
- Saves necessary tensors for backward pass
- Computes gradients only for em_x_tensor (not em_tensor)
1203-1211: Public API function properly exposed.The
tabulate_fusion_se_t_tebdfunction correctly wraps the autograd operator.
1231-1233: Torch library registration complete.The TEBD operation is properly registered with the Torch library.
| TORCH_LIBRARY_FRAGMENT(deepmd, m) { | ||
| m.def("tabulate_fusion_se_t", tabulate_fusion_se_t); | ||
| } | ||
| TORCH_LIBRARY_FRAGMENT(deepmd, m) { |
Check notice
Code scanning / CodeQL
Unused static function Note
TORCH_LIBRARY_FRAGMENT_static_init_deepmd_5
c252e57 to
cea01d2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
deepmd/utils/tabulate.py (1)
309-312: Remove the off-by-one in the fine-stride window.The
+ 1still shifts the fine-stride block forward by one slot, disaligning the spline segments for both T and T_TEBD descriptors (same bug previously reported). Drop the increment and clamp the slice to valid bounds.- start_index = int((lower - extrapolate * lower) / stride1) + 1 - end_index = start_index + int((upper - lower) / stride0) - tt[start_index:end_index, :] = stride0 + start_index = max(0, int((lower - extrapolate * lower) / stride1)) + end_index = min( + nspline, + start_index + int((upper - lower) / stride0), + ) + tt[start_index:end_index, :] = stride0source/lib/src/gpu/tabulate.cu (1)
689-717: Stop 128× redundant work in the TEBD grad kernel.All threads execute the full
(nnei_i * nnei_j * last_layer_size)loops, but only lane 0 writes the result. That multiplies work (and kernel runtime) by the block size and risks watchdog timeouts. Bail out early for non-zero threads (or parallelise properly) so the math runs once per block.const int_64 block_idx = blockIdx.x; // nloc const int thread_idx = threadIdx.x; // thread within block + + if (thread_idx != 0) { + return; + } for (int ii = 0; ii < nnei_i; ii++) { for (int jj = 0; jj < nnei_j; jj++) { @@ - if (thread_idx == 0) { // Only one thread writes the gradient - dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = grad_sum; - } + dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = grad_sum;
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
deepmd/pt/model/descriptor/se_t_tebd.py(10 hunks)deepmd/pt/utils/tabulate.py(8 hunks)deepmd/utils/tabulate.py(10 hunks)source/lib/include/tabulate.h(2 hunks)source/lib/src/gpu/tabulate.cu(3 hunks)source/lib/src/tabulate.cc(2 hunks)source/op/pt/tabulate_multi_device.cc(4 hunks)source/tests/pt/test_tabulate_fusion_se_t_tebd.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- source/lib/src/tabulate.cc
- deepmd/pt/utils/tabulate.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (AGENTS.md)
Always run
ruff check .andruff format .before committing changes to Python code
Files:
source/tests/pt/test_tabulate_fusion_se_t_tebd.pydeepmd/utils/tabulate.pydeepmd/pt/model/descriptor/se_t_tebd.py
🧬 Code graph analysis (5)
source/tests/pt/test_tabulate_fusion_se_t_tebd.py (2)
source/tests/consistent/common.py (1)
parameterized(580-640)source/op/pt/tabulate_multi_device.cc (2)
tabulate_fusion_se_t_tebd(1203-1211)tabulate_fusion_se_t_tebd(1203-1208)
source/lib/include/tabulate.h (2)
source/lib/src/gpu/tabulate.cu (25)
void(39-48)void(165-252)void(255-362)void(365-471)void(474-514)void(517-578)void(581-631)void(634-672)void(675-718)void(721-761)void(764-795)void(798-840)void(843-882)tabulate_fusion_se_t_tebd_gpu(1058-1078)tabulate_fusion_se_t_tebd_gpu(1058-1066)tabulate_fusion_se_t_tebd_gpu(1393-1401)tabulate_fusion_se_t_tebd_gpu(1403-1411)tabulate_fusion_se_t_tebd_grad_gpu(1081-1104)tabulate_fusion_se_t_tebd_grad_gpu(1081-1090)tabulate_fusion_se_t_tebd_grad_gpu(1413-1423)tabulate_fusion_se_t_tebd_grad_gpu(1425-1435)tabulate_fusion_se_t_tebd_grad_grad_gpu(1107-1132)tabulate_fusion_se_t_tebd_grad_grad_gpu(1107-1116)tabulate_fusion_se_t_tebd_grad_grad_gpu(1437-1447)tabulate_fusion_se_t_tebd_grad_grad_gpu(1449-1459)source/lib/src/tabulate.cc (12)
tabulate_fusion_se_t_tebd_cpu(545-590)tabulate_fusion_se_t_tebd_cpu(545-553)tabulate_fusion_se_t_tebd_cpu(963-972)tabulate_fusion_se_t_tebd_cpu(973-982)tabulate_fusion_se_t_tebd_grad_cpu(593-641)tabulate_fusion_se_t_tebd_grad_cpu(593-602)tabulate_fusion_se_t_tebd_grad_cpu(983-993)tabulate_fusion_se_t_tebd_grad_cpu(994-1004)tabulate_fusion_se_t_tebd_grad_grad_cpu(644-692)tabulate_fusion_se_t_tebd_grad_grad_cpu(644-654)tabulate_fusion_se_t_tebd_grad_grad_cpu(1005-1015)tabulate_fusion_se_t_tebd_grad_grad_cpu(1016-1026)
source/op/pt/tabulate_multi_device.cc (2)
source/lib/src/gpu/tabulate.cu (12)
tabulate_fusion_se_t_tebd_gpu(1058-1078)tabulate_fusion_se_t_tebd_gpu(1058-1066)tabulate_fusion_se_t_tebd_gpu(1393-1401)tabulate_fusion_se_t_tebd_gpu(1403-1411)tabulate_fusion_se_t_tebd_grad_gpu(1081-1104)tabulate_fusion_se_t_tebd_grad_gpu(1081-1090)tabulate_fusion_se_t_tebd_grad_gpu(1413-1423)tabulate_fusion_se_t_tebd_grad_gpu(1425-1435)tabulate_fusion_se_t_tebd_grad_grad_gpu(1107-1132)tabulate_fusion_se_t_tebd_grad_grad_gpu(1107-1116)tabulate_fusion_se_t_tebd_grad_grad_gpu(1437-1447)tabulate_fusion_se_t_tebd_grad_grad_gpu(1449-1459)source/lib/src/tabulate.cc (12)
tabulate_fusion_se_t_tebd_cpu(545-590)tabulate_fusion_se_t_tebd_cpu(545-553)tabulate_fusion_se_t_tebd_cpu(963-972)tabulate_fusion_se_t_tebd_cpu(973-982)tabulate_fusion_se_t_tebd_grad_cpu(593-641)tabulate_fusion_se_t_tebd_grad_cpu(593-602)tabulate_fusion_se_t_tebd_grad_cpu(983-993)tabulate_fusion_se_t_tebd_grad_cpu(994-1004)tabulate_fusion_se_t_tebd_grad_grad_cpu(644-692)tabulate_fusion_se_t_tebd_grad_grad_cpu(644-654)tabulate_fusion_se_t_tebd_grad_grad_cpu(1005-1015)tabulate_fusion_se_t_tebd_grad_grad_cpu(1016-1026)
source/lib/src/gpu/tabulate.cu (1)
source/lib/src/tabulate.cc (2)
locate_xx_se_t(45-73)locate_xx_se_t(45-52)
deepmd/pt/model/descriptor/se_t_tebd.py (5)
source/lib/include/tabulate.h (1)
deepmd(4-293)deepmd/pt/utils/tabulate.py (1)
DPTabulate(30-441)deepmd/pt/model/descriptor/se_t.py (2)
enable_compression(284-330)enable_compression(743-770)deepmd/pt/utils/utils.py (1)
ActivationFn(175-220)source/op/pt/tabulate_multi_device.cc (2)
tabulate_fusion_se_t_tebd(1203-1211)tabulate_fusion_se_t_tebd(1203-1208)
🪛 Ruff (0.13.3)
deepmd/pt/model/descriptor/se_t_tebd.py
552-552: Avoid specifying long messages outside the exception class
(TRY003)
572-574: Avoid specifying long messages outside the exception class
(TRY003)
577-577: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## devel #4992 +/- ##
==========================================
- Coverage 84.28% 84.23% -0.06%
==========================================
Files 708 709 +1
Lines 69909 70074 +165
Branches 3608 3618 +10
==========================================
+ Hits 58923 59027 +104
- Misses 9830 9880 +50
- Partials 1156 1167 +11 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
iProzd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Plz remove all the refactor of variable/func names, sush as residual/xx/vv etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
deepmd/utils/tabulate.py (1)
306-313: Off-by-one in stride window selection (T/T_TEBD) corrupts spline segments.The
+ 1at line 311 incorrectly shifts the fine-stride window by one index, causing the stride assignment to misalign with the actual grid regions. This corrupts the spline segment boundaries.Apply this diff to remove the off-by-one error:
elif self.descrpt_type in ("T", "T_TEBD"): tt = np.full((nspline, self.last_layer_size), stride1) # pylint: disable=no-explicit-dtype - start_index = int((lower - extrapolate * lower) / stride1) + 1 + start_index = int((lower - extrapolate * lower) / stride1) end_index = start_index + int((upper - lower) / stride0) tt[start_index:end_index, :] = stride0
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/utils/tabulate.py(6 hunks)deepmd/utils/tabulate.py(5 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pt/utils/tabulate.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (AGENTS.md)
Always run
ruff check .andruff format .before committing changes to Python code
Files:
deepmd/utils/tabulate.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
🔇 Additional comments (2)
deepmd/utils/tabulate.py (2)
433-433: LGTM: Correctly adds T_TEBD to single-table descriptors.T_TEBD uses a single shared geometric network, so
table_size = 1is appropriate and consistent with the shared-table implementation in lines 200-241.
487-487: LGTM: Correctly treats T_TEBD like T for environment matrix range.T_TEBD uses the same angular (cosine) features as the T descriptor, so applying the same range calculation is appropriate.
njzjz
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add C++ tests like
https://github.com/deepmodeling/deepmd-kit/blob/devel/source/lib/tests/test_tabulate_se_t.cc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
source/lib/tests/test_tabulate_se_t_tebd.cc(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
njzjz
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update the documentation if the compression has been supported in this PR.
deepmd-kit/doc/model/train-se-e3-tebd.md
Lines 84 to 86 in 6f00250
| ## Model compression | |
| Model compression is not supported. |
…abulate_fusion_se_t_tebd custom OP but not used and tested
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
…sts since tabulate_fusion_se_t_tebd_grad_gpu is on CPU
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
deepmd/pt/utils/tabulate.py (1)
311-314: Off‑by‑one bug in fine‑stride window (T/T_TEBD).start_index + 1 shifts the fine region and corrupts spline segments. Remove +1 and clip indices.
Apply:
- start_index = int((lower - extrapolate * lower) / stride1) + 1 - end_index = start_index + int((upper - lower) / stride0) - tt[start_index:end_index, :] = stride0 + start_index = int((lower - extrapolate * lower) / stride1) + end_index = start_index + int((upper - lower) / stride0) + start_index = max(0, start_index) + end_index = min(nspline, end_index) + tt[start_index:end_index, :] = stride0deepmd/utils/tabulate.py (1)
405-405: Avoid @lru_cache on instance method and fix lint errors
- Remove/refactor
_all_excludedat deepmd/utils/tabulate.py:405 to eliminate@lru_cacheon an instance method and prevent memory leaks.- Encapsulate long
RuntimeErrormessages into custom exception classes or constants for raises at lines 280, 315, 446, and 495 (resolves TRY003).- Rerun:
ruff check deepmd/utils/tabulate.py ruff format --check deepmd/utils/tabulate.py
♻️ Duplicate comments (5)
deepmd/pt/utils/tabulate.py (2)
154-155: Use env.DEVICE for device consistency.Prefer env.DEVICE over yy.device to stay consistent across the module.
Apply:
- ) + torch.ones((1, 1), dtype=yy.dtype, device=yy.device) + ) + torch.ones((1, 1), dtype=yy.dtype, device=env.DEVICE)
173-174: Use env.DEVICE for device consistency.Prefer env.DEVICE over yy.device.
Apply:
- ) + torch.ones((1, 2), dtype=yy.dtype, device=yy.device) + ) + torch.ones((1, 2), dtype=yy.dtype, device=env.DEVICE)deepmd/pt/model/descriptor/se_t_tebd.py (1)
1065-1080: ParameterList requires nn.Parameter; current assignment will error.Assigning a plain Tensor to nn.ParameterList raises TypeError. Wrap both in nn.Parameter with requires_grad=False.
Apply:
- self.compress_info[0] = torch.as_tensor( + self.compress_info[0] = nn.Parameter(torch.as_tensor( [ lower[net_key], upper[net_key], upper[net_key] * table_config[0], table_config[1], table_config[2], table_config[3], ], dtype=self.prec, device="cpu", - ) - self.compress_data[0] = table_data[net_key].to( - device=env.DEVICE, dtype=self.prec - ) + ), requires_grad=False) + self.compress_data[0] = nn.Parameter( + table_data[net_key].to(device=env.DEVICE, dtype=self.prec), + requires_grad=False, + )Alternatively, replace ParameterList with a plain list if training exclusion is desired without registering parameters.
source/lib/tests/test_tabulate_se_t_tebd.cc (1)
4-6: Include for fabs.Without you may get “fabs was not declared”. Include the header.
Apply:
#include <gtest/gtest.h> +#include <cmath> #include <iostream> #include <vector>source/lib/src/gpu/tabulate.cu (1)
106-136: Remove duplicate locator; reuse locate_xx_se_tlocate_xx_se_t_tebd duplicates locate_xx_se_t. Drop it and call locate_xx_se_t at TEBD call sites to reduce maintenance.
-// same with locate_xx_se_t -template <typename FPTYPE> -__forceinline__ __device__ void locate_xx_se_t_tebd(FPTYPE& xx, - int& table_idx, - const FPTYPE& lower, - const FPTYPE& upper, - const FPTYPE& min, - const FPTYPE& max, - const FPTYPE& stride0, - const FPTYPE& stride1) { - ... -}And replace calls:
- locate_xx_se_t_tebd(xx, table_idx, lower, upper, -max, max, stride0, stride1); + locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);Apply at Lines 698–699, 753–755, and 822–824.
🧹 Nitpick comments (7)
doc/model/train-se-e3-tebd.md (1)
86-86: Clarify compression constraints (backend/mode).State that compression is supported only when tebd_input_mode == "strip" and requires the PT custom ops (ENABLE_CUSTOMIZED_OP) built; clarify other backends (JAX/DP) if not yet supported.
deepmd/pt/utils/tabulate.py (1)
288-299: Type hint for extrapolate is incorrect.extrapolate is a float, not bool. Update the annotation.
Apply:
- extrapolate: bool, + extrapolate: float,deepmd/utils/tabulate.py (1)
289-299: Type hint for extrapolate is incorrect.extrapolate is a float. Fix the annotation.
- extrapolate: bool, + extrapolate: float,deepmd/pt/model/descriptor/se_t_tebd.py (1)
527-603: Docs note: expose compression constraints.Consider surfacing in class docstring that compression requires tebd_input_mode == "strip" and custom ops built; mirrors the doc page change.
source/op/pt/tabulate_multi_device.cc (2)
433-480: Move compression guard before kernel launch (fail fast)The TORCH_CHECK(last_layer_size <= 1024) runs after the GPU call. Move it before launching the kernel to avoid wasted work and earlier erroring (same applies to SeT/SeR for consistency).
if (device == "GPU") { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - deepmd::tabulate_fusion_se_t_tebd_grad_grad_gpu( - dz_dy, table, table_info, em_x, em, dz_dy_dem_x, nloc, nnei_i, nnei_j, - last_layer_size); -#else - throw std::runtime_error( - "The input tensor is on the GPU, but the GPU support for the " - "customized OP library is not enabled."); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - TORCH_CHECK(last_layer_size <= 1024, - "In the process of model compression, the size of the " - "last layer of embedding net must be less than 1024!"); + TORCH_CHECK(last_layer_size <= 1024, + "In the process of model compression, the size of the " + "last layer of embedding net must be less than 1024!"); +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + deepmd::tabulate_fusion_se_t_tebd_grad_grad_gpu( + dz_dy, table, table_info, em_x, em, dz_dy_dem_x, nloc, nnei_i, nnei_j, + last_layer_size); +#else + throw std::runtime_error( + "The input tensor is on the GPU, but the GPU support for the " + "customized OP library is not enabled."); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } else if (device == "CPU") {
1231-1233: Consolidate TORCH_LIBRARY_FRAGMENT blocks to reduce static init noiseMultiple fragments create separate static initializers and have triggered “Unused static function … TORCH_LIBRARY_FRAGMENT_init_deepmd_*” in past scans. Consider combining defs into a single fragment in this TU.
source/lib/src/gpu/tabulate.cu (1)
1146-1238: GPU wrappers OK; consider avoiding unconditional device syncsWrappers are correct. Optionally remove pre/post gpuDeviceSynchronize in production to improve overlap and throughput; rely on stream semantics and post‑kernel error checks instead.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
deepmd/pt/model/descriptor/se_t_tebd.py(10 hunks)deepmd/pt/utils/tabulate.py(6 hunks)deepmd/utils/tabulate.py(5 hunks)doc/model/train-se-e3-tebd.md(1 hunks)source/lib/include/tabulate.h(2 hunks)source/lib/src/gpu/tabulate.cu(5 hunks)source/lib/src/tabulate.cc(2 hunks)source/lib/tests/test_tabulate_se_t_tebd.cc(1 hunks)source/op/pt/tabulate_multi_device.cc(4 hunks)source/tests/pt/test_tabulate_fusion_se_t_tebd.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (AGENTS.md)
Always run
ruff check .andruff format .before committing changes to Python code
Files:
deepmd/utils/tabulate.pydeepmd/pt/model/descriptor/se_t_tebd.pydeepmd/pt/utils/tabulate.pysource/tests/pt/test_tabulate_fusion_se_t_tebd.py
🧬 Code graph analysis (7)
source/lib/tests/test_tabulate_se_t_tebd.cc (1)
source/lib/src/gpu/tabulate.cu (8)
tabulate_fusion_se_t_tebd_gpu(1148-1175)tabulate_fusion_se_t_tebd_gpu(1148-1156)tabulate_fusion_se_t_tebd_gpu(1498-1506)tabulate_fusion_se_t_tebd_gpu(1508-1516)tabulate_fusion_se_t_tebd_grad_gpu(1178-1205)tabulate_fusion_se_t_tebd_grad_gpu(1178-1187)tabulate_fusion_se_t_tebd_grad_gpu(1518-1528)tabulate_fusion_se_t_tebd_grad_gpu(1530-1540)
deepmd/pt/model/descriptor/se_t_tebd.py (7)
source/lib/include/tabulate.h (1)
deepmd(4-293)deepmd/pt/utils/tabulate.py (1)
DPTabulate(30-441)deepmd/pt/model/descriptor/se_t.py (2)
enable_compression(284-330)enable_compression(743-770)deepmd/pt/utils/utils.py (1)
ActivationFn(175-220)deepmd/pt/model/descriptor/se_atten.py (1)
enable_compression(427-448)deepmd/utils/tabulate.py (1)
build(70-285)source/op/pt/tabulate_multi_device.cc (2)
tabulate_fusion_se_t_tebd(1203-1211)tabulate_fusion_se_t_tebd(1203-1208)
source/lib/include/tabulate.h (2)
source/lib/src/gpu/tabulate.cu (25)
void(39-48)void(197-284)void(287-394)void(397-503)void(506-546)void(549-610)void(613-663)void(667-721)void(725-785)void(789-851)void(854-885)void(888-930)void(933-972)tabulate_fusion_se_t_tebd_gpu(1148-1175)tabulate_fusion_se_t_tebd_gpu(1148-1156)tabulate_fusion_se_t_tebd_gpu(1498-1506)tabulate_fusion_se_t_tebd_gpu(1508-1516)tabulate_fusion_se_t_tebd_grad_gpu(1178-1205)tabulate_fusion_se_t_tebd_grad_gpu(1178-1187)tabulate_fusion_se_t_tebd_grad_gpu(1518-1528)tabulate_fusion_se_t_tebd_grad_gpu(1530-1540)tabulate_fusion_se_t_tebd_grad_grad_gpu(1208-1237)tabulate_fusion_se_t_tebd_grad_grad_gpu(1208-1217)tabulate_fusion_se_t_tebd_grad_grad_gpu(1542-1552)tabulate_fusion_se_t_tebd_grad_grad_gpu(1554-1564)source/lib/src/tabulate.cc (12)
tabulate_fusion_se_t_tebd_cpu(545-590)tabulate_fusion_se_t_tebd_cpu(545-553)tabulate_fusion_se_t_tebd_cpu(963-972)tabulate_fusion_se_t_tebd_cpu(973-982)tabulate_fusion_se_t_tebd_grad_cpu(593-641)tabulate_fusion_se_t_tebd_grad_cpu(593-602)tabulate_fusion_se_t_tebd_grad_cpu(983-993)tabulate_fusion_se_t_tebd_grad_cpu(994-1004)tabulate_fusion_se_t_tebd_grad_grad_cpu(644-692)tabulate_fusion_se_t_tebd_grad_grad_cpu(644-654)tabulate_fusion_se_t_tebd_grad_grad_cpu(1005-1015)tabulate_fusion_se_t_tebd_grad_grad_cpu(1016-1026)
deepmd/pt/utils/tabulate.py (1)
deepmd/pt/model/descriptor/se_t_tebd.py (1)
DescrptSeTTebd(78-602)
source/op/pt/tabulate_multi_device.cc (2)
source/lib/src/gpu/tabulate.cu (12)
tabulate_fusion_se_t_tebd_gpu(1148-1175)tabulate_fusion_se_t_tebd_gpu(1148-1156)tabulate_fusion_se_t_tebd_gpu(1498-1506)tabulate_fusion_se_t_tebd_gpu(1508-1516)tabulate_fusion_se_t_tebd_grad_gpu(1178-1205)tabulate_fusion_se_t_tebd_grad_gpu(1178-1187)tabulate_fusion_se_t_tebd_grad_gpu(1518-1528)tabulate_fusion_se_t_tebd_grad_gpu(1530-1540)tabulate_fusion_se_t_tebd_grad_grad_gpu(1208-1237)tabulate_fusion_se_t_tebd_grad_grad_gpu(1208-1217)tabulate_fusion_se_t_tebd_grad_grad_gpu(1542-1552)tabulate_fusion_se_t_tebd_grad_grad_gpu(1554-1564)source/lib/src/tabulate.cc (12)
tabulate_fusion_se_t_tebd_cpu(545-590)tabulate_fusion_se_t_tebd_cpu(545-553)tabulate_fusion_se_t_tebd_cpu(963-972)tabulate_fusion_se_t_tebd_cpu(973-982)tabulate_fusion_se_t_tebd_grad_cpu(593-641)tabulate_fusion_se_t_tebd_grad_cpu(593-602)tabulate_fusion_se_t_tebd_grad_cpu(983-993)tabulate_fusion_se_t_tebd_grad_cpu(994-1004)tabulate_fusion_se_t_tebd_grad_grad_cpu(644-692)tabulate_fusion_se_t_tebd_grad_grad_cpu(644-654)tabulate_fusion_se_t_tebd_grad_grad_cpu(1005-1015)tabulate_fusion_se_t_tebd_grad_grad_cpu(1016-1026)
source/tests/pt/test_tabulate_fusion_se_t_tebd.py (2)
source/tests/consistent/common.py (1)
parameterized(580-640)source/op/pt/tabulate_multi_device.cc (2)
tabulate_fusion_se_t_tebd(1203-1211)tabulate_fusion_se_t_tebd(1203-1208)
source/lib/src/tabulate.cc (1)
source/lib/src/gpu/tabulate.cu (1)
locate_xx_se_t(76-83)
🪛 Ruff (0.14.0)
deepmd/pt/model/descriptor/se_t_tebd.py
552-552: Avoid specifying long messages outside the exception class
(TRY003)
572-574: Avoid specifying long messages outside the exception class
(TRY003)
577-577: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (15)
source/tests/pt/test_tabulate_fusion_se_t_tebd.py (1)
1852-1860: LGTM: uses TEBD op and validates shapes/values.source/lib/src/tabulate.cc (3)
544-592: LGTM: TEBD CPU forward preserves i×j×ng without reduction.Polynomial eval and indexing match locate_xx_se_t; zero-init sizing is correct.
593-641: LGTM: TEBD grad sums dres/dxx per output channel.Gradient accumulation aligns with forward’s per‑pair outputs.
643-692: LGTM: TEBD grad‑grad propagates via dz_dy_dem_x only.Consistent with forward not depending on em.
deepmd/pt/model/descriptor/se_t_tebd.py (1)
955-970: LGTM: Compression forward path calls TEBD tabulation op correctly..contiguous(), CPU table_info, and reshaping back to nfnl×nnei×nnei×ng look correct.
source/op/pt/tabulate_multi_device.cc (4)
338-385: TEBD forward path looks correctShapes, device dispatch, and CPU/GPU calls align with SeT/SeR patterns. No issues.
387-431: Grad path is consistent and minimalDim checks (dy is 4D), last_layer_size from descriptor, and dy_dem_x allocation are correct. Good use of contiguous grad.
1086-1169: Autograd op for TEBD is well‑structuredTyped dispatch, allocation, save_for_backward, and backward flow look good.
1203-1211: Python entry point wiring is correctExposes TEBD via apply, matching existing patterns.
source/lib/include/tabulate.h (2)
114-148: TEBD CPU API additions align with existing SeT interfacesSignatures, ordering, and namespace usage are consistent. Looks good.
258-292: TEBD GPU API declarations under proper guardsMatches SeT patterns and keeps ABI symmetry across float/double. LGTM.
source/lib/src/gpu/tabulate.cu (4)
665-721: TEBD forward kernel: indexing and grid‑stride pattern look correctIndex decomposition, table lookup, and output layout (nloc, nnei_i, nnei_j, last_layer_size) are consistent. Good.
723-786: TEBD grad kernel: correct accumulation and dy indexingChain rule accumulation and dy indexing align with forward layout. Looks good.
787-852: TEBD grad‑grad kernel: correct mapping back to (…, last_layer_size)Uses shared dz_dy_dem_x per pair and writes per feature; matches CPU implementation semantics. LGTM.
1497-1565: Explicit instantiations completeFloat/double instantiations for all TEBD GPU entry points provided. LGTM.
add compression support for se_e3_tebd <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added TEBD-style descriptor support, optional runtime compression, and a Python-callable tabulation operation with autograd. - **Performance Improvements** - Shared/global tabulation tables and shared embeddings to reduce redundant table builds and improve CPU/GPU throughput. - **Reliability** - Validation and guards around enabling compression and runtime dispatch to prevent misconfiguration. - **Tests** - New CPU/GPU unit and integration tests for forward/backward TEBD paths. - **Documentation** - Docs updated to state model compression is supported. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
add compression support for se_e3_tebd
Summary by CodeRabbit
New Features
Performance Improvements
Reliability
Tests
Documentation