Conversation
WalkthroughA new Changes
Sequence Diagram(s)sequenceDiagram
participant Dev as Developer
participant SimState as SimState (base class)
participant Subclass as SimState Subclass
Dev->>Subclass: Define subclass with tensor attribute
Subclass->>SimState: Triggers __init_subclass__
SimState->>SimState: Inspect type annotations
alt Attribute is torch.Tensor | None
SimState-->>Dev: Raise TypeError
else Attribute is valid
SimState->>Subclass: Complete initialization
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~15 minutes Possibly related PRs
Poem
Note 🔌 MCP (Model Context Protocol) integration is now available in Early Access!Pro users can now connect to remote MCP servers under the Integrations page to get reviews and chat conversations that understand additional development context. 📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (5)
🚧 Files skipped from review as they are similar to previous changes (5)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (45)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
7d1f3f2 to
570a2b2
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
torch_sim/state.py (2)
140-145: Address the TODO about system index validation reliability.The comment indicates uncertainty about the reliability of the consecutive integer validation logic. Consider implementing a more robust check or documenting the specific edge cases this might miss.
Would you like me to suggest a more reliable validation approach for ensuring system indices are unique consecutive integers starting from 0?
425-429: Complete the TODO about InitVar guidance.The comment suggests providing guidance about using
InitVarfor attributes with default values, but the implementation appears incomplete. Consider either implementing this feature or removing the TODO if it's no longer relevant.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
torch_sim/integrators/nvt.py(1 hunks)torch_sim/state.py(6 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
torch_sim/integrators/nvt.py (1)
torch_sim/integrators/npt.py (3)
npt_nose_hoover_init(1348-1494)npt_nose_hoover(899-1560)NPTNoseHooverState(807-896)
🪛 Ruff (0.12.2)
torch_sim/state.py
405-405: import should be at the top-level of a file
(PLC0415)
🔇 Additional comments (4)
torch_sim/state.py (3)
25-25: LGTM! Good refactoring approach.Moving to
init=Falsewith a custom constructor provides better control over initialization logic and validation.
83-83: Good design for handling optional system indices.The pattern of declaring
system_idxas a required tensor field while acceptingNonein the constructor and converting it to a zero tensor is excellent. This ensures the field is always initialized and avoids concatenation issues withNonevalues.Also applies to: 92-92, 134-139
405-405: Local import is appropriate here.The local import of
typinginside__init_subclass__is intentional and appropriate for this use case, likely to avoid circular imports or reduce module load time. The static analysis warning can be safely ignored.torch_sim/integrators/nvt.py (1)
392-392: Essential fix for system index propagation.Adding
system_idx=state.system_idxensures proper propagation of system indexing information to the Nose-Hoover state, which is critical for batched simulations. This change correctly aligns with the NPT implementation and the updatedSimStateinitialization logic.
2701b5c to
4c49f21
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
🔭 Outside diff range comments (2)
torch_sim/runners.py (1)
534-565: Critical inconsistency: dataclass requires tensors but None is still passed.The
StaticStatedataclass now requiresforcesandstressto betorch.Tensor, but the instantiation code at lines 563-564 still conditionally passesNonewhen the model doesn't compute these properties. This will cause runtime errors.Apply this diff to fix the issue by providing appropriate tensor defaults:
sub_state = StaticState( **vars(sub_state), energy=model_outputs["energy"], - forces=model_outputs["forces"] if model.compute_forces else None, - stress=model_outputs["stress"] if model.compute_stress else None, + forces=model_outputs.get("forces", torch.full_like(sub_state.positions, float('nan'))), + stress=model_outputs.get("stress", torch.full((sub_state.n_systems, 3, 3), float('nan'), device=sub_state.device, dtype=sub_state.dtype)), )Alternatively, the dataclass could be reverted to allow optional tensors if this behavior is intentional, but that would conflict with the PR's objectives.
tests/test_autobatching.py (1)
495-497: Fix undefined variable ‘state’ before first next_batch call.
stateis referenced before assignment on the initial call tonext_batch. Initialize it toNone(to fetch the first batch from the internal queue) before the loop.- all_completed_states, convergence_tensor = [], None - while True: - state, completed_states = batcher.next_batch(state, convergence_tensor) + all_completed_states, convergence_tensor = [], None + state = None # initialize: fetch first batch from internal queue + while True: + state, completed_states = batcher.next_batch(state, convergence_tensor)
♻️ Duplicate comments (1)
tests/test_autobatching.py (1)
451-459: Nice parametrize; this addresses the prior DRY feedback.This change eliminates duplicated tests while preserving both scenarios.
🧹 Nitpick comments (1)
tests/test_autobatching.py (1)
451-459: Optional: add ids to parameterization for clearer test output.This improves readability in pytest reports.
-@pytest.mark.parametrize( - "num_steps_per_batch", - [ - 5, # At 5 steps, not every state will converge before the next batch. - # This tests the merging of partially converged states with new states - # which has been a bug in the past. See https://github.com/Radical-AI/torch-sim/pull/219 - 10, # At 10 steps, all states will converge before the next batch - ], -) +@pytest.mark.parametrize( + "num_steps_per_batch", + [5, 10], + ids=["partial_convergence", "full_convergence"], +)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
tests/test_autobatching.py(2 hunks)tests/test_state.py(1 hunks)torch_sim/models/interface.py(1 hunks)torch_sim/optimizers.py(7 hunks)torch_sim/runners.py(1 hunks)torch_sim/state.py(6 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (3)
torch_sim/models/interface.py (1)
torch_sim/models/mattersim.py (1)
MatterSimModel(24-33)
tests/test_autobatching.py (1)
tests/test_optimizers.py (3)
test_fire_fixed_cell_unit_cell_consistency(785-879)test_unit_cell_fire_multi_batch(709-782)test_fire_optimization(113-177)
torch_sim/state.py (2)
torch_sim/integrators/md.py (1)
MDState(13-48)torch_sim/monte_carlo.py (1)
SwapMCState(22-36)
🔇 Additional comments (8)
torch_sim/state.py (5)
9-9: LGTM!The
typingimport is correctly added to support the new__init_subclass__method's type inspection functionality.
26-26: LGTM!Setting
init=Falseis correct when providing a custom__init__method, preventing conflicts with the auto-generated constructor.
84-84: LGTM!Changing
system_idxto strictlytorch.Tensoraligns with the PR objective of preventing concatenation issues with mixed tensor/None attributes. The optional behavior is preserved through the constructor parameter.
86-158: LGTM!The custom
__init__method correctly consolidates initialization and validation logic. Key improvements:
- Proper handling of optional
system_idxparameter with default to zeros tensor- Preserved device compatibility validation
- Maintained shape compatibility checks
- Correct cell dimension handling
The implementation maintains existing functionality while supporting the new tensor attribute restrictions.
401-426: LGTM!Excellent implementation of the subclass validation mechanism. The method correctly:
- Uses
typing.get_type_hints()for proper type inspection- Handles both
typing.Unionand Python 3.10+|union syntax- Provides clear error messaging explaining the concatenation issue
- Follows proper
__init_subclass__patterns withsuper()callThis effectively prevents the tensor concatenation issues described in the PR objectives.
tests/test_state.py (1)
648-658: LGTM!Excellent test coverage for the new
__init_subclass__validation mechanism. The test correctly:
- Uses
pytest.raisesto capture the expectedTypeError- Defines a subclass with the prohibited
torch.Tensor | Nonetype annotation- Validates the error message mentions the concatenation issue
- Ensures the restriction is properly enforced
tests/test_autobatching.py (1)
485-493: Align dtype in scatter_reduce buffer with forces dtype.On some PyTorch versions,
scatter_reducerequires matching dtypes between input and src. Usestate.forces.dtype(orstate.energy.dtype) instead of hardcoded float64 to avoid type promotion or runtime errors.- system_wise_max_force = torch.zeros( - state.n_systems, device=state.device, dtype=torch.float64 - ) + system_wise_max_force = torch.zeros( + state.n_systems, device=state.device, dtype=state.forces.dtype + )torch_sim/optimizers.py (1)
592-595: Initializing velocities/cell_velocities with NaN is the right call.This removes None from tensor attributes, unblocks concatenation, and the step functions zero them on first use. LGTM.
Also applies to: 867-869, 875-879, 1170-1173, 1178-1181
4c49f21 to
84d6750
Compare
|
I think this PR will not totally fix the issues. This is because inside def split_attr(
attr_value: torch.Tensor | None, split_sizes: list[int]
) -> list[torch.Tensor | None]:
return (
[None] * n_systems
if attr_value is None
else torch.split(attr_value, split_sizes)
) |
|
I would like to add more tests for the split and concatenate states. |
we should just merge his PR first. then this can be a follow-up PR which adds the extra checks. |
6755e8b to
3d23f7d
Compare
|
Thanks guys for acknowledging my PR and thanks to @curtischong for your additional work on top of it! |
|
Since a few ppl have already seen this PR it's probably best to add the extra tests in another PR. I'll merge this in. |
see #219
Summary
This is actually kinda a serious issue and I'll outline it here in a clear manner.
MD SimStates often track
velocity. But on the first iteration, the states do NOT have velocity - so they are currently initialized asnone.But once the optimizer gets going, these states end up having a
velocityattribute.The problem is how we concatenate SimStates. Inside the autobatcher, when some SimStates finish before others, we swap those finished states with fresh states. This means inside the entire SimState, we have some systems with velocity set to
none(since they were just swapped in and are fresh) and other systems with a set velocity.When we concatenate these "mixed" SimStates (during the optimization process), we do
torch.concatenate([torch.Tensor, none, none]). Where the first system's velocity exists (because it's a torch.Tensor, and the last 2 systems do NOT have a velocity - since they were just swapped in by the autobatcher.PyTmorch cannot concatenate this because we're passing in
noneas an input which is invalid.@t-reents 's solution works pretty well and is valid (which is why I'm touching it up in this PR). His solution is: "rather than initializing
vectorattributes asnone, we initiliaze it asnanso we can do torch.concatentate between states that are old, and states that have just been swapped in.What's in this PR?
This PR is an addition to @t-reents 's contribution. I added validation logic to ensure that all subclasses of SimState will validate to ensure that no
torch.Tensorobject can be| Nonebecause it breaksconcatenatebetween SimStates.I think @t-reents 's solution works fine for the Fire Optimizers since by setting velocities to 0, they do not contribute to teh power calculation of the next iteration.
Checklist
Before a pull request can be merged, the following items must be checked:
Run ruff on your code.
We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run
pip install pre-commit && pre-commit installto install the hooks which will check your code before each commit.Summary by CodeRabbit
Bug Fixes
None, causing errors during tensor operations. All relevant tensor attributes are now required to always be present.Tests