Skip to content

[Diffusion] Standalone Rollout API + Denoising Environment Backpass + SP-Aligned Log-Prob for T2I Post-Training#22604

Merged
mickqian merged 55 commits intosgl-project:mainfrom
Rockdu:feat/denoising_backpass
Apr 15, 2026
Merged

[Diffusion] Standalone Rollout API + Denoising Environment Backpass + SP-Aligned Log-Prob for T2I Post-Training#22604
mickqian merged 55 commits intosgl-project:mainfrom
Rockdu:feat/denoising_backpass

Conversation

@Rockdu
Copy link
Copy Markdown
Contributor

@Rockdu Rockdu commented Apr 11, 2026

[Diffusion] Standalone Rollout API + Denoising Environment Backpass + SP-Aligned Log-Prob for T2I Post-Training

1. Architecture & Design

Motivation

Building on PR #21204 (SDE/CPS/ODE rollout log-prob engine), RL-based post-training of diffusion models (e.g., FlowGRPO) needs more than just per-step log-probabilities. It also needs (a) a dedicated serving endpoint that is insulated from the generic T2I generate path, (b) enough trajectory metadata to replay the rollout outside of SGLang-D for policy-gradient computation, and (c) numerical stability under Sequence Parallelism.

This PR delivers four things (T2I scope):

  1. Standalone rollout HTTP API. Rollout is split off from the generic /generate path into its own FastAPI router POST /rollout/generate (APIRouter(prefix="/rollout")) with dedicated RolloutRequest / RolloutResponse[] I/O structs. The endpoint is sample-granular — batch results are split into per-sample RolloutResponse objects, each carrying its own serialized trajectory slice, so RL trainers consuming individual trajectories don't have to demux batches or thread RL-only fields through the main generate path.
  2. Rollout denoising environment backpass. RolloutDenoisingMixin hooks DenoisingStage and, when rollout_return_denoising_env / rollout_return_dit_trajectory are set, returns the frozen transformer kwargs (text embeds, image embeds, guidance) and the per-step DiT trajectory (raw noisy x_{t_i} + final latent + timesteps, [T+1, ...]) alongside the existing log-probs / debug tensors — everything a trainer needs to re-run the exact forward pass SGLang-D ran.
  3. SP-aligned log-prob. Previously rollout_log_probs under SP drifted from the single-GPU reference by exactly 1 bf16 ulp at each step's log-prob magnitude (Qwen-Image SP=2 max over 50 steps: 4.88e-4 = 2⁻¹¹ for SDE, 1.95e-3 = 2⁻⁹ for CPS). Root cause: 0-dim fp32 noise_std_dev is silently demoted to bf16 by PyTorch wrapped-scalar promotion when multiplied against an N-dim bf16 variance_noise on enable_autocast=False pipelines, and bf16 sum is non-associative across shards. Fix: each rank computes log-prob on the full pre-shard noise buffer (no all_reduce), plus flowGRPO's fp32 entry-cast policy for SDE/CPS (see §2 Precision Policy).
  4. Assorted bug fixes bundled along the way — missing final-step latent in dit_trajectory, wrong capture tensor, redundant _maybe_collect_rollout_log_probs call, variance_noise aliasing on SP=1, CLI-defaults leak, gather_latents_for_sp kwarg drift. Full list in §2 "Bug Fixes".

