Skip to content

Raycast sensor optimization#948

Merged
kevinzakka merged 2 commits into
mujocolab:mainfrom
bd-pdomanico:pd/optimize_raycast
Apr 28, 2026
Merged

Raycast sensor optimization#948
kevinzakka merged 2 commits into
mujocolab:mainfrom
bd-pdomanico:pd/optimize_raycast

Conversation

@bd-pdomanico

Copy link
Copy Markdown
Contributor

Removes boolean-mask indexing operations with masked_fill_ and an implementation of hit_pos_w that uses a clamped distance to put hit_pos_w for misses at world_origins. These effectively remove all cuda syncs from the ray postprocess, unblocking the cpu thread while gpu-based sensing occurs.

Also optimizes quat_from_matrix. Similarly this removes cuda sync operations and catches the implementation up to the latest from https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py. Profiling now shows:

device=NVIDIA GeForce RTX 4090

shape              legacy us   current us   speedup     max |Δ|
---------------------------------------------------------------
B = 4096              231.26        51.01     4.53x    1.79e-07
B = 16384             243.55        56.07     4.34x    1.79e-07
B = 65536             282.05        53.93     5.23x    1.79e-07

Profiled using:

def time_call(fn, x, *, warmup=20, iters=200) -> float:
    """Mean per-call latency in microseconds, measured with cudaEvents."""
    for _ in range(warmup):
        fn(x)
    torch.cuda.synchronize()
    starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
    ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
    for i in range(iters):
        starts[i].record()
        fn(x)
        ends[i].record()
    torch.cuda.synchronize()
    times_ms = [s.elapsed_time(e) for s, e in zip(starts, ends)]
    return statistics.mean(times_ms) * 1000.0  # us

m = random_rotations(shape, device=device, dtype=dtype)

# Numerical agreement (account for ±q sign ambiguity)
with torch.no_grad():
    ql = quat_from_matrix_legacy(m)
    qc = quat_from_matrix_current(m)
    sign = torch.sign((ql * qc).sum(dim=-1, keepdim=True))
    sign = torch.where(sign == 0, torch.ones_like(sign), sign)
    max_diff = (ql - qc * sign).abs().max().item()

t_legacy = time_call(lambda x: quat_from_matrix_legacy(x), m)
t_current = time_call(lambda x: quat_from_matrix_current(x), m)
speedup = t_legacy / t_current

With both of these changes, Mjlab-Velocity-Rough-Unitree-Go1 trains ~3% faster

@kevinzakka kevinzakka left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome, thanks @bd-pdomanico! Two small asks before merging:

  1. Rebase onto the latest main
  2. Add a changelog entry

Thanks!

bd-pdomanico and others added 2 commits April 28, 2026 14:42
Removes boolean-mask indexing operations with `masked_fill_` and an implementation of hit_pos_w that uses a clamped distance to put hit_pos_w for misses at world_origins. These effectively remove all cuda syncs from the ray postprocess, unblocking the cpu thread while gpu-based sensing occurs.

Also optimizes quat_from_matrix. Similarly this removes cuda sync operations and catches the implementation up to the latest from https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py. Profiling now shows:

    device=NVIDIA GeForce RTX 4090

    shape              legacy us   current us   speedup     max |Δ|
    ---------------------------------------------------------------
    B = 4096              231.26        51.01     4.53x    1.79e-07
    B = 16384             243.55        56.07     4.34x    1.79e-07
    B = 65536             282.05        53.93     5.23x    1.79e-07
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@bd-pdomanico bd-pdomanico requested a review from kevinzakka April 28, 2026 18:51
kevinzakka added a commit that referenced this pull request Apr 28, 2026
The previous mujoco 3.7 nightly was yanked, breaking CI. This bumps both mujoco (to 3.8.1 nightly) and mujoco-warp (to a post-3.8.0 commit that includes the cache_kernel fix from google-deepmind/mujoco_warp#1318). The multiccd enable flag was removed in mujoco 3.8 (it became default-on), so the test that exercised it now uses the energy flag instead.

Fixes #948
@kevinzakka kevinzakka reopened this Apr 28, 2026

@kevinzakka kevinzakka left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@kevinzakka kevinzakka merged commit f4a7504 into mujocolab:main Apr 28, 2026
11 of 15 checks passed
kevinzakka added a commit that referenced this pull request Apr 28, 2026
…ipt (#954)

The @torch.compile(fullgraph=True) decorator added in #948 hits dynamo's
recompile_limit (8) when the full test suite runs, causing
test_raycast_sensor.py to fail with FailOnRecompileLimitHit. The shape
specialization on the leading batch dim exhausts the budget across the
many env constructions that happen during the suite.

Switching back to @torch.jit.script avoids dynamo entirely while keeping
the pytorch3d-style sync-free implementation rewrite from #948 (which is
the actual source of the speedup). Only the decorator changes.

Also drops the corresponding changelog entry, since the speedup
attribution is now muddier.
sibisibi pushed a commit to DAVIAN-Robotics/mjlab that referenced this pull request May 5, 2026
The previous mujoco 3.7 nightly was yanked, breaking CI. This bumps both mujoco (to 3.8.1 nightly) and mujoco-warp (to a post-3.8.0 commit that includes the cache_kernel fix from google-deepmind/mujoco_warp#1318). The multiccd enable flag was removed in mujoco 3.8 (it became default-on), so the test that exercised it now uses the energy flag instead.

Fixes mujocolab#948
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants