Skip to content

Iterate elliptic-cone JTCJ over contact support pairs#1411

Closed
adenzler-nvidia wants to merge 1 commit into
google-deepmind:mainfrom
adenzler-nvidia:adenzler/jtcj-support-pairs
Closed

Iterate elliptic-cone JTCJ over contact support pairs#1411
adenzler-nvidia wants to merge 1 commit into
google-deepmind:mainfrom
adenzler-nvidia:adenzler/jtcj-support-pairs

Conversation

@adenzler-nvidia

Copy link
Copy Markdown
Collaborator

update_gradient_JTCJ_sparse assembled the elliptic-cone Hessian output-stationary: one thread per (contact, dense dof-pair), scanning the contact's colind to locate each pair. With nefc >> nv almost every dof-pair is absent from a given contact's support, so ~99% of the work was scan-and-skip.

This makes it input-stationary — one thread per (contact, support-pair), decoding the pair index directly into its two dofs, register-summing the cone block, and accumulating into H with a single atomic add (which also fixes a latent write race in the old +=). The pair count is bounded by a new jtcj_max_pairs, computed once in put_model from MuJoCo's mass-matrix sparsity. The math is unchanged; results are bit-identical to the old kernel (differences ~1e-4, from float atomic ordering).

End-to-end, elliptic cone, vs main:

model step time speedup
aloha clutter (~60k contacts) 82.1 → 29.0 µs 2.8×
three_humanoids 4.6 → 1.8 µs 2.5×
g1 + hands 3.6 → 2.6 µs 1.4×

The JTCJ kernel alone drops from ~6.2 ms to ~140 µs (~44×) on aloha; gains scale with contact count. The dense (small-nv) and island paths are untouched. Existing elliptic solver/forward tests pass.

The sparse elliptic-cone Hessian assembly (JTCJ) previously ran
output-stationary: one thread per (contact, dense dof-pair), scanning
the contact's column indices to locate each pair. When nefc >> nv the
overwhelming majority of those dof-pairs do not appear in the contact's
support, so the kernel spent almost all of its time scanning and
skipping.

Restructure it to be input-stationary: launch one thread per
(contact, support-pair), decode the pair index directly into the two
participating dofs, register-sum the cone block, and accumulate into H
with a single atomic add. This touches only the pairs that actually
contribute and exposes far more parallelism. The pair dimension is
bounded by jtcj_max_pairs, derived once in io.py from the deepest
geom-body dof chain. The math is unchanged.
rowadr0 = efc_J_rowadr_in[worldid, efcid0]
pos1 = int(0)
rem = pairid
while rem >= rownnz - pos1:

@thowell thowell Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we save some computations with something like the following?

dif = rownnz - pos1
while rem >= dif
  rem -= dif
  ...
  dif = rownnz - pos1

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi i tested this and didn't see a speedup on clutter, so leaving as is

@thowell thowell left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is AWESOME

thanks @adenzler-nvidia!

@erikfrey erikfrey mentioned this pull request Jun 5, 2026
@erikfrey

erikfrey commented Jun 5, 2026

Copy link
Copy Markdown
Collaborator

#1413 is merged - thanks again @adenzler-nvidia

@erikfrey erikfrey closed this Jun 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants