Introduce PRNG to SimState and add reproducibility docs.#460
Conversation
| c2 = torch.sqrt(kT * (1 - torch.square(c1))).unsqueeze(-1) | ||
|
|
||
| # Generate random noise from normal distribution | ||
| noise = torch.randn_like(state.momenta, device=state.device, dtype=state.dtype) |
There was a problem hiding this comment.
after pytorch/pytorch#165865 randn_like is in torch 2.10 but I am not sure we want to pin to 2.10 given not all the models people want to use will support.
There was a problem hiding this comment.
| weibull = torch.distributions.weibull.Weibull(scale=0.1, concentration=1) | ||
| rnd = torch.randn_like(sim_state.positions) | ||
| rnd = rnd / torch.norm(rnd, dim=-1, keepdim=True) | ||
| shifts = weibull.sample(rnd.shape).to(device=sim_state.positions.device) * rnd |
There was a problem hiding this comment.
avoid weibull.sample() as it cannot be seeded.
| # Generate random numbers | ||
| r1 = torch.randn(n_systems, device=device, dtype=dtype) | ||
| # Sample Gamma((dof - 1)/2, 1/2) = \sum_2^{dof} X_i^2 where X_i ~ N(0,1) | ||
| r2 = torch.distributions.Gamma((dof - 1) / 2, torch.ones_like(dof) / 2).sample() |
There was a problem hiding this comment.
avoid Gamma.sample() as it cannot be seeded.
|
|
||
| @staticmethod | ||
| def _clone_attr(value: object) -> object: | ||
| """Clone a single attribute value, handling torch.Generator specially.""" |
There was a problem hiding this comment.
forks have identical rng states.
There was a problem hiding this comment.
This should be fine, but I'm wondering whether there could be setup where one wants to clone state but still would need different prng (some kind of replica exchange)
Ok I'm thinking while writing, and it would be fine because if you batch your systems each system will get different random numbers.
|
|
||
|
|
||
| def calculate_momenta( | ||
| def initialize_momenta( |
There was a problem hiding this comment.
driveby: this was a misleading name.
|
LGTM. Note that setting deterministic mode can have performance penalties. Also, use of text vs binary file formats for restarting can have an effect. |
|
|
||
| # Generate atom-specific noise | ||
| noise = torch.randn_like(state.momenta) | ||
| rng = state.rng |
There was a problem hiding this comment.
Is there a particular reason to not directly use state.rng and define a rng var ?
There was a problem hiding this comment.
it only calls coerce once this way
| self._rng = coerce_prng(self._rng, self.device) | ||
| return self._rng | ||
|
|
||
| @rng.setter | ||
| def rng(self, value: int | torch.Generator | None) -> None: | ||
| self._rng = value |
There was a problem hiding this comment.
you probably want to coerce_prng only when setting the rng with a int value. Otherwise you're creating the object many time per MD step
There was a problem hiding this comment.
If it's a way to handle _rng=None, you may want to do that in post_init
There was a problem hiding this comment.
I wanted to leave _rng=None alone when not used because then if you did clone an unseeded state all the downstream simstates would diverge
There was a problem hiding this comment.
if isinstance(rng, torch.Generator):
if rng.device == device:
return rng
new = torch.Generator(device=device)
new.set_state(rng.get_state())
return newin coerce if it's already a generator and on the right device we don't make a new object. It does add the two ifs into the step cost but I can't imagine that is significant?
|
|
||
| @staticmethod | ||
| def _clone_attr(value: object) -> object: | ||
| """Clone a single attribute value, handling torch.Generator specially.""" |
There was a problem hiding this comment.
This should be fine, but I'm wondering whether there could be setup where one wants to clone state but still would need different prng (some kind of replica exchange)
Ok I'm thinking while writing, and it would be fine because if you batch your systems each system will get different random numbers.
- remove reintroduced seed kwargs from NVE/NVT/NPT init APIs and route seeding through state.rng again - restore initialize_momenta usage in integrator init paths and drop calculate_momenta helper reintroduced by #339 edits - update affected examples/tutorial and testing helper to use state-bound RNG behavior
The only messy bit is resumption as for serialization it seems that the only way to do it is with torch.save and I feel asking the user to store the pickle manually is awkward.
AI Overview
Every
SimStatenow carries an optional_rngfield (atorch.Generator) that controls all stochastic operations: momentum initialization, Langevin OU noise, V-Rescale Gamma draws, and C-Rescale barostat noise. No integrator init or step function accepts aseedorprngargument anymore — seeding is done exclusively through the state.The
rngproperty_rngisNone(the default), accessingstate.rngcreates a newtorch.Generatoron the state's device and stores it back. No Generator is allocated until first use._rngis anint, accessingstate.rngconverts it to a seeded Generator viacoerce_prng()and stores it back, so subsequent accesses return the same (advancing) Generator.torch.Generatorobject is stored, its internal state advances with each draw, giving a proper random stream rather than re-seeding every step.Cloning
state.clone()deep-copies the Generator viaget_state()/set_state(), producing an independent copy with identical initial RNG state. Drawing from one does not affect the other.Splitting
state.split()copies global attributes (including_rng) to every piece. All resulting single-system states share the same Generator value (copied), not the same object.Concatenating
concatenate_states([s1, s2, ...])takes global attributes from the first state. The resulting batch usess1's Generator; other states' Generators are discarded.Device movement
state.to(device)moves the Generator to the target device viacoerce_prng(), which creates a new Generator on the target device and copies the RNG state if devices differ.Serialisation
What changed
_rngmoved fromMDStatetoSimState(it's a global attribute, not MD-specific).seed=/prng=parameters removed from integrator init functions.initialize_momentatakesgenerator: torch.Generator | Nonedirectly.torch.distributions.Gamma(unseeded) totorch._standard_gamma(..., generator=rng)so it's now fully seedable._rattle_sim_stateintesting.pyrefactored to usestate.rnginstead of saving/restoring global RNG state.coerce_prnghandles cross-device Generator transfer._state_to_devicehandles Generator device movement.