Skip to content

Optimize LTX2 feed-forward tensor parallelism#23221

Merged
BBuf merged 1 commit intosgl-project:mainfrom
BBuf:codex/ltx2-row-parallel-ffn
Apr 21, 2026
Merged

Optimize LTX2 feed-forward tensor parallelism#23221
BBuf merged 1 commit intosgl-project:mainfrom
BBuf:codex/ltx2-row-parallel-ffn

Conversation

@BBuf
Copy link
Copy Markdown
Collaborator

@BBuf BBuf commented Apr 20, 2026

Summary

This PR keeps the LTX2 feed-forward intermediate activation sharded under tensor parallelism:

  • proj_in: ColumnParallelLinear(..., gather_output=False)
  • proj_out: RowParallelLinear(..., input_is_parallel=True)

The old path gathered the expanded FFN hidden state across TP ranks before GELU and the output projection. The new path applies GELU on the local shard and uses a row-parallel output projection to reduce back to the full hidden size. This removes the large FFN AllGather path while preserving the checkpoint layout through the existing row-parallel weight loader.

Validation Command

Benchmarked and profiled this exact workflow on 4x H100 80GB, using physical GPUs 4-7:

CUDA_VISIBLE_DEVICES=4,5,6,7 sglang generate \
    --model-path Lightricks/LTX-2.3 \
    --pipeline-class-name LTX2TwoStagePipeline \
    --num-gpus 4 \
    --tp-size 4 \
    --ltx2-two-stage-device-mode resident \
    --enable-torch-compile \
    --prompt "SpongeBob talking with patrick" \
    --width 768 --height 512 \
    --num-frames 121 --warmup --text-encoder-cpu-offload false

The benchmark baseline was origin/main@1ebe1c57e; the optimized run was the same baseline plus this patch. This PR branch is rebased on the current origin/main@69eb95f20 at PR creation time; the intervening main commits do not touch the LTX2 diffusion path.

Benchmark

3 repetitions, warmup enabled. The table uses the warmup-excluded request time reported by sglang generate; stage times are from the command logs.

Metric Main mean Optimized mean Speedup Main runs Optimized runs
Warmup-excluded request 15.5367 s 14.6433 s 1.061x 16.14, 15.64, 14.83 14.11, 15.19, 14.63
Denoise stage 11.6071 s 11.2188 s 1.035x 11.6979, 11.6847, 11.4388 10.9778, 11.5430, 11.1355
Denoise per step 0.3868 s 0.3739 s 1.035x 0.3898, 0.3894, 0.3812 0.3658, 0.3847, 0.3711
Refinement stage 1.3980 s 1.1083 s 1.261x 1.1266, 1.8456, 1.2217 1.0088, 0.9659, 1.3503
Refinement per step 0.4649 s 0.3684 s 1.262x 0.3744, 0.6141, 0.4061 0.3352, 0.3209, 0.4491
Decode stage 1.6787 s 1.6951 s 0.990x 1.7150, 1.6655, 1.6555 1.6607, 1.7771, 1.6476

Full logs and perf JSONs are archived here:

Nsight Systems Kernel Summary

Captured full-workflow Nsight Systems traces for main and optimized with the same command. The profiler adds overhead, so the latency table above is the source of truth for speedup. The kernel summary still shows the intended communication shift clearly.

Kernel group Main GPU time / share Optimized GPU time / share Instance change
NCCL AllGather 10865.5 ms / 12.2% 4929.0 ms / 5.4% 32584 -> 4168
NCCL AllReduce bf16 21611.0 ms / 24.2% 31572.8 ms / 34.9% 44176 -> 58384
NCCL AllReduce f32 12153.9 ms / 13.6% 13789.9 ms / 15.2% 85260 -> 85260
Triton pointwise 15075.8 ms / 16.9% 11221.4 ms / 12.4% 579364 -> 562054
nvJitLink GEMM 12256.4 ms / 13.7% 12126.6 ms / 13.4% 257936 -> 257936
VAE vol2col BF16 4104.1 ms / 4.6% 4101.9 ms / 4.5% 1488 -> 1488

Top kernels:

Main share Main GPU time Optimized share Optimized GPU time Kernel
24.2% 21611.0 ms 34.9% 31572.8 ms ncclDevKernel_AllReduce_Sum_bf16_RING_LL
13.6% 12153.9 ms 15.2% 13789.9 ms ncclDevKernel_AllReduce_Sum_f32_RING_LL
12.2% 10865.5 ms 5.4% 4929.0 ms ncclDevKernel_AllGather_RING_LL
5.2% 4638.5 ms 3.1% 2785.3 ms nvjet_tst_128x192_64x5_2x1_v_bz_coopB_bias_TNN
4.6% 4104.1 ms 4.5% 4101.9 ms at::native::vol2col_kernel<c10::BFloat16>
4.1% 3637.1 ms 3.8% 3463.2 ms at::native::vectorized_elementwise_kernel
3.9% 3500.7 ms - - triton_poi_fused_clone_permute_view_0

Profiler CSVs:

Output Videos And Visual Check

The full pipeline completes successfully with the LTX2.3 stage-2 distilled LoRA path; logs show the adapter merged into 1660 layers.

Representative videos:

Preview frame, r1 side-by-side:

LTX2 row-parallel FFN side-by-side preview

I do not claim frame-level bitwise identity. With this workflow, repeated main runs with the same seed are already non-identical. The measured main-vs-optimized difference is in the same range as repeated main/candidate runs:

Comparison PSNR average SSIM all
Main r1 vs main r3 23.1434 0.7972
Optimized r1 vs optimized r3 24.8673 0.8383
Main r3 vs optimized r3 23.7439 0.8158

Tests

python3 -m py_compile python/sglang/multimodal_gen/runtime/models/dits/ltx_2.py
git diff --check

@github-actions github-actions Bot added the diffusion SGLang Diffusion label Apr 20, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the LTX2FeedForward module for Tensor Parallelism by ensuring intermediate activations remain sharded. It updates proj_in to disable output gathering and changes proj_out to a RowParallelLinear layer. A new unit test using AST parsing has been added to verify these configurations. I have no feedback to provide.

@BBuf BBuf force-pushed the codex/ltx2-row-parallel-ffn branch from 50c4f9c to 9bf02a1 Compare April 20, 2026 07:14
@mickqian
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@BBuf BBuf force-pushed the codex/ltx2-row-parallel-ffn branch from 9bf02a1 to 5110854 Compare April 20, 2026 07:52
@BBuf BBuf merged commit 0d69012 into sgl-project:main Apr 21, 2026
154 of 196 checks passed
zhangying098 pushed a commit to zhangying098/sglang that referenced this pull request Apr 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants