Skip to content

Semantic segmentation parity#1283

Merged
thowell merged 20 commits into
google-deepmind:mainfrom
tkelestemur:tarik/semantic-segmentation-parity
Apr 27, 2026
Merged

Semantic segmentation parity#1283
thowell merged 20 commits into
google-deepmind:mainfrom
tkelestemur:tarik/semantic-segmentation-parity

Conversation

@tkelestemur

Copy link
Copy Markdown
Contributor

Summary

This PR adds first-class semantic segmentation to MJWarp's renderer.

  • Adds mjw.get_semantic_segmentation(...), returning per-pixel (object_id, object_type) pairs aligned with MuJoCo segmentation semantics.
  • Preserves the existing mjw.get_segmentation(...) API and buffer for backward compatibility.
  • Reports regular geom hits as (geom_id, mjOBJ_GEOM) and flex hits as (flex_id, mjOBJ_FLEX), instead of reducing flex hits to the legacy -2 sentinel only.
  • Exposes the new API from the public package surface and adds test coverage for the new renderer/accessor behavior.
  • Adds a dedicated semantic segmentation notebook demo, including flex examples, RGB/depth rendering, legacy segmentation, semantic object ids, semantic object types, and semantic flex ids.

PS: The new notebook imports mujoco_warp directly from the local checkout and disables IPython autoreload for the session, since Warp kernels require file-backed Python source in notebook environments.

Aside from the regular tests, I've also run pytest contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py.

@thowell thowell requested a review from StafaH April 13, 2026 14:51
@thowell

thowell commented Apr 13, 2026

Copy link
Copy Markdown
Collaborator

@tkelestemur thank you for contributing this feature to mujoco warp!

@StafaH @btaba would it make sense to combine this feature (in a potentially breaking way) with the existing segmentation api?

@tkelestemur instead of adding a separate notebook for this feature, can we add an example to the existing tutorial notebook?

@thowell thowell requested a review from btaba April 13, 2026 15:00
@StafaH

StafaH commented Apr 13, 2026

Copy link
Copy Markdown
Collaborator

Hi @tkelestemur, thanks for the PR!

In this case I believe this new feature should not be implemented in MJWarp, and can be implemented at the framework level that you use downstream. The existing segmentation output provides enough information to perform post processing downstream e.g. by implementing semantic segmentation as an extension in Mjlab with a kernel that converts the existing integer id to another value inplace.

@tkelestemur

Copy link
Copy Markdown
Contributor Author

@StafaH The existing segmentation buffer is sufficient to reconstruct semantic output for geom hits and background, but not for flex hits, because all flex hits are collapsed to -2 and lose the underlying flex_id.

@StafaH

StafaH commented Apr 14, 2026

Copy link
Copy Markdown
Collaborator

I agree, we should fix existing segmentation output. I think what @thowell mentioned is correct, we should focus on having the existing segmentation output match MuJoCo exactly instead of making a seperate output, and we can update the tests to check that this segmentation output matches exactly the C MuJoCo output (similar to what was done for depth recently).

WDYT @thowell?

@tkelestemur this might also require opening an issue and a PR in mjlab to fix any changes that happen downstream there.

@tkelestemur

Copy link
Copy Markdown
Contributor Author

@StafaH Sounds good -- I'll update this PR to use the existing segmentation API and also open an issue and a PR in mjlab.

@tkelestemur

Copy link
Copy Markdown
Contributor Author

@thowell I've removed the extra notebook and added a section to the current notebook.

@StafaH I've reverted the new segmentation api and updated the current one. Soon, I'll open a PR on the mjlab side.

Comment thread mujoco_warp/_src/io_test.py Outdated
Comment thread mujoco_warp/_src/io_test.py Outdated
seg = rc.seg_data.numpy()
self.assertTrue(np.any(seg >= 0), "Expected geom hits from auto-detected seg")
self.assertGreater(np.unique(seg).shape[0], 1)
seg = wp.zeros((1, 32, 32, 2), dtype=int)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we create seg from the render context?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to size the segmentation output from rc.cam_res instead of hardcoding the resolution. Is that what you meant?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets change this back to what you had before wp.zeros((1, 32, 32, 2), dtype=int). thanks!

Comment thread notebooks/tutorial.ipynb Outdated
"media.show_image(depth_grid, cmap='gray', vmin=0, vmax=1)"
]
},
{

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we can simplify this and make it similar to the depth render example (essentially a comment plus 2 lines of code and then an additional segmentation image to display.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified the tutorial example to extend the existing rendering cell with a small segmentation extraction/display block. Lmk how it looks.

@thowell

thowell commented Apr 15, 2026

Copy link
Copy Markdown
Collaborator

the updated pr looks great! added a few comments. please let us known if you have any questions. thanks!

@tkelestemur

Copy link
Copy Markdown
Contributor Author

Thanks @thowell I've addressed your reviews. I also opened an downstream PR for mjlab here: mujocolab/mjlab#911

Comment thread notebooks/tutorial.ipynb Outdated
"segmentation_grid = segmentation_data.numpy()[..., 0].reshape(4, 4, CAM_RES[1], CAM_RES[0])\n",
"segmentation_grid = segmentation_grid.transpose(0, 2, 1, 3)\n",
"segmentation_grid = segmentation_grid.reshape(4 * CAM_RES[0], 4 * CAM_RES[1])\n",
"media.show_image(segmentation_grid, cmap='viridis', vmin=-1, vmax=max(1, int(segmentation_grid.max())))\n"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please confirm that this line runs as expected

ValueError: Type int32 is not a valid media data type (uint or float).

converting the array to floats produces

Image

@tkelestemur

Copy link
Copy Markdown
Contributor Author

sorry forgot to enable render_seg. This is what I'm getting right now:

image

so should be good to go.

@StafaH StafaH left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is shaping up nicely. Left some comments.

Comment thread mujoco_warp/_src/render_util.py Outdated
Comment thread mujoco_warp/_src/render_util.py Outdated
Comment thread mujoco_warp/_src/io_test.py Outdated
Comment thread mujoco_warp/_src/render_test.py Outdated
Comment thread mujoco_warp/_src/render_test.py Outdated
@StafaH

StafaH commented Apr 17, 2026

Copy link
Copy Markdown
Collaborator

Also might be worth having the seg test xml include 1 flex object for completeness

@tkelestemur

Copy link
Copy Markdown
Contributor Author

@StafaH thanks for the reviews, I think I addressed all of them.

In f4bf916, I also dded a small flex object to the synthetic segmentation XML for completeness, and the test now asserts the model includes one flex before checking the render-context segmentation setup.

Let me know if this is enough.

@StafaH StafaH left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tkelestemur, left one last comment, but otherwise LGTM

Comment thread mujoco_warp/_src/render_test.py Outdated
@thowell

thowell commented Apr 22, 2026

Copy link
Copy Markdown
Collaborator

@tkelestemur some of the checks are not passing, please take a look. might just need to sync with main. thanks!

@thowell

thowell commented Apr 22, 2026

Copy link
Copy Markdown
Collaborator

there is a segmentation fault with one of the checks, please take a look. thanks! https://github.com/google-deepmind/mujoco_warp/actions/runs/24784925329/job/72532705830?pr=1283

@tkelestemur

Copy link
Copy Markdown
Contributor Author

@thowell on it!

@tkelestemur

Copy link
Copy Markdown
Contributor Author

@thowell I pushed a fix for the Linux CPU segfault in 98d6411.

Root cause was that build_flex_bvh() creates a grouped wp.Mesh, but we were looking up roots with the BVH path. For flexes the correct lookup is mesh_get_group_root, and in this Warp version that is only exposed as a kernel builtin, not a normal Python host API. So the final change keeps the surface area small:

  • add a private _compute_mesh_group_roots helper just to call the mesh-specific builtin
  • switch flex BVH root lookup over to that helper
  • fix flex_group_root layout in create_render_context() to match the renderer indexing ([worldid, flexid])

I considered a smaller nworld == 1 -> -1 workaround, but that would only paper over the single-world case instead of fixing the underlying mesh-vs-BVH mismatch. I also considered splitting flexes into one mesh per world, but that would be a much broader change.

I rechecked the targeted regressions locally on macOS and on Linux CPU (l4):

  • render_util_test.py::test_get_segmentation_preserves_flex_ids
  • render_test.py -k segmentation_matches_mujoco (skips on headless Linux as expected)
  • io_test.py -k segmentation_from_camera_output

All of those pass with this change, so this should address the CI segfault without broadening the PR beyond the failing path.

Comment thread mujoco_warp/_src/bvh.py Outdated

# Warp exposes mesh group-root lookup as a kernel builtin in this version.
@wp.kernel
def _compute_mesh_group_roots(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the _ is not necessary

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tkelestemur can we update this as @StafaH suggests? thanks!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in 0fdb3e6

@thowell thowell merged commit 5f63341 into google-deepmind:main Apr 27, 2026
10 checks passed
@thowell

thowell commented Apr 27, 2026

Copy link
Copy Markdown
Collaborator

@tkelestemur thank you for this contribution!

@tkelestemur tkelestemur deleted the tarik/semantic-segmentation-parity branch April 27, 2026 23:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants