Optimize LTX2 feed-forward tensor parallelism#23221
Merged
BBuf merged 1 commit intosgl-project:mainfrom Apr 21, 2026
Merged
Conversation
Contributor
There was a problem hiding this comment.
Code Review
This pull request optimizes the LTX2FeedForward module for Tensor Parallelism by ensuring intermediate activations remain sharded. It updates proj_in to disable output gathering and changes proj_out to a RowParallelLinear layer. A new unit test using AST parsing has been added to verify these configurations. I have no feedback to provide.
50c4f9c to
9bf02a1
Compare
mickqian
approved these changes
Apr 20, 2026
Collaborator
|
/tag-and-rerun-ci |
9bf02a1 to
5110854
Compare
4 tasks
zhangying098
pushed a commit
to zhangying098/sglang
that referenced
this pull request
Apr 23, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR keeps the LTX2 feed-forward intermediate activation sharded under tensor parallelism:
proj_in:ColumnParallelLinear(..., gather_output=False)proj_out:RowParallelLinear(..., input_is_parallel=True)The old path gathered the expanded FFN hidden state across TP ranks before GELU and the output projection. The new path applies GELU on the local shard and uses a row-parallel output projection to reduce back to the full hidden size. This removes the large FFN
AllGatherpath while preserving the checkpoint layout through the existing row-parallel weight loader.Validation Command
Benchmarked and profiled this exact workflow on 4x H100 80GB, using physical GPUs 4-7:
CUDA_VISIBLE_DEVICES=4,5,6,7 sglang generate \ --model-path Lightricks/LTX-2.3 \ --pipeline-class-name LTX2TwoStagePipeline \ --num-gpus 4 \ --tp-size 4 \ --ltx2-two-stage-device-mode resident \ --enable-torch-compile \ --prompt "SpongeBob talking with patrick" \ --width 768 --height 512 \ --num-frames 121 --warmup --text-encoder-cpu-offload falseThe benchmark baseline was
origin/main@1ebe1c57e; the optimized run was the same baseline plus this patch. This PR branch is rebased on the currentorigin/main@69eb95f20at PR creation time; the intervening main commits do not touch the LTX2 diffusion path.Benchmark
3 repetitions, warmup enabled. The table uses the warmup-excluded request time reported by
sglang generate; stage times are from the command logs.Full logs and perf JSONs are archived here:
Nsight Systems Kernel Summary
Captured full-workflow Nsight Systems traces for main and optimized with the same command. The profiler adds overhead, so the latency table above is the source of truth for speedup. The kernel summary still shows the intended communication shift clearly.
Top kernels:
ncclDevKernel_AllReduce_Sum_bf16_RING_LLncclDevKernel_AllReduce_Sum_f32_RING_LLncclDevKernel_AllGather_RING_LLnvjet_tst_128x192_64x5_2x1_v_bz_coopB_bias_TNNat::native::vol2col_kernel<c10::BFloat16>at::native::vectorized_elementwise_kerneltriton_poi_fused_clone_permute_view_0Profiler CSVs:
Output Videos And Visual Check
The full pipeline completes successfully with the LTX2.3 stage-2 distilled LoRA path; logs show the adapter merged into 1660 layers.
Representative videos:
Preview frame, r1 side-by-side:
I do not claim frame-level bitwise identity. With this workflow, repeated main runs with the same seed are already non-identical. The measured main-vs-optimized difference is in the same range as repeated main/candidate runs:
Tests