[Diffusion] Standalone Rollout API + Denoising Environment Backpass + SP-Aligned Log-Prob for T2I Post-Training#22604
Merged
mickqian merged 55 commits intosgl-project:mainfrom Apr 15, 2026
Conversation
…port Rebased onto latest main.
Replace the original monolithic flow_matching_with_logprob patch with a modular mixin-based architecture: - SchedulerRLMixin: core rollout logic (prepare, SDE/CPS/ODE sampling, log-prob accumulation, resource lifecycle) - SchedulerRLDebugMixin: optional debug tensor collection - RolloutSessionData: per-batch state dataclass stored on batch object - All rollout state lives on the batch, keeping the scheduler stateless Made-with: Cursor
- Pass batch object through scheduler.step() to enable per-request rollout - Add _maybe_prepare_rollout / _maybe_collect_rollout_log_probs lifecycle hooks in the denoising stage - Wire rollout flow through decoding and denoising_dmd stages Made-with: Cursor
- Add rollout_sde_type, rollout_noise_level, rollout_log_prob_no_const, rollout_debug_mode to SamplingParams with validation - Propagate parameters through OpenAI-compatible image/video endpoints - Wire through diffusion_generator and gpu_worker Made-with: Cursor
- ODE mode: bit-exact alignment against FlowGRPO reference implementation - SDE/CPS mode: verify log-prob sign, shape, noise injection behavior - Validate prepare/consume/release lifecycle and edge cases Made-with: Cursor
…ults - Run isort/black/ruff formatting on all changed files - Remove unused TeaCacheParams imports from schedule_batch.py (F401) - Rewrite FlowGRPO alignment test: use verbatim reference from sd3_sde_with_logprob.py, verify log_prob at atol=1e-6 - Match FlowGRPO convention: SDE uses full Gaussian log-prob (no_const=False), CPS uses no_const=True - Remove explicit defaults from rollout CLI args to fix test_get_cli_args_drops_unset_sampling_params Made-with: Cursor
Rollout is an internal post-training feature; it should not be exposed through the standardized OpenAI image/video generation endpoints. Parameters remain accessible via SamplingParams CLI and direct generator API. Made-with: Cursor
…rollout_unit The file tests ODE, SDE, and CPS modes — the old name was misleading. Made-with: Cursor
Rename ``_maybe_append_dit_env_step`` to ``_maybe_append_dit_trajectory_step``
and switch it from capturing the post-scale/post-I2V-concat
``latent_model_input`` to the raw pre-processing ``latents`` (x_{t_i}). The
call is moved to the top of each denoising step and both it and
``_postprocess_rollout_outputs`` are now gated with ``batch.rollout`` at the
call site to keep the non-rollout main path strictly untouched.
``_postprocess_rollout_outputs`` additionally appends the last
``scheduler.step`` output as the (T+1)-th latent, and is reordered to run
before ``_post_denoising_loop`` so that the final latent is still SP-sharded
and can be gathered uniformly alongside the per-step trajectory via
``gather_stacked_latents_for_sp``.
Field renames propagated through ``RolloutDitTrajectory``, rollout_api
serialization, and the unit tests: ``latent_model_inputs`` -> ``latents``
(shape ``[B, T+1, ...]``); internal state keys
``trajectory_latent_model_inputs`` -> ``step_latents``.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Fix existing flowGRPO alignment mock to also populate rollout_session_data.noise_buffer (required by full-buffer log_prob path) - Add test_ode_bit_exact_with_non_rollout_path: locks in that ODE rollout produces the same prev_sample as scheduling_flow_match_euler_discrete's non-rollout deterministic branch using bf16 model_output - Add test_sde_cps_force_fp32_with_bf16_model_output: regression test for the wrapped-scalar promotion trap; asserts that SDE/CPS branches keep log_prob and noise buffer in fp32 even when given bf16 model_output Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Contributor
There was a problem hiding this comment.
Code Review
This pull request introduces a new Rollout HTTP API for post-training and reinforcement learning workflows, featuring new request/response structures, tensor serialization utilities, and a RolloutDenoisingMixin for collecting denoising environments and trajectories. The update also refactors sequence parallelism handling and ensures numerical precision during sampling by casting to float32. Feedback suggests making the _kwargs_to_cpu helper fully recursive to correctly handle nested sequences and preserve tuple types.
- Remove QwenImageEditRolloutPipelineMixin and its Edit pipeline wiring; rollout is T2I-only for now, and the base gather_dit_env_static_for_sp on QwenImagePipelineConfig is reused. - Generalize _kwargs_to_cpu in rollout_denoising_mixin.py to a recursive walker handling tensor / dict / list / tuple uniformly. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Matches CI pre-commit output: 3 isort fixes (http_server.py, rollout_api.py, zimage_rollout_pipeline_mixin.py) and 5 black fixes (sampling_params.py, rollout_api.py, sp_utils.py, rollout_denoising_mixin.py, test_rollout_api.py). No logic changes. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
mickqian
reviewed
Apr 12, 2026
| @@ -0,0 +1,27 @@ | |||
| # SPDX-License-Identifier: Apache-2.0 | |||
mickqian
approved these changes
Apr 14, 2026
Collaborator
|
/tag-and-rerun-ci |
Collaborator
|
/tag-and-rerun-ci |
…pass # Conflicts: # python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py # python/sglang/multimodal_gen/test/run_suite.py
Collaborator
|
/tag-and-rerun-ci |
Collaborator
|
/rerun-failed-ci |
1 similar comment
Collaborator
|
/rerun-failed-ci |
yhyang201
pushed a commit
to yhyang201/sglang
that referenced
this pull request
Apr 22, 2026
… backpass and sp-aligned log-prob for T2I post-training (sgl-project#22604) Co-authored-by: MikukuOvO <mikukuovo@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
[Diffusion] Standalone Rollout API + Denoising Environment Backpass + SP-Aligned Log-Prob for T2I Post-Training
1. Architecture & Design
Motivation
Building on PR #21204 (SDE/CPS/ODE rollout log-prob engine), RL-based post-training of diffusion models (e.g., FlowGRPO) needs more than just per-step log-probabilities. It also needs (a) a dedicated serving endpoint that is insulated from the generic T2I generate path, (b) enough trajectory metadata to replay the rollout outside of SGLang-D for policy-gradient computation, and (c) numerical stability under Sequence Parallelism.
This PR delivers four things (T2I scope):
/generatepath into its own FastAPI routerPOST /rollout/generate(APIRouter(prefix="/rollout")) with dedicatedRolloutRequest/RolloutResponse[]I/O structs. The endpoint is sample-granular — batch results are split into per-sampleRolloutResponseobjects, each carrying its own serialized trajectory slice, so RL trainers consuming individual trajectories don't have to demux batches or thread RL-only fields through the main generate path.RolloutDenoisingMixinhooksDenoisingStageand, whenrollout_return_denoising_env/rollout_return_dit_trajectoryare set, returns the frozen transformer kwargs (text embeds, image embeds, guidance) and the per-step DiT trajectory (raw noisyx_{t_i}+ final latent + timesteps,[T+1, ...]) alongside the existing log-probs / debug tensors — everything a trainer needs to re-run the exact forward pass SGLang-D ran.rollout_log_probsunder SP drifted from the single-GPU reference by exactly 1 bf16 ulp at each step's log-prob magnitude (Qwen-Image SP=2 max over 50 steps:4.88e-4 = 2⁻¹¹for SDE,1.95e-3 = 2⁻⁹for CPS). Root cause: 0-dim fp32noise_std_devis silently demoted to bf16 by PyTorch wrapped-scalar promotion when multiplied against an N-dim bf16variance_noiseonenable_autocast=Falsepipelines, and bf16 sum is non-associative across shards. Fix: each rank computes log-prob on the full pre-shard noise buffer (noall_reduce), plus flowGRPO's fp32 entry-cast policy for SDE/CPS (see §2 Precision Policy).dit_trajectory, wrong capture tensor, redundant_maybe_collect_rollout_log_probscall,variance_noisealiasing on SP=1, CLI-defaults leak,gather_latents_for_spkwarg drift. Full list in §2 "Bug Fixes".Design Principles
RolloutResponseobjects, each carrying its own serialized trajectory slice, so RL trainers that process individual trajectories don't need to demux the batch.rollout_return_denoising_envandrollout_return_dit_trajectoryare opt-in flags. When disabled, no extra memory or computation is consumed — strict zero overhead on the main path.gather_latents_for_sp/gather_dit_env_static_for_spinfrastructure as the core pipeline. Non-rollout inference is strictly unchanged.rollout_session_data.noise_bufferwith a common seed, then computeslog_prob_no_const_valon that full buffer. Every rank gets the same per-step sum, so noall_reduceis needed — both SP/single-GPU bit-exactness and communication cost are better than the previous "local sum + all-reduce" path.model_output.float()at entry (matching flowGRPOsd3_sde_with_logprob.py), but the ODE branch is deliberately left uncasted so thatrollout(sde_type="ode")produces bit-identicalprev_sampleto the non-rollout deterministic step inscheduling_flow_match_euler_discrete.step. A dedicated unit test locks this in withtorch.equal.Module Structure
Data Flow
2. Features, API & Reliability
Rollout API Parameters (
RolloutRequest+SamplingParamsextensions)rolloutboolTruerollout_sde_typestr"sde""sde","cps", or"ode"rollout_noise_levelfloat0.7rollout_log_prob_no_constboolFalserollout_debug_modeboolTruerollout_return_denoising_envboolFalserollout_return_dit_trajectoryboolFalseResponse Structure (
RolloutResponse, per sample)generated_output: Base64-serialized output image / video tensor.rollout_log_probs:[T]per-step log-probability, fp32.rollout_debug_tensors:{variance_noise, prev_sample_mean, noise_std_dev, model_output}, each[T, ...], SP-gathered to full shape.denoising_env:{image_kwargs, pos_cond_kwargs, neg_cond_kwargs, guidance}, SP-gathered.dit_trajectory:{latents [T+1, ...], timesteps [T]}, SP-gathered.Precision Policy (flowGRPO alignment)
model_outputcast at entryvariance_noisedtypelog_prob_no_const_valdtypesde.float()sd3_sde_with_logprob.pycps.float()sd3_sde_with_logprob.pyodescheduling_flow_match_euler_discrete.stepnon-rollout branch, bit-exactThe fp32 entry-cast on SDE/CPS is adopted from FlowGRPO's reference
sde_step_with_logprob, whose comment reads "bf16 can overflow here when compute prev_sample_mean, we must convert all variable to fp32". This also avoids the PyTorch wrapped-scalar promotion trap where a 0-dim fp32noise_std_devis silently demoted to bf16 when multiplied against an N-dim bf16variance_noise, which was the actual root cause of the pre-fix SP log-prob drift onenable_autocast=Falsepipelines (Qwen-Image, Z-Image).Bug Fixes
dit_trajectory._postprocess_rollout_outputsnow appends the lastscheduler.stepoutput as the(T+1)-th entry and is reordered to run before_post_denoising_loopso it can be SP-gathered uniformly with the per-step trajectory.dit_trajectory. Moved the per-step capture hook to the top of the denoising loop; stores the raw pre-scale / pre-I2V-concat latentx_{t_i}instead of the post-scalelatent_model_input. Field renamedlatent_model_inputs→latents(shape[B, T+1, ...]).max_abs_diff=4.88e-4SDE /1.95e-3CPS on Qwen-Image SP=2, now0).variance_noisealiasing across debug-append calls.append_local_rollout_debug_tensorsnow clones before appending; on SP=1,_rollout_variance_noisereturns a view of the reusablenoise_buffer, so without the clone every appended entry held the last step's noise after the buffer was overwritten._maybe_collect_rollout_log_probscall. Removed the duplicate call from_post_denoising_loop— it was already invoked from_postprocess_rollout_outputs, and the first one crashed on SP runs because_rollout_session_datahad already been released.sampling_params. Removed explicit defaults from rollout CLI args sotest_get_cli_args_drops_unset_sampling_paramspasses again.gather_latents_for_spkeyword-arg drift in Z-Image / Qwen-Image rollout pipeline mixins and the debug mixin — small positional/keyword arg fix.Reliability Testing
Three levels of testing validate correctness:
python3 sglang/multimodal_gen/test/run_suite.py --suite unit) runs 125 tests (124 passed, 1 env-dependent skip:test_integration_with_moto). Of those, 36 are rollout-specific (31 intest_rollout_api.py+ 5 intest_scheduler_rollout_unit.py); see §3.1.sde_step_with_logprobacross 4 seeds, for all four output quantities (prev_sample,prev_sample_mean,noise_std_dev,log_prob). All < 1e-6.3. Detailed Test Results
3.1 Unit Tests — Full CI unit suite 124 passed / 1 skipped
Command (matches
.github/workflows/pr-test-multimodal-gen.yml::multimodal-gen-unit-test):Rollout-specific breakdown (36/36):
The 5 scheduler rollout tests lock in the following invariants:
test_ode_step_does_not_call_variance_noise_samplertest_ode_debug_tensors_have_shape_safe_noise_stdtest_ode_bit_exact_with_non_rollout_path(new)torch.equal(rollout_prev, sample + dt * model_output)with bf16model_output, locking in ODE ≡ non-rollout bit-exactnesstest_single_step_matches_flowgrpo_referencemax_abs_diff < 1e-6forprev_sample,prev_sample_mean,noise_std_dev,log_probtest_sde_cps_force_fp32_with_bf16_model_output(new)model_outputto SDE/CPS must still yield fp32log_prob_sumand fp32noise_buffer, locking in the fp32 entry-cast + guarding against PyTorch wrapped-scalar demotion regression3.2 Qwen-Image Parallel Consistency (primary)
Model:
Qwen/Qwen-Image, seed=42, noise_level=0.5, 50 steps, 1024×1024, 2 GPUs.Reference: single-GPU (TP1, SP1, no CFGP).
Prompt: "A fluffy silver-gray cat rests on a light surface indoors, wearing small round sunglasses with gold frames and dark blue lenses. …"
SDE mode
CPS mode
ODE mode
3.3
log_probbefore vs after the SP fixHistorical Qwen-Image TP1-SP2 run (pre-fix, from
outputs/rollout_trajectory_debug_compare/20260411-080016/trajectory_debug_report.md):log_prob max_abs_diff(pre-fix)log_prob max_abs_diff(this PR)Key observations (Qwen-Image, primary)
log_prob.tp1_sp1_cfg1) shows smallprev_sample_meanandmodel_outputdifferences from non-deterministic reduction order in the CFG-parallel transformer forward — not the rollout engine.log_prob,variance_noiseandnoise_std_devare all bit-exact.test_ode_bit_exact_with_non_rollout_path). Any residual drift under TP2/CFGP is exactly what the non-rollout path would also see, consistent with PR [Diffusion] Revamp Rollout Log-Prob Support with SDE/CPS for RL Post-Training #21204's baseline.4. Summary
POST /rollout/generateAPI (new FastAPI router, separate from/generate), sample-granular, returns per-sample serialized rollout trajectory data (log-probs, debug tensors, denoising env, DiT trajectory) for T2I pipelines.RolloutDenoisingMixincaptures frozen transformer kwargs + per-stepx_{t_i}+ final latent, SP-gathered for replay.rollout_log_probs, bit-exact with the single-GPU reference on Qwen-Image and Z-Image-Turbo across all three modes (sde / cps / ode). Achieved by computing log-prob on the full pre-shard noise buffer + flowGRPO's fp32 entry-cast policy for SDE/CPS, while ODE preserves bit-exactness with the non-rollout deterministic path (locked in by unit test).run_suite.py --suite unit) 124 passed / 1 env-skipped, including the 36 rollout-specific tests. Two new invariants lock in (a) ODE ≡ non-rollout bit-exactness and (b) the fp32 entry-cast regression guard against PyTorch wrapped-scalar promotion. SDE/CPS single-step also matches a verbatim FlowGRPO reference across 4 seeds within 1e-6.upstream/main: 20 files, +1276 / −78 lines (excluding unrelated upstream changes).Appendix A — Z-Image-Turbo parallel consistency (secondary verification)
Z-Image-Turbo was run through the same
test_rollout_trajectory_debug_parallel.pyharness as Qwen-Image to make sure the SP fix generalizes beyond a single model. Same prompt/seed/noise_level as §3.2; 9 steps instead of 50 because Z-Image-Turbo is turbo-distilled.Model:
Tongyi-MAI/Z-Image-Turbo, seed=42, noise_level=0.5, 9 steps, 1024×1024, 2 GPUs.Reference: single-GPU (TP1, SP1, no CFGP).
max_abs_diffis the max over all 9 steps.SDE mode
CPS mode
ODE mode
Z-Image-Turbo observation: SP2 is bit-exact on every tensor including
log_prob. TP2-SP1 shows slightly moreprev_sample_mean/model_outputdrift than Qwen-Image does (reduction-order non-determinism in a different DiT architecture), butlog_prob,variance_noiseandnoise_std_devare all still bit-exact, confirming the full-buffer log-prob path is insulated from downstream DiT non-determinism regardless of the backbone.