Design Principles

  • Per-sample granularity. The rollout API splits batch outputs into per-sample RolloutResponse objects, each carrying its own serialized trajectory slice, so RL trainers that process individual trajectories don't need to demux the batch.
  • Opt-in fields. rollout_return_denoising_env and rollout_return_dit_trajectory are opt-in flags. When disabled, no extra memory or computation is consumed — strict zero overhead on the main path.
  • SP-transparent gathering. All SP-sharded tensors (latent model inputs, cond kwargs, DiT trajectory) are gathered to full shape before being returned, using the same gather_latents_for_sp / gather_dit_env_static_for_sp infrastructure as the core pipeline. Non-rollout inference is strictly unchanged.
  • Bit-exact SP log-prob, zero extra collectives. Each SP rank generates the full variance noise into rollout_session_data.noise_buffer with a common seed, then computes log_prob_no_const_val on that full buffer. Every rank gets the same per-step sum, so no all_reduce is needed — both SP/single-GPU bit-exactness and communication cost are better than the previous "local sum + all-reduce" path.
  • ODE ≡ non-rollout bit-exact, preserved. The SDE/CPS branches cast model_output.float() at entry (matching flowGRPO sd3_sde_with_logprob.py), but the ODE branch is deliberately left uncasted so that rollout(sde_type="ode") produces bit-identical prev_sample to the non-rollout deterministic step in scheduling_flow_match_euler_discrete.step. A dedicated unit test locks this in with torch.equal.

Module Structure

runtime/entrypoints/post_training/
  io_struct.py        # RolloutRequest / RolloutResponse Pydantic models
  rollout_api.py      # POST /rollout/generate endpoint + batch→per-sample serialization
  utils.py            # tensor_to_base64, base64_to_tensor, _maybe_serialize, _maybe_deserialize

runtime/post_training/
  rollout_denoising_mixin.py              # DenoisingStage hooks for env/trajectory collection
  sp_utils.py                             # SP collective helpers (gather_stacked_latents_for_sp, ...)
  rl_dataclasses.py                       # RolloutTrajectoryData, RolloutDenoisingEnv, RolloutDitTrajectory
  scheduler_rl_mixin.py                   # SchedulerRLMixin — SP-aligned log-prob on full buffer
  scheduler_rl_debug_mixin.py             # SchedulerRLDebugMixin — debug tensor accumulation
  pipeline_configs/
    zimage_rollout_pipeline_mixin.py      # Z-Image cond-kwargs gather under SP
    qwen_image_rollout_pipeline_mixin.py  # Qwen-Image cond-kwargs gather under SP

test/unit/
  test_rollout_api.py                     # 31 tests: tensor serialization, response building, trajectory slicing
  test_scheduler_rollout_unit.py          #  5 tests: ODE determinism, ODE bit-exact vs non-rollout,
                                          #           flowGRPO alignment, dtype regression

Data Flow

POST /rollout/generate
  → build SamplingParams (rollout=True, rollout_return_denoising_env, rollout_return_dit_trajectory)
  → pipeline.forward(req)
      → DenoisingStage:
          _maybe_prepare_rollout(batch)
          _maybe_init_denoising_env_collection(batch, image_kwargs, pos_cond_kwargs, ...)
          for each step:
              _maybe_append_dit_trajectory_step(batch, latents, timestep)   # raw x_{t_i}, before scale
              scheduler.step(batch=batch, ...)
                → flow_sde_sampling(model_output, sample, sigma, sigma_next, generator)
                    → SDE/CPS: full-buffer log_prob on rollout_session_data.noise_buffer (fp32)
                    → ODE   : prev_sample = sample + dt * model_output (dtype-preserved)
          _postprocess_rollout_outputs
              → _maybe_collect_rollout_log_probs(batch)
              → _maybe_finalize_dit_env_collection(batch)       # appends final latent, SP-gathers env + trajectory
  → _build_response: split batch → per-sample RolloutResponse with base64-serialized tensors
  → ORJSONResponse

2. Features, API & Reliability

Rollout API Parameters (RolloutRequest + SamplingParams extensions)

Parameter Type Default Description
rollout bool True Enable rollout log-prob computation
rollout_sde_type str "sde" Step strategy: "sde", "cps", or "ode"
rollout_noise_level float 0.7 Noise level for SDE/CPS
rollout_log_prob_no_const bool False Omit constant terms (common for RL loss)
rollout_debug_mode bool True Return intermediate debug tensors
rollout_return_denoising_env bool False Return frozen transformer kwargs for replay
rollout_return_dit_trajectory bool False Return per-step latent model inputs + timesteps

Response Structure (RolloutResponse, per sample)

  • generated_output: Base64-serialized output image / video tensor.
  • rollout_log_probs: [T] per-step log-probability, fp32.
  • rollout_debug_tensors: {variance_noise, prev_sample_mean, noise_std_dev, model_output}, each [T, ...], SP-gathered to full shape.
  • denoising_env: {image_kwargs, pos_cond_kwargs, neg_cond_kwargs, guidance}, SP-gathered.
  • dit_trajectory: {latents [T+1, ...], timesteps [T]}, SP-gathered.

Precision Policy (flowGRPO alignment)

sde_type model_output cast at entry variance_noise dtype log_prob_no_const_val dtype Matches
sde .float() fp32 fp32 flowGRPO sd3_sde_with_logprob.py
cps .float() fp32 fp32 flowGRPO sd3_sde_with_logprob.py
ode unchanged (kept at native dtype) zeros fp32 zeros scheduling_flow_match_euler_discrete.step non-rollout branch, bit-exact

The fp32 entry-cast on SDE/CPS is adopted from FlowGRPO's reference sde_step_with_logprob, whose comment reads "bf16 can overflow here when compute prev_sample_mean, we must convert all variable to fp32". This also avoids the PyTorch wrapped-scalar promotion trap where a 0-dim fp32 noise_std_dev is silently demoted to bf16 when multiplied against an N-dim bf16 variance_noise, which was the actual root cause of the pre-fix SP log-prob drift on enable_autocast=False pipelines (Qwen-Image, Z-Image).

Bug Fixes

  • Missing final-step latent in dit_trajectory. _postprocess_rollout_outputs now appends the last scheduler.step output as the (T+1)-th entry and is reordered to run before _post_denoising_loop so it can be SP-gathered uniformly with the per-step trajectory.
  • Wrong capture tensor for dit_trajectory. Moved the per-step capture hook to the top of the denoising loop; stores the raw pre-scale / pre-I2V-concat latent x_{t_i} instead of the post-scale latent_model_input. Field renamed latent_model_inputslatents (shape [B, T+1, ...]).
  • SP log-prob mismatch. Full-buffer log-prob computation + flowGRPO fp32 entry-cast (see §3.2 / §3.3 — was max_abs_diff=4.88e-4 SDE / 1.95e-3 CPS on Qwen-Image SP=2, now 0).
  • variance_noise aliasing across debug-append calls. append_local_rollout_debug_tensors now clones before appending; on SP=1, _rollout_variance_noise returns a view of the reusable noise_buffer, so without the clone every appended entry held the last step's noise after the buffer was overwritten.
  • Redundant _maybe_collect_rollout_log_probs call. Removed the duplicate call from _post_denoising_loop — it was already invoked from _postprocess_rollout_outputs, and the first one crashed on SP runs because _rollout_session_data had already been released.
  • Rollout CLI defaults leaking into sampling_params. Removed explicit defaults from rollout CLI args so test_get_cli_args_drops_unset_sampling_params passes again.
  • gather_latents_for_sp keyword-arg drift in Z-Image / Qwen-Image rollout pipeline mixins and the debug mixin — small positional/keyword arg fix.

Reliability Testing

Three levels of testing validate correctness:

  1. Unit tests — the full CI unit suite (python3 sglang/multimodal_gen/test/run_suite.py --suite unit) runs 125 tests (124 passed, 1 env-dependent skip: test_integration_with_moto). Of those, 36 are rollout-specific (31 in test_rollout_api.py + 5 in test_scheduler_rollout_unit.py); see §3.1.
  2. FlowGRPO reference alignment: SDE/CPS single-step computation is compared against a verbatim copy of FlowGRPO's sde_step_with_logprob across 4 seeds, for all four output quantities (prev_sample, prev_sample_mean, noise_std_dev, log_prob). All < 1e-6.
  3. Multi-parallel-config trajectory debug on Qwen-Image (50-step, 1024×1024, 2 GPUs) under TP1-SP2 / TP2-SP1 / TP1-SP1-CFGP, comparing all debug tensors + log-prob against the single-GPU reference — see §3.2. A secondary Z-Image-Turbo run with the same harness is in Appendix A.

3. Detailed Test Results

3.1 Unit Tests — Full CI unit suite 124 passed / 1 skipped

Command (matches .github/workflows/pr-test-multimodal-gen.yml::multimodal-gen-unit-test):

cd python && python3 sglang/multimodal_gen/test/run_suite.py --suite unit
# => 124 passed, 1 skipped, 3 warnings, 8 subtests passed in 9.91s
# (skipped = test_storage.py::test_integration_with_moto, env-dependent S3-moto integration)

Rollout-specific breakdown (36/36):

python/sglang/multimodal_gen/test/unit/test_rollout_api.py             31 passed
python/sglang/multimodal_gen/test/unit/test_scheduler_rollout_unit.py   5 passed

The 5 scheduler rollout tests lock in the following invariants:

Test Invariant
test_ode_step_does_not_call_variance_noise_sampler ODE never samples variance noise (behavior contract)
test_ode_debug_tensors_have_shape_safe_noise_std ODE debug tensors have the right shapes and are zero-valued
test_ode_bit_exact_with_non_rollout_path (new) ODE rollout branch produces torch.equal(rollout_prev, sample + dt * model_output) with bf16 model_output, locking in ODE ≡ non-rollout bit-exactness
test_single_step_matches_flowgrpo_reference SDE/CPS single step matches a verbatim FlowGRPO reference across 4 seeds, max_abs_diff < 1e-6 for prev_sample, prev_sample_mean, noise_std_dev, log_prob
test_sde_cps_force_fp32_with_bf16_model_output (new) Passing bf16 model_output to SDE/CPS must still yield fp32 log_prob_sum and fp32 noise_buffer, locking in the fp32 entry-cast + guarding against PyTorch wrapped-scalar demotion regression

3.2 Qwen-Image Parallel Consistency (primary)

Model: Qwen/Qwen-Image, seed=42, noise_level=0.5, 50 steps, 1024×1024, 2 GPUs.
Reference: single-GPU (TP1, SP1, no CFGP).
Prompt: "A fluffy silver-gray cat rests on a light surface indoors, wearing small round sunglasses with gold frames and dark blue lenses. …"

SDE mode

Config Tensor Max Abs Diff First Step MAD Last Step MAD Last Step Cosine
TP1 SP2 variance_noise / prev_sample_mean / noise_std_dev / model_output 0 0 0 1.000000
TP1 SP2 log_prob 0 0 0 1.000000
TP2 SP1 all tensors 0 0 0 1.000000
TP2 SP1 log_prob 0 0 0 1.000000
CFGP prev_sample_mean 0.392 3.86e-5 8.90e-3 0.999656
CFGP model_output 1.938 3.75e-3 9.43e-2 0.987609
CFGP log_prob 0 0 0 1.000000

CPS mode

Config Tensor Max Abs Diff First Step MAD Last Step MAD Last Step Cosine
TP1 SP2 variance_noise / prev_sample_mean / noise_std_dev / model_output 0 0 0 1.000000
TP1 SP2 log_prob 0 0 0 1.000000
TP2 SP1 all tensors 0 0 0 1.000000
TP2 SP1 log_prob 0 0 0 1.000000
CFGP prev_sample_mean 0.591 3.86e-5 1.65e-2 0.998944
CFGP model_output 1.953 3.75e-3 1.27e-1 0.971950
CFGP log_prob 0 0 0 1.000000

ODE mode

Config Tensor Max Abs Diff Last Step Cosine
TP1 SP2 all tensors + log_prob 0 1.000000
TP2 SP1 all tensors + log_prob 0 1.000000
CFGP prev_sample_mean 0.547 0.999262
CFGP model_output 4.391 0.972642
CFGP log_prob 0 1.000000

3.3 log_prob before vs after the SP fix

Historical Qwen-Image TP1-SP2 run (pre-fix, from outputs/rollout_trajectory_debug_compare/20260411-080016/trajectory_debug_report.md):

Mode log_prob max_abs_diff (pre-fix) log_prob max_abs_diff (this PR)
sde 4.88e-4 (= 2⁻¹¹) 0
cps 1.95e-3 (= 2⁻⁹) 0
ode 0 0

Key observations (Qwen-Image, primary)

  • Qwen-Image SP2 is bit-exact on every tensor including log_prob.
  • Qwen-Image TP2-SP1 is also bit-exact across all 5 metrics.
  • CFGP (tp1_sp1_cfg1) shows small prev_sample_mean and model_output differences from non-deterministic reduction order in the CFG-parallel transformer forward — not the rollout engine. log_prob, variance_noise and noise_std_dev are all bit-exact.
  • ODE rollout ↔ non-rollout deterministic step is bit-exact at the step-function level (guarded by test_ode_bit_exact_with_non_rollout_path). Any residual drift under TP2/CFGP is exactly what the non-rollout path would also see, consistent with PR [Diffusion] Revamp Rollout Log-Prob Support with SDE/CPS for RL Post-Training #21204's baseline.
  • Appendix A replicates the same experiment on Z-Image-Turbo and shows the same pattern (SP2 bit-exact, log-prob bit-exact on every config; slightly more DiT-side FP drift under TP2-SP1 because of a different DiT architecture).

4. Summary

  • Standalone POST /rollout/generate API (new FastAPI router, separate from /generate), sample-granular, returns per-sample serialized rollout trajectory data (log-probs, debug tensors, denoising env, DiT trajectory) for T2I pipelines.
  • Denoising-environment backpass: RolloutDenoisingMixin captures frozen transformer kwargs + per-step x_{t_i} + final latent, SP-gathered for replay.
  • SP-aligned rollout_log_probs, bit-exact with the single-GPU reference on Qwen-Image and Z-Image-Turbo across all three modes (sde / cps / ode). Achieved by computing log-prob on the full pre-shard noise buffer + flowGRPO's fp32 entry-cast policy for SDE/CPS, while ODE preserves bit-exactness with the non-rollout deterministic path (locked in by unit test).
  • Assorted bug fixes — see §2 "Bug Fixes" for the full list.
  • Full CI unit suite (run_suite.py --suite unit) 124 passed / 1 env-skipped, including the 36 rollout-specific tests. Two new invariants lock in (a) ODE ≡ non-rollout bit-exactness and (b) the fp32 entry-cast regression guard against PyTorch wrapped-scalar promotion. SDE/CPS single-step also matches a verbatim FlowGRPO reference across 4 seeds within 1e-6.
  • Total python diff against upstream/main: 20 files, +1276 / −78 lines (excluding unrelated upstream changes).

Appendix A — Z-Image-Turbo parallel consistency (secondary verification)

Z-Image-Turbo was run through the same test_rollout_trajectory_debug_parallel.py harness as Qwen-Image to make sure the SP fix generalizes beyond a single model. Same prompt/seed/noise_level as §3.2; 9 steps instead of 50 because Z-Image-Turbo is turbo-distilled.

Model: Tongyi-MAI/Z-Image-Turbo, seed=42, noise_level=0.5, 9 steps, 1024×1024, 2 GPUs.
Reference: single-GPU (TP1, SP1, no CFGP). max_abs_diff is the max over all 9 steps.

SDE mode

Config Tensor Max Abs Diff First Step MAD Last Step MAD Last Step Cosine
TP1 SP2 variance_noise 0 0 0 1.000000
TP1 SP2 prev_sample_mean 0 0 0 1.000000
TP1 SP2 noise_std_dev 0 0 0 1.000000
TP1 SP2 model_output 0 0 0 1.000000
TP1 SP2 log_prob 0 0 0 1.000000
TP2 SP1 variance_noise 0 0 0 1.000000
TP2 SP1 prev_sample_mean 2.123 3.13e-3 5.55e-2 0.997898
TP2 SP1 noise_std_dev 0 0 0 1.000000
TP2 SP1 model_output 6.469 6.91e-2 2.19e-1 0.982503
TP2 SP1 log_prob 0 0 0 1.000000
CFGP variance_noise 0 0 0 1.000000
CFGP prev_sample_mean 1.747 4.15e-4 2.45e-2 0.999589
CFGP noise_std_dev 0 0 0 1.000000
CFGP model_output 3.828 9.16e-3 1.55e-1 0.990435
CFGP log_prob 0 0 0 1.000000

CPS mode

Config Tensor Max Abs Diff First Step MAD Last Step MAD Last Step Cosine
TP1 SP2 variance_noise / prev_sample_mean / noise_std_dev / model_output 0 0 0 1.000000
TP1 SP2 log_prob 0 0 0 1.000000
TP2 SP1 prev_sample_mean 2.880 3.13e-3 6.35e-2 0.996891
TP2 SP1 model_output 6.766 6.91e-2 2.95e-1 0.958455
TP2 SP1 log_prob 0 0 0 1.000000
CFGP prev_sample_mean 2.146 4.15e-4 3.48e-2 0.998968
CFGP model_output 5.672 9.16e-3 2.34e-1 0.970720
CFGP log_prob 0 0 0 1.000000

ODE mode

Config Tensor Max Abs Diff Last Step Cosine
TP1 SP2 all tensors + log_prob 0 1.000000
TP2 SP1 prev_sample_mean 2.541 0.998600
TP2 SP1 model_output 4.375 0.983847
TP2 SP1 log_prob 0 1.000000
CFGP prev_sample_mean 1.409 0.999698
CFGP model_output 3.992 0.990114
CFGP log_prob 0 1.000000

Z-Image-Turbo observation: SP2 is bit-exact on every tensor including log_prob. TP2-SP1 shows slightly more prev_sample_mean / model_output drift than Qwen-Image does (reduction-order non-determinism in a different DiT architecture), but log_prob, variance_noise and noise_std_dev are all still bit-exact, confirming the full-buffer log-prob path is insulated from downstream DiT non-determinism regardless of the backbone.

MikukuOvO and others added 30 commits April 9, 2026 08:06
Replace the original monolithic flow_matching_with_logprob patch with a
modular mixin-based architecture:

- SchedulerRLMixin: core rollout logic (prepare, SDE/CPS/ODE sampling,
  log-prob accumulation, resource lifecycle)
- SchedulerRLDebugMixin: optional debug tensor collection
- RolloutSessionData: per-batch state dataclass stored on batch object
- All rollout state lives on the batch, keeping the scheduler stateless

Made-with: Cursor
- Pass batch object through scheduler.step() to enable per-request rollout
- Add _maybe_prepare_rollout / _maybe_collect_rollout_log_probs lifecycle
  hooks in the denoising stage
- Wire rollout flow through decoding and denoising_dmd stages

Made-with: Cursor
- Add rollout_sde_type, rollout_noise_level, rollout_log_prob_no_const,
  rollout_debug_mode to SamplingParams with validation
- Propagate parameters through OpenAI-compatible image/video endpoints
- Wire through diffusion_generator and gpu_worker

Made-with: Cursor
- ODE mode: bit-exact alignment against FlowGRPO reference implementation
- SDE/CPS mode: verify log-prob sign, shape, noise injection behavior
- Validate prepare/consume/release lifecycle and edge cases

Made-with: Cursor
…ults

- Run isort/black/ruff formatting on all changed files
- Remove unused TeaCacheParams imports from schedule_batch.py (F401)
- Rewrite FlowGRPO alignment test: use verbatim reference from
  sd3_sde_with_logprob.py, verify log_prob at atol=1e-6
- Match FlowGRPO convention: SDE uses full Gaussian log-prob
  (no_const=False), CPS uses no_const=True
- Remove explicit defaults from rollout CLI args to fix
  test_get_cli_args_drops_unset_sampling_params

Made-with: Cursor
Rollout is an internal post-training feature; it should not be exposed
through the standardized OpenAI image/video generation endpoints.
Parameters remain accessible via SamplingParams CLI and direct generator API.

Made-with: Cursor
…rollout_unit

The file tests ODE, SDE, and CPS modes — the old name was misleading.

Made-with: Cursor
Rockdu and others added 7 commits April 9, 2026 09:50
Rename ``_maybe_append_dit_env_step`` to ``_maybe_append_dit_trajectory_step``
and switch it from capturing the post-scale/post-I2V-concat
``latent_model_input`` to the raw pre-processing ``latents`` (x_{t_i}). The
call is moved to the top of each denoising step and both it and
``_postprocess_rollout_outputs`` are now gated with ``batch.rollout`` at the
call site to keep the non-rollout main path strictly untouched.

``_postprocess_rollout_outputs`` additionally appends the last
``scheduler.step`` output as the (T+1)-th latent, and is reordered to run
before ``_post_denoising_loop`` so that the final latent is still SP-sharded
and can be gathered uniformly alongside the per-step trajectory via
``gather_stacked_latents_for_sp``.

Field renames propagated through ``RolloutDitTrajectory``, rollout_api
serialization, and the unit tests: ``latent_model_inputs`` -> ``latents``
(shape ``[B, T+1, ...]``); internal state keys
``trajectory_latent_model_inputs`` -> ``step_latents``.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Fix existing flowGRPO alignment mock to also populate
  rollout_session_data.noise_buffer (required by full-buffer log_prob path)
- Add test_ode_bit_exact_with_non_rollout_path: locks in that ODE rollout
  produces the same prev_sample as scheduling_flow_match_euler_discrete's
  non-rollout deterministic branch using bf16 model_output
- Add test_sde_cps_force_fp32_with_bf16_model_output: regression test for
  the wrapped-scalar promotion trap; asserts that SDE/CPS branches keep
  log_prob and noise buffer in fp32 even when given bf16 model_output

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@github-actions github-actions Bot added the diffusion SGLang Diffusion label Apr 11, 2026
@Rockdu Rockdu changed the title Feat/denoising backpass [Diffusion] Standalone Rollout API + Denoising Environment Backpass + SP-Aligned Log-Prob for T2I Post-Training Apr 11, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new Rollout HTTP API for post-training and reinforcement learning workflows, featuring new request/response structures, tensor serialization utilities, and a RolloutDenoisingMixin for collecting denoising environments and trajectories. The update also refactors sequence parallelism handling and ensures numerical precision during sampling by casting to float32. Feedback suggests making the _kwargs_to_cpu helper fully recursive to correctly handle nested sequences and preserve tuple types.

Comment thread python/sglang/multimodal_gen/runtime/post_training/rollout_denoising_mixin.py Outdated
Rockdu and others added 2 commits April 11, 2026 23:41
- Remove QwenImageEditRolloutPipelineMixin and its Edit pipeline wiring;
  rollout is T2I-only for now, and the base gather_dit_env_static_for_sp
  on QwenImagePipelineConfig is reused.
- Generalize _kwargs_to_cpu in rollout_denoising_mixin.py to a recursive
  walker handling tensor / dict / list / tuple uniformly.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Matches CI pre-commit output: 3 isort fixes (http_server.py, rollout_api.py,
zimage_rollout_pipeline_mixin.py) and 5 black fixes (sampling_params.py,
rollout_api.py, sp_utils.py, rollout_denoising_mixin.py, test_rollout_api.py).
No logic changes.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@@ -0,0 +1,27 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pipeline_configs/mixins

@yhyang201
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yhyang201
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

…pass

# Conflicts:
#	python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py
#	python/sglang/multimodal_gen/test/run_suite.py
@yhyang201
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yhyang201
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

1 similar comment
@yhyang201
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@mickqian mickqian merged commit 47ac830 into sgl-project:main Apr 15, 2026
120 of 130 checks passed
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
… backpass and sp-aligned log-prob for T2I post-training (sgl-project#22604)

Co-authored-by: MikukuOvO <mikukuovo@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants