[Diffusion] modelopt diffusion fp8 support for flux1/flux2 and wan2.2#22365
[Diffusion] modelopt diffusion fp8 support for flux1/flux2 and wan2.2#22365
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements support for NVIDIA ModelOpt FP8 and NVFP4 quantization in SGLang Diffusion, introducing new runtime layers, loading adapters, and tools for checkpoint conversion and accuracy validation. Feedback focuses on generalizing the layer exclusion logic to avoid LLM-specific assumptions and ensuring that parameter metadata is preserved during weight processing by using appropriate utility functions.
| import regex as re | ||
|
|
||
| fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"] | ||
| prefix_split = prefix.split(".") | ||
| for pattern in self.exclude_modules: | ||
| regex_str = pattern.replace(".", r"\.").replace("*", r".*") | ||
| pattern_split = pattern.split(".") | ||
| if re.fullmatch(regex_str, prefix): | ||
| return True | ||
| if ( | ||
| pattern_split[-1] in fused_patterns | ||
| and pattern_split[-1] in prefix_split[-1] | ||
| ): | ||
| assert len(prefix_split) == 5 and len(pattern_split) == 5 | ||
| return True | ||
| return False |
There was a problem hiding this comment.
The is_layer_excluded method contains logic and assertions that are specific to LLM layer structures in sglang.srt (e.g., fused_patterns like q_a_proj and assert len(prefix_split) == 5). These are likely not applicable to diffusion models and could cause runtime errors or incorrect exclusion behavior. Additionally, it's recommended to use the standard re library instead of regex for these simple patterns.
import re
for pattern in self.exclude_modules:
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
if re.fullmatch(regex_str, prefix):
return True
return False| layer.weight = Parameter(quantized_weight.t(), requires_grad=False) | ||
| if self.cutlass_fp8_supported: | ||
| max_w_scale = convert_to_channelwise(max_w_scale, layer.logical_widths) | ||
| layer.weight_scale = Parameter(max_w_scale, requires_grad=False) | ||
| layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) |
There was a problem hiding this comment.
In process_weights_after_loading, replacing layer.weight, layer.weight_scale, and layer.input_scale with plain Parameter objects removes the custom metadata and attributes (like weight_loader, input_dim, etc.) associated with ModelWeightParameter and PerTensorScaleParameter. It is safer to use copy_or_rebind_param to update the data while preserving the parameter types and their metadata.
| layer.weight = Parameter(quantized_weight.t(), requires_grad=False) | |
| if self.cutlass_fp8_supported: | |
| max_w_scale = convert_to_channelwise(max_w_scale, layer.logical_widths) | |
| layer.weight_scale = Parameter(max_w_scale, requires_grad=False) | |
| layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) | |
| copy_or_rebind_param(layer, "weight", quantized_weight.t()) | |
| if self.cutlass_fp8_supported: | |
| max_w_scale = convert_to_channelwise(max_w_scale, layer.logical_widths) | |
| copy_or_rebind_param(layer, "weight_scale", max_w_scale) | |
| copy_or_rebind_param(layer, "input_scale", layer.input_scale.max()) |
|
/tag-and-rerun-ci |
mickqian
left a comment
There was a problem hiding this comment.
some TODOs:
- adapt quantization doc if necessary
- add at least one testcase for modelopt fp8
| """ | ||
| quant_config = get_quant_config(hf_config, component_model_path) | ||
| if quant_config is None and server_args.transformer_weights_path: | ||
| override_quantized_path = maybe_download_model( |
There was a problem hiding this comment.
maybe extract to a dedicated function here to better illustrate the quant load logic
|
Split the ModelOpt FP8 skill and helper tooling out into stacked PR #22492 so this PR stays focused on the runtime / loader / test changes. This PR now only keeps the runtime-side code, docs, and the diffusion FP8 correctness test. |
|
/tag-and-rerun-ci |
Mock maybe_download_model in test_resolve_transformer_quant_load_spec_keeps_nunchaku_hook to prevent it from trying to download a fake local path as an HF repo. #22365 added _resolve_quant_config_from_transformer_override which calls maybe_download_model on the transformer_weights_path, but the test uses a non-existent /tmp path that fails HF Hub validation.
…gl-project#22560) Co-authored-by: Alison Shao <alison.shao@MacBook-Pro-D2W773R9CD.local>
…gl-project#22560) Co-authored-by: Alison Shao <alison.shao@MacBook-Pro-D2W773R9CD.local>


Summary
This PR adds a diffusion-side ModelOpt FP8 loading path for SGLang and a reusable workflow for converting ModelOpt diffusers exports into SGLang-loadable checkpoints.
The main goal is to make ModelOpt FP8 practical for SGLang diffusion models without requiring users to manually reconstruct FP8 checkpoints from
backbone.ptevery time.What changed
Runtime support
modelopt_fp8quantization path for diffusion modelsquant_method=modelopt+quant_algo=FP8into the SGLang diffusion FP8 runtime pathdit_cpu_offloadanddit_layerwise_offloadfor ModelOpt FP8 checkpointsWhy offload is disabled:
FP8 checkpoint conversion
python -m sglang.multimodal_gen.tools.convert_modelopt_fp8_checkpointbackbone.ptweight_scale/input_scalefloat8_e4m3fnweightsignorelayers in their original dtypeThe converter is generic in its core flow. The only model-family-specific part is an optional BF16 fallback profile. Today the validated built-in fallback profile is for FLUX.2.
Validation helper
python -m sglang.multimodal_gen.tools.compare_diffusion_trajectory_similaritySkill
python/sglang/multimodal_gen/.claude/skills/sglang-diffusion-modelopt-quant/SKILL.mdNotes on ModelOpt formats
FP8 currently needs an extra SGLang-side conversion step.
Why:
weight_scaleandinput_scaletensorsbackbone.ptfloat8_e4m3fnweights in the converted checkpointNVFP4 is different:
Published checkpoints
The following converted checkpoints are already published so users do not need to run ModelOpt export + SGLang conversion themselves.
BBuf/flux2-dev-modelopt-fp8-sglang-transformerBBuf/wan22-t2v-a14b-modelopt-fp8-sglang-transformerExample usage:
Note:
transformerFP8 override currently used in our H100 validation runstransformer_2remains loaded from the base model in BF16 for that published recipeValidation
FLUX.2
Validation was run on H100 with nightly-aligned settings and BF16/FP8 output comparisons.
Observed latency:
24.47 stotal,23.21 sdenoising17.13 stotal,16.21 sdenoising30.0%total and30.1%denoisingReduced deterministic validation also showed high latent trajectory agreement:
0.99715.7ms->2.5ms in Profile 5 step's last layer.
Wan2.2
Validation was run on H100 with nightly-aligned settings using the validated primary-transformer FP8 override.
Observed latency:
212.19 stotal,204.09 sdenoising204.38 stotal,196.28 sdenoising3.68%total and3.83%denoisingReduced deterministic validation also showed stable trajectory agreement:
0.9755wan22_bf16_nocompile.mp4
wan22_fp8_nocompile.mp4
Artifacts
For both FLUX.2 and Wan2.2, I collected:
These artifacts were used during local validation and can be attached in review if needed.
Scope
This PR focuses on:
It does not add ModelOpt mixed precision support.
Modifications
Accuracy Tests
Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci