Use scalar tile GEMM instead of cuBLASDx in the constraint solver#1402
Merged
adenzler-nvidia merged 2 commits intoJun 3, 2026
Merged
Conversation
The blocked Cholesky performs small (16x16) upper-convention (U^T U) rank-k updates. Warp's register-blocked scalar GEMM is faster than cuBLASDx for that left-transpose pattern at this tile size, and skipping cuBLASDx also avoids its LTO compile cost. Set enable_mathdx_gemm=False via per-kernel module_options on the blocked-Cholesky kernels. Also switch the existing dense-JTDAJ disable to the same per-kernel mechanism and remove the now-unused scoped_mathdx_gemm_disabled helper, so the option lives on the kernel definition rather than mutating the global warp config around each launch.
The dense JTDAJ tile_matmul runs on Warp's scalar GEMM (MathDx disabled). Its optimal block_dim for the scalar path is 128, whereas 96 was tuned for the cuBLASDx path. Only affects dense models (nv <= 32).
thowell
approved these changes
Jun 3, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The blocked-Cholesky solver kernels (
update_gradient_cholesky_blockedandupdate_gradient_cholesky_blocked_skip_unchanged) run theirtile_matmulthrough cuBLASDx (MathDx). This disables MathDx for them via the per-kernelmodule_options={"enable_mathdx_gemm": False}decorator option, routing them to Warp's built-in scalar register-blocked GEMM. The scalar path is both faster for these small, transpose-heavy tile shapes and dramatically cheaper to JIT-compile, since it avoids cuBLASDx's LTO step.It also tidies the dense JTDAJ path. That launch was already forced onto the scalar GEMM through a
scoped_mathdx_gemm_disabled()context manager that toggled the globalwp.config.enable_mathdx_gemmaround the launch. This replaces that fragile global-state toggle with the same per-kernel decorator (no backend change) and retunes its block dim from 96 to 128, the scalar-GEMM optimum.Motivation
The solver uses the upper-triangular convention (A = UᵀU), so its GEMMs are left-transpose products (AᵀB). For the 16×16 tiles used here, Warp's scalar GEMM keeps the accumulator in registers and handles this shape more efficiently than the cuBLASDx kernel, which also forces an expensive LTO compile on every cold start. Disabling MathDx for the Cholesky kernels recovers both.
The per-kernel decorator is declarative, applied at kernel-build time, and does not mutate global state, so it also lets us delete the
scoped_mathdx_gemm_disabledhelper entirely.Changes
module_options={"enable_mathdx_gemm": False}to the three solver kernels that usetile_matmul:update_gradient_cholesky_blocked,update_gradient_cholesky_blocked_skip_unchanged, andupdate_gradient_JTDAJ_dense_tiled. The two blocked-Cholesky kernels move from cuBLASDx to scalar;JTDAJ_densewas already scalar (via the scoped wrap) and stays scalar, now selected by the decorator.scoped_mathdx_gemm_disabledcontext manager and its remaining call site, now redundant.BlockDim.update_gradient_JTDAJ_densefrom 96 to 128, the optimal block dim for the scalar GEMM (the occupancy trade-off shifts once cuBLASDx's shared-memory pressure is gone).Results
Measured on an RTX PRO 6000 (Blackwell, sm_120) with Warp 1.13, via
benchmarks/run.pyfor end-to-end and cold-start JIT, and Nsight Systems node-level tracing for per-kernel GPU time.End-to-end throughput and JIT compile time:
Per-kernel GPU time (Nsight, full-factorize instances, low variance):
update_gradient_cholesky_blockedupdate_gradient_JTDAJ_dense_tiledThe Cholesky backend swap is the real win: it drives the end-to-end and JIT improvements on the blocked-Cholesky models (nv > 32), and eliminating cuBLASDx's LTO compile is what halves the JIT time there. The dense JTDAJ kernel was already on the scalar path, so its 10% gain is purely the block-dim retune and shows up only on dense models such as humanoid (which never enters the blocked Cholesky kernel); its end-to-end delta is within run-to-run noise.
Testing
Existing solver regression tests pass. This is a GEMM-backend selection change with no algorithmic change, so numerical results are unaffected.