Skip to content

Introduce PRNG to SimState and add reproducibility docs.#460

Merged
CompRhys merged 10 commits intomainfrom
prng-simstate-reproduce
Feb 27, 2026
Merged

Introduce PRNG to SimState and add reproducibility docs.#460
CompRhys merged 10 commits intomainfrom
prng-simstate-reproduce

Conversation

@CompRhys
Copy link
Member

@CompRhys CompRhys commented Feb 21, 2026

The only messy bit is resumption as for serialization it seems that the only way to do it is with torch.save and I feel asking the user to store the pickle manually is awkward.


AI Overview

Every SimState now carries an optional _rng field (a torch.Generator) that controls all stochastic operations: momentum initialization, Langevin OU noise, V-Rescale Gamma draws, and C-Rescale barostat noise. No integrator init or step function accepts a seed or prng argument anymore — seeding is done exclusively through the state.

The rng property

state.rng = 42          # int → coerced to a seeded Generator
state.rng = gen         # torch.Generator used directly
state.rng = None        # reset; next access creates an unseeded Generator
samples = state.rng     # lazily initialises if _rng is None/int, then returns it
  • Lazy: if _rng is None (the default), accessing state.rng creates a new torch.Generator on the state's device and stores it back. No Generator is allocated until first use.
  • Coercing: if _rng is an int, accessing state.rng converts it to a seeded Generator via coerce_prng() and stores it back, so subsequent accesses return the same (advancing) Generator.
  • Advancing: because a single torch.Generator object is stored, its internal state advances with each draw, giving a proper random stream rather than re-seeding every step.

Cloning

state.clone() deep-copies the Generator via get_state() / set_state(), producing an independent copy with identical initial RNG state. Drawing from one does not affect the other.

Splitting

state.split() copies global attributes (including _rng) to every piece. All resulting single-system states share the same Generator value (copied), not the same object.

Concatenating

concatenate_states([s1, s2, ...]) takes global attributes from the first state. The resulting batch uses s1's Generator; other states' Generators are discarded.

Device movement

state.to(device) moves the Generator to the target device via coerce_prng(), which creates a new Generator on the target device and copies the RNG state if devices differ.

Serialisation

torch.save(state.rng.get_state(), "rng.pt")           # save
gen = torch.Generator(device=state.device)
gen.set_state(torch.load("rng.pt"))
state.rng = gen                                        # restore

What changed

  • _rng moved from MDState to SimState (it's a global attribute, not MD-specific).
  • All seed= / prng= parameters removed from integrator init functions.
  • initialize_momenta takes generator: torch.Generator | None directly.
  • V-Rescale Gamma sampling switched from torch.distributions.Gamma (unseeded) to torch._standard_gamma(..., generator=rng) so it's now fully seedable.
  • _rattle_sim_state in testing.py refactored to use state.rng instead of saving/restoring global RNG state.
  • coerce_prng handles cross-device Generator transfer.
  • _state_to_device handles Generator device movement.

@CompRhys CompRhys added api API design discussions breaking Breaking changes keep-open PRs to be ignored by StaleBot labels Feb 21, 2026
@CompRhys CompRhys requested a review from thomasloux February 21, 2026 20:05
c2 = torch.sqrt(kT * (1 - torch.square(c1))).unsqueeze(-1)

# Generate random noise from normal distribution
noise = torch.randn_like(state.momenta, device=state.device, dtype=state.dtype)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after pytorch/pytorch#165865 randn_like is in torch 2.10 but I am not sure we want to pin to 2.10 given not all the models people want to use will support.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CompRhys CompRhys marked this pull request as ready for review February 21, 2026 20:35
weibull = torch.distributions.weibull.Weibull(scale=0.1, concentration=1)
rnd = torch.randn_like(sim_state.positions)
rnd = rnd / torch.norm(rnd, dim=-1, keepdim=True)
shifts = weibull.sample(rnd.shape).to(device=sim_state.positions.device) * rnd
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid weibull.sample() as it cannot be seeded.

# Generate random numbers
r1 = torch.randn(n_systems, device=device, dtype=dtype)
# Sample Gamma((dof - 1)/2, 1/2) = \sum_2^{dof} X_i^2 where X_i ~ N(0,1)
r2 = torch.distributions.Gamma((dof - 1) / 2, torch.ones_like(dof) / 2).sample()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid Gamma.sample() as it cannot be seeded.


@staticmethod
def _clone_attr(value: object) -> object:
"""Clone a single attribute value, handling torch.Generator specially."""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forks have identical rng states.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fine, but I'm wondering whether there could be setup where one wants to clone state but still would need different prng (some kind of replica exchange)
Ok I'm thinking while writing, and it would be fine because if you batch your systems each system will get different random numbers.



def calculate_momenta(
def initialize_momenta(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

driveby: this was a misleading name.

@abhijeetgangan
Copy link
Collaborator

LGTM. Note that setting deterministic mode can have performance penalties. Also, use of text vs binary file formats for restarting can have an effect.


# Generate atom-specific noise
noise = torch.randn_like(state.momenta)
rng = state.rng
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a particular reason to not directly use state.rng and define a rng var ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it only calls coerce once this way

Comment on lines +156 to +161
self._rng = coerce_prng(self._rng, self.device)
return self._rng

@rng.setter
def rng(self, value: int | torch.Generator | None) -> None:
self._rng = value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you probably want to coerce_prng only when setting the rng with a int value. Otherwise you're creating the object many time per MD step

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's a way to handle _rng=None, you may want to do that in post_init

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to leave _rng=None alone when not used because then if you did clone an unseeded state all the downstream simstates would diverge

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    if isinstance(rng, torch.Generator):
        if rng.device == device:
            return rng
        new = torch.Generator(device=device)
        new.set_state(rng.get_state())
        return new

in coerce if it's already a generator and on the right device we don't make a new object. It does add the two ifs into the step cost but I can't imagine that is significant?


@staticmethod
def _clone_attr(value: object) -> object:
"""Clone a single attribute value, handling torch.Generator specially."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fine, but I'm wondering whether there could be setup where one wants to clone state but still would need different prng (some kind of replica exchange)
Ok I'm thinking while writing, and it would be fine because if you batch your systems each system will get different random numbers.

@CompRhys CompRhys merged commit 48dcfb1 into main Feb 27, 2026
68 checks passed
@CompRhys CompRhys deleted the prng-simstate-reproduce branch February 27, 2026 13:21
janosh added a commit that referenced this pull request Mar 1, 2026
- remove reintroduced seed kwargs from NVE/NVT/NPT init APIs and route seeding through state.rng again
- restore initialize_momenta usage in integrator init paths and drop calculate_momenta helper reintroduced by #339 edits
- update affected examples/tutorial and testing helper to use state-bound RNG behavior
CompRhys pushed a commit that referenced this pull request Mar 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

api API design discussions breaking Breaking changes keep-open PRs to be ignored by StaleBot

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add a page in docs about reproducibility Add a seed for integrator step function to reproduce results Allow seeds to be set for individual batches

3 participants