[diffusion]: add ERNIE-Image#22439
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces comprehensive support for ErnieImage models, including new configuration files for its DiT architecture, VAE, and sampling parameters, along with the integration of a Mistral3-based text encoder. A key feature added is Prompt Enhancement (PE), which involves a new PE model loader, a dedicated pipeline, and a stage to process and enhance user prompts. The review feedback suggests improving code robustness by explicitly specifying UTF-8 encoding for file operations in several locations and enhancing code organization by moving a local import to the top-level imports.
| if pe_model is not None: | ||
| pe_tokenizer = getattr(pe_model, "pe_tokenizer", None) | ||
| if pe_tokenizer is None: | ||
| from transformers import AutoTokenizer |
…pe_loader.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
| # Performance profiling | ||
| perf_dump_path: Optional[str] = None | ||
| # Prompt enhancement (ErnieImage) | ||
| use_pe: Optional[bool] = None |
There was a problem hiding this comment.
could we avoid modifying the openai endpoint?
There was a problem hiding this comment.
Good point on keeping the OpenAI endpoint clean. What alternative approaches would you suggest for exposing the prompt enhance feature?
There was a problem hiding this comment.
something like:
client.images.generate(
model="baidu/ERNIE-Image",
prompt="...",
extra_body={"use_pe": False},
)There was a problem hiding this comment.
Thanks for the review!
Fixed and pushed.
|
/tag-and-rerun-ci |
|
|
||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| model = model.to(device).eval() | ||
|
|
There was a problem hiding this comment.
nit: Consider using get_local_torch_device() instead of hardcoding torch.device("cuda"), to stay consistent with all other component loaders in the repo.
The current code works fine after the framework has initialized (i.e., torch.cuda.set_device(local_rank) has been called), but there are two potential risks:
- It will fail on non-CUDA platforms (NPU / MUSA, etc.)
- If this loader is invoked before
torch.cuda.set_device, it will default tocuda:0
Every other loader (transformer_loader, vae_loader, text_encoder_loader, etc.) uses get_local_torch_device(). Suggested change:
from sglang.multimodal_gen.runtime.distributed import get_local_torch_device
device = get_local_torch_device()
model = model.to(device).eval()There was a problem hiding this comment.
Thanks for the review!
Fixed and pushed.
|
/tag-and-rerun-ci |
|
/tag-and-rerun-ci |
|
/tag-and-rerun-ci |
|
/rerun-failed-ci |
Motivation
We have introduced a new text-to-image model called ERNIE-Image, which will soon be open-sourced to the community. This PR includes the model architecture definition, the pipeline, the config, as well as the new PE (prompt enhance) module.
Modifications
New files (all additions, no existing files modified except where noted):
configs/models/dits/ernie_image.py(+50):ErnieImageArchConfig(36-layer, 32-head, hidden 4096) andErnieImageDitConfig;param_names_mappingfuses HFgate_proj/up_projinto a singlegate_up_projweight at load time.configs/models/encoders/mistral3.py(+72):Mistral3EncoderArchConfig/Mistral3EncoderConfigfor the Mistral-3 text encoder (26 layers, hidden 3072); declares QKV / gate-up stacked-param mappings for TP loading.configs/models/vaes/ernie_image.py(+57):ErnieImageVAEArchConfig/ErnieImageVAEConfig(8× spatial VAE;use_feature_cache=Trueby default).configs/pipeline_configs/ernie_image.py(+206):ErnieImagePipelineConfig; includes_patchify_latents/_unpatchify_latents(2×2 pixel-shuffle, 4× channel expansion),ernie_image_postprocess_text(extractshidden_states[-2]from Mistral-3), andget_decode_scale_and_shift(reads BatchNorm running stats from VAE for proper latent rescaling).configs/sample/ernie_image.py(+15):ErnieImageSamplingParams— overridesguidance_scale=5.0,num_inference_steps=50,negative_prompt=" ",use_pe=True.runtime/models/dits/ernie_image.py(+477): Full DiT implementation —EmbedND3(3D RoPE),ErnieImageSelfAttention(TPColumnParallelLinearQ/K/V +RowParallelLinearout, QK-RMSNorm),ErnieImageMLP(MergedColumnParallelLineargate-up fusion),ErnieImageSharedAdaLNBlock(single-stream AdaLN block),ErnieImageTransformer2DModel(inheritsCachableDiT+OffloadableDiTMixin).runtime/pipelines/ernie_image.py(+218):ErnieImagePipeline— detects PE module frommodel_index.json, readsmodel_max_lengthfromtokenizer/tokenizer_config.jsonat load time, wires stages: input validation → (optional) PE → text encoding → denoising → decoding.runtime/pipelines_core/stages/model_specific_stages/ernie_image_pe.py(+98):PromptEnhancementStage— wraps prompt + resolution into a JSON{"prompt", "width", "height"}user message, calls PE model, replacesbatch.promptwith enhanced output; skipped whenuse_pe=False.runtime/loader/component_loaders/pe_loader.py(+161):PELoader— loads Mistral-3 causal LM viaAutoModelForCausalLM, prefers Flash Attention 2 with SDPA fallback, readsmodel_max_lengthfromtokenizer_config.json, returnsPEModelWrapperwith a unifiedgenerate()interface.Modified files:
configs/sample/sampling_params.py(+3): Addsuse_pe: bool | None = Nonefield to baseSamplingParamsfor PE toggle pass-through.runtime/entrypoints/openai/protocol.py(+2): Addsuse_pe: Optional[bool]toImageGenerationsRequest.runtime/entrypoints/openai/image_api.py(+1): Forwardsuse_pefrom HTTP request to sampling params.registry.py(+17): RegistersErnieImagePipelineConfig+ErnieImageSamplingParamsforbaidu/ERNIE-Imageandbaidu/ERNIE-Image-Turbowith a case-insensitiveernie-imagedetector.Accuracy Tests
This is a brand new diffusion text-to-image model that will not affect the output of any existing models.
Speed Tests and Profiling
We performed inference experiments with a dataset of 100 proprietary prompts of variable lengths. Model deployment was executed via the command sglang serve --model-path baidu/ERNIE-Image. The following presents our inference results and Python implementation code, encompassing comparative analyses between configurations with the PE module enabled and disabled, benchmarked against our Diffusers implementation.
code:
results:
Chinese
对比汇总
指标 use_pe=False use_pe=True
成功请求数 100 100
失败请求数 0 0
实际耗时(Wall Time) 1612.74s 2896.73s
成功请求平均耗时 16.13s 28.97s
最快请求 15.63s 21.55s
最慢请求 17.99s 42.09s
English:
Comparison Summary
Metric use_pe=False use_pe=True
Successful Requests 100 100
Failed Requests 0 0
Actual Time Elapsed (Wall Time) 1612.74s 2896.73s
Average Time per Successful Request 16.13s 28.97s
Fastest Request 15.63s 21.55s
Slowest Request 17.99s 42.09s
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci