Skip to content

Determinism support 1/N#1281

Open
mar-yan24 wants to merge 13 commits into
google-deepmind:mainfrom
mar-yan24:mark/determinism1
Open

Determinism support 1/N#1281
mar-yan24 wants to merge 13 commits into
google-deepmind:mainfrom
mar-yan24:mark/determinism1

Conversation

@mar-yan24

Copy link
Copy Markdown
Contributor

Add opt.deterministic flag with post-narrowphase contact sort (#562)

I was previously working on differentiation support for MJWarp but I am taking a break from that because the contacts are giving me a hard time. I can't seem to figure out how to optimize the gradient landscape while keeping good dynamics from rigid contact and coulombic friction. Thus, I have decided spending some time on this would be of more use for now lol.

Summary

This is one of several phased additions. This is a basic PR that just adds an opt-in opt.deterministic flag that sorts contacts after narrowphase by (worldid, geom0, geom1, geomcollisionid) using wp.utils.radix_sort_pairs. This fixes the most upstream source of run-to-run non-determinism on GPU: contact index permutation from atomic_add counters in narrowphase and CCD. After sorting, d.contact.* is rewritten in canonical order before any downstream kernel reads it.

Downstream state (qacc, qvel, qpos, constraint force, solver reductions) is not yet bitwise reproducible. Follow-ups needed, see Roadmap below.

Changes

  • types.py: Option.deterministic: bool (default False). Docstring notes phase 1 scope and ~5-10% overhead.
  • io.py: Wires the default in put_model, adds the field to override_model so opt.deterministic=True works from the CLI.
  • collision_driver.py: _sort_contacts() runs after _narrowphase() when the flag is set. Composite 32-bit key: ((world * ngeom + geom0) * ngeom + geom1) * gcid_max + gcid. Falls back to gcid_max = 1 on int32 overflow. Three gather-permute kernels rewrite d.contact.* from temp buffers.
  • determinism_test.py: 8 parameterized tests -> contact ordering, field bitwise equality across repeat runs, sort key monotonicity, default-false smoke check.

Test results

8/8 pass on RTX 4060 Laptop (sm_89, Ada Lovelace), Warp 1.13.0.dev20260302:

test_contact_ordering_deterministic[collision.xml, nworld=1]   PASSED
test_contact_ordering_deterministic[collision.xml, nworld=4]   PASSED
test_contact_ordering_deterministic[humanoid.xml, nworld=1]    PASSED
test_contact_ordering_deterministic[humanoid.xml, nworld=4]    PASSED
test_contact_fields_deterministic[collision.xml, nworld=1]     PASSED
test_contact_fields_deterministic[humanoid.xml, nworld=1]      PASSED
test_contacts_sorted_by_geom                                   PASSED
test_deterministic_flag_default_false                          PASSED

Coverage: contact geom arrays bitwise identical across 3 runs x 10 steps at two nworld sizes. All contact fields (dist, pos, frame, dim, worldid, geomcollisionid) bitwise identical. Sort key monotonicity verified. Default False confirmed (no cost unless opted in).

Benchmarks

I had claude help me formulate some benchmarks to see the potential overhead with this implementation. 3 trials x 500 steps, 50-step warmup, wp.synchronize() fences around the timing window.

Newton + Dense, RTX 4060 Laptop (sm_89)

model nworld nacon off (us/step) on (us/step) overhead
humanoid.xml 1 7 3,459 3,903 +12.8%
humanoid.xml 64 448 3,570 4,006 +12.2%
humanoid.xml 512 3,584 3,866 4,550 +17.7%
collision.xml 1 6 4,185 4,449 +6.3%
collision.xml 64 384 4,219 4,610 +9.3%
collision.xml 512 3,072 5,258 5,829 +10.9%

CG + Sparse, RTX 4060 Laptop (sm_89)

model nworld nacon off (us/step) on (us/step) overhead
humanoid.xml 1 7 6,445 6,858 +6.4%
humanoid.xml 64 413 7,205 7,612 +5.6%
humanoid.xml 512 3,385 7,245 7,600 +4.9%
collision.xml 1 6 4,505 4,930 +9.4%
collision.xml 64 384 5,037 5,399 +7.2%
collision.xml 512 3,072 6,342 6,833 +7.7%

All configs under 25% overhead. Worst case is +17.7% (humanoid nworld=512, Newton+Dense); actually one trial in that config hit +28.9% but had 208 ms stdev vs ~65 ms for adjacent configs. Im pretty sure that is likely thermal throttling on my crappy laptop lol.

Overhead % is roughly flat across nworld within each solver path. The bottleneck is the 17 wp.empty_like calls in _sort_contacts , not the GPU sort itself. I am planning on implementing pre-allocated scratch buffers and will fix this in a follow-up, let me know thoughts.

Roadmap

Full reproducibility obviously needs more phases:

My rough plan at the moment is to work on constraint row allocation next, this is probably what will help open up downstream effects. After that I will work on actuator moment allocation. Both of these will be done using prefix-sum.

The biggest fix later will be implementing solver reductions, i.e. cost, grad_dot, search_dot. This should make d.qacc bitwise stable and thus follows qpos and qvel as well.

This current PR does not make simulation bitwise reproducible end to end. It guarantees only that d.contact.* is stable across runs of the same input. End to end full state reproducibility will probably come after some more phases are released.

@thowell thowell self-requested a review April 14, 2026 21:59
@thowell thowell linked an issue Apr 14, 2026 that may be closed by this pull request
@thowell

thowell commented Apr 15, 2026

Copy link
Copy Markdown
Collaborator

@mar-yan24 thank you for contributing this feature to mujoco warp!

Comment thread mujoco_warp/_src/types.py Outdated
@thowell

thowell commented Apr 15, 2026

Copy link
Copy Markdown
Collaborator

@mar-yan24 fyi there is a warp draft pr for introducing determinism in warp NVIDIA/warp#1355

@erikfrey

Copy link
Copy Markdown
Collaborator

We just discussed this - it could be worth pursuing this approach in parallel to Warp's low-level support for determinism as they two different approaches may have different performance tradeoffs.

@mar-yan24

Copy link
Copy Markdown
Contributor Author

@thowell, thanks for the input! Actually I haven't kept up with Warp as closely recently so I'll take a look at the PR brought up there and see if there are similar ideas compared to what I have in my current plan.

Regarding @erikfrey's comment, I don't mind working on the rest of the determinism implementation for this PR and comparing the performance once finished. I'll probably continue working on this for the week and I'll try to finish by around a week from now for the full end-to-end implementation.

Thank you both for the info/updates!

Comment thread mujoco_warp/_src/collision_driver.py Outdated
Comment thread mujoco_warp/_src/collision_driver.py
Comment thread mujoco_warp/_src/collision_driver.py
Comment thread mujoco_warp/_src/determinism_test.py Outdated
Comment thread mujoco_warp/_src/determinism_test.py
Comment thread mujoco_warp/_src/collision_driver.py Outdated
@mar-yan24

Copy link
Copy Markdown
Contributor Author

Thanks for the review @thowell! The changes should be good to go. I am planning the next determinism steps after this PR like constraint row allocation and actuator moment allocation. Before I continue building, would you prefer that I keep extending this branch/PR so the work is all on this PR or split it up into separate requests for review. Either works for me.

@thowell

thowell commented Apr 17, 2026

Copy link
Copy Markdown
Collaborator

@mar-yan24 lets create separate prs for the next deterministic features. thanks!

@mar-yan24 mar-yan24 marked this pull request as ready for review April 17, 2026 23:56
@thowell

thowell commented Apr 23, 2026

Copy link
Copy Markdown
Collaborator

@mar-yan24 can we benchmark the performance of this pr with the built-in determinism from warp NVIDIA/warp#1355? probably makes sense to confirm that writing custom deterministic kernels is more performant compared to the general purpose warp determinism functionality. thanks!

@mar-yan24

Copy link
Copy Markdown
Contributor Author

@thowell I just tried running the NVIDIA/warp#1355 build on my machine and it seems there are several issues with it that currently make it incompatible with mujoco_warp. When I first tried running, it just crashed, so I disabled graph capture. I think the Warp PR crashes inside wp.ScopedCapture when mujoco_warp launches kernels, discussed below.

The PR's codegen looks up the destination array for each atomic_add by matching the kernel's top-level argument names (context.py:7706). That works for wp.atomic_add(arr, idx, value) but not for the sliced pattern wp.atomic_add(arr[outer_idx], inner_idx, value) that we use in other kernels.

I don't mind helping try and fix this/look into this deeper, but I suspect there may be other blockers as well. In the meantime, I can draft a minimal repro kernel for the PR? Let me know your thoughts.

@johnnynunez

Copy link
Copy Markdown

hello guys,
how is it going this?

@mar-yan24

Copy link
Copy Markdown
Contributor Author

Hey @johnnynunez I took a bit of a break on some of the projects I was working on as I started a new job recently. Alongside this, I am also waiting in part to take a look at the final implementation of determinism on the Warp side before doing too much more as I want to wait and see what the reach of the implementation there will be.

If you would like to contribute and help out in what I have here, be my guest. But I might take another week or two before I have enough time to pick things back up :)

@johnnynunez

Copy link
Copy Markdown

Hey @johnnynunez I took a bit of a break on some of the projects I was working on as I started a new job recently. Alongside this, I am also waiting in part to take a look at the final implementation of determinism on the Warp side before doing too much more as I want to wait and see what the reach of the implementation there will be.

If you would like to contribute and help out in what I have here, be my guest. But I might take another week or two before I have enough time to pick things back up :)

thank you... let me push from Nvidia side (warp)

@johnnynunez johnnynunez mentioned this pull request Jun 10, 2026
3 tasks
@johnnynunez

Copy link
Copy Markdown

Following up here after working through the warp side (NVIDIA/warp#1355, fix PR mmacklin/warp#3) — continuing this work as @mar-yan24 invited contributors.

Rebased branch available: I rebased this PR onto current main (96 commits forward, conflicts in io.py override fields and types.py ls_parallel deprecation): https://github.com/johnnynunez/mujoco_warp/tree/det/mjwarp-determinism1 — full test suite passes (1069 passed, 23 skipped) on RTX PRO 6000 Blackwell (sm_120, CUDA 13.3). @mar-yan24 happy to open a PR into your branch or for the team to take it directly, whichever is easier.

Benchmark: custom deterministic kernels vs warp automatic determinism (@thowell asked about this). 20-body contact-rich pile, CUDA graph captured, 1000 steps:

mode 1 world 64 worlds 256 worlds
baseline (main) 2589 steps/s 112k world-steps/s 395k world-steps/s
this PR's opt.deterministic 2356 steps/s (−9%) 99k (−12%) 359k (−9%)
warp RUN_TO_RUN auto 857 steps/s (−67%) 10.0k (−91%) 9.3k (−98%)

The custom-kernel approach in this PR costs ~9-12%; warp's automatic interception costs 3x at 1 world and collapses at high world counts (the per-launch sort-reduce dominates, and scatter buffers hit the int32 allocation limit at 1024 worlds × njmax=2048). This matches @eric-heiden's observation on the warp PR that automatic mode is too slow for MuJoCo and custom kernels are the right path — i.e., this PR's approach is the right architecture.

Determinism measurements on the rebased branch (same-process, bit-identical full-Data resets, 200 steps): contact/constraint assembly is bitwise stable, and single-step replay is fully bitwise. Long rollouts still split at ~step 115 (also ~108 on main with and without any determinism flags; first divergent arrays are solver-internal efc.Ma/qacc with bit-identical inputs and equal solver_niter). So the residual nondeterminism lives in non-atomic solver internals (likely the blocked Cholesky / mathdx path) — consistent with this PR's stated scope of fixing contact ordering first, with solver reductions as follow-up (#1300 and beyond).

@johnnynunez

Copy link
Copy Markdown

Root cause of the residual long-rollout drift found — correcting my earlier guess (it is not mathdx/blocked-Cholesky; it's two racing float atomics in the sparse solver path).

Method: bit-identical full-Data resets in one process, replaying 200-500 steps, replacing solver sub-phases one at a time with order-deterministic equivalents and checking the first divergent step (~108 on main for a 20-body pile).

Eliminations (each alone does NOT fix it): serial sparse LDL factorization, serial sparse LDL backsubstitution, forced full H rebuild, tolerance=0, CG vs Newton (CG drifts identically → Cholesky/mathdx exonerated), warp deterministic_max_records, zeroed workspaces.

The minimal fixing pair (drift gone over 500 steps, 3/3 runs):

  1. _update_constraint_init_qfrc_constraint_sparsewp.atomic_add(qfrc_constraint[worldid], colind, J*force) races across efc rows hitting the same dof; replaced with a per-(world,dof) gather in efc index order.
  2. _JTDAJ_sparse (+ _update_gradient_h_incremental_sparse) — wp.atomic_add(h[worldid,row], col, ...) races across efc rows touching the same (row,col); replaced with a per-upper-tri-element gather.

A third, smaller source remains in linesearch-adjacent atomics (Jaref multi-thread init / search_dot): with only the pair above fixed, onset moves from ~108 → ~197; adding warp's RUN_TO_RUN interception for the rest gives bitwise determinism over 500 steps. Consistent with this: dense-jacobian mode is already fully deterministic (its qfrc_constraint and JTDAJ are gathers/tiles), only sparse drifts.

This cleanly defines determinism 3/N scope: deterministic sparse qfrc_constraint + sparse H assembly (+ optionally the linesearch Jaref/search_dot atomics). The gather versions I used for falsification are naive (O(nefc) per dof/element) — a production version would want the usual count/scan or segmented-reduce treatment, same architecture as 2/N. Happy to draft it.

@johnnynunez

Copy link
Copy Markdown

Follow-up to the root-cause analysis above: the fix is now implemented as determinism 3/N — mar-yan24#8. End result across the stack (1/N rebase + 2/N port + 3/N): bitwise qacc/qpos/qvel over 1000 steps, graph-capture-safe, 1108 tests passing. Details and honest perf table in that PR.

@johnnynunez

Copy link
Copy Markdown

@mar-yan24 could you update your fork and all branch with main? then my PRs will be clean

@mar-yan24

Copy link
Copy Markdown
Contributor Author

Wow @johnnynunez! Your development speed is quite fast; I cannot keep up. Are you still looking for me to update my branches? They are somewhat deprecated, so I need to take a look at the new conflicts. I see your new PRs though, thanks for the contribution! This weekend I will look into your new PRs. In the meantime, I can rebase my code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Determinism

4 participants