Skip to content

Use scalar tile GEMM instead of cuBLASDx in the constraint solver#1402

Merged
adenzler-nvidia merged 2 commits into
google-deepmind:mainfrom
adenzler-nvidia:adenzler/scoped-mathdx-tile-matmul
Jun 3, 2026
Merged

Use scalar tile GEMM instead of cuBLASDx in the constraint solver#1402
adenzler-nvidia merged 2 commits into
google-deepmind:mainfrom
adenzler-nvidia:adenzler/scoped-mathdx-tile-matmul

Conversation

@adenzler-nvidia

Copy link
Copy Markdown
Collaborator

Summary

The blocked-Cholesky solver kernels (update_gradient_cholesky_blocked and update_gradient_cholesky_blocked_skip_unchanged) run their tile_matmul through cuBLASDx (MathDx). This disables MathDx for them via the per-kernel module_options={"enable_mathdx_gemm": False} decorator option, routing them to Warp's built-in scalar register-blocked GEMM. The scalar path is both faster for these small, transpose-heavy tile shapes and dramatically cheaper to JIT-compile, since it avoids cuBLASDx's LTO step.

It also tidies the dense JTDAJ path. That launch was already forced onto the scalar GEMM through a scoped_mathdx_gemm_disabled() context manager that toggled the global wp.config.enable_mathdx_gemm around the launch. This replaces that fragile global-state toggle with the same per-kernel decorator (no backend change) and retunes its block dim from 96 to 128, the scalar-GEMM optimum.

Motivation

The solver uses the upper-triangular convention (A = UᵀU), so its GEMMs are left-transpose products (AᵀB). For the 16×16 tiles used here, Warp's scalar GEMM keeps the accumulator in registers and handles this shape more efficiently than the cuBLASDx kernel, which also forces an expensive LTO compile on every cold start. Disabling MathDx for the Cholesky kernels recovers both.

The per-kernel decorator is declarative, applied at kernel-build time, and does not mutate global state, so it also lets us delete the scoped_mathdx_gemm_disabled helper entirely.

Changes

  • Add module_options={"enable_mathdx_gemm": False} to the three solver kernels that use tile_matmul: update_gradient_cholesky_blocked, update_gradient_cholesky_blocked_skip_unchanged, and update_gradient_JTDAJ_dense_tiled. The two blocked-Cholesky kernels move from cuBLASDx to scalar; JTDAJ_dense was already scalar (via the scoped wrap) and stays scalar, now selected by the decorator.
  • Remove the scoped_mathdx_gemm_disabled context manager and its remaining call site, now redundant.
  • Bump BlockDim.update_gradient_JTDAJ_dense from 96 to 128, the optimal block dim for the scalar GEMM (the occupancy trade-off shifts once cuBLASDx's shared-memory pressure is gone).

Results

Measured on an RTX PRO 6000 (Blackwell, sm_120) with Warp 1.13, via benchmarks/run.py for end-to-end and cold-start JIT, and Nsight Systems node-level tracing for per-kernel GPU time.

End-to-end throughput and JIT compile time:

Model steps/sec JIT compile
three_humanoids (nv=81) +2.8% 66s → 32s (−52%)
unitree_g1_flat (nv=35) +3.3% 67s → 33s (−51%)
aloha_clutter (nv=136) +3.8% 109s → 75s (−31%)
humanoid (nv=27, dense) neutral (+0.6%) 41s → 41s

Per-kernel GPU time (Nsight, full-factorize instances, low variance):

Kernel before after Δ change
update_gradient_cholesky_blocked 379.7 µs (cuBLASDx) 342 µs (scalar) −10% GEMM backend
update_gradient_JTDAJ_dense_tiled 36.4 µs (scalar, bd 96) 32.7 µs (scalar, bd 128) −10% block dim

The Cholesky backend swap is the real win: it drives the end-to-end and JIT improvements on the blocked-Cholesky models (nv > 32), and eliminating cuBLASDx's LTO compile is what halves the JIT time there. The dense JTDAJ kernel was already on the scalar path, so its 10% gain is purely the block-dim retune and shows up only on dense models such as humanoid (which never enters the blocked Cholesky kernel); its end-to-end delta is within run-to-run noise.

Testing

Existing solver regression tests pass. This is a GEMM-backend selection change with no algorithmic change, so numerical results are unaffected.

The blocked Cholesky performs small (16x16) upper-convention (U^T U)
rank-k updates. Warp's register-blocked scalar GEMM is faster than
cuBLASDx for that left-transpose pattern at this tile size, and skipping
cuBLASDx also avoids its LTO compile cost.

Set enable_mathdx_gemm=False via per-kernel module_options on the
blocked-Cholesky kernels. Also switch the existing dense-JTDAJ disable to
the same per-kernel mechanism and remove the now-unused
scoped_mathdx_gemm_disabled helper, so the option lives on the kernel
definition rather than mutating the global warp config around each launch.
The dense JTDAJ tile_matmul runs on Warp's scalar GEMM (MathDx disabled).
Its optimal block_dim for the scalar path is 128, whereas 96 was tuned for
the cuBLASDx path. Only affects dense models (nv <= 32).
@adenzler-nvidia adenzler-nvidia merged commit 925902a into google-deepmind:main Jun 3, 2026
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants