Skip to content

[diffusion]: add ERNIE-Image#22439

Merged
mickqian merged 26 commits intosgl-project:mainfrom
dyhsup:ernie-image-dev
Apr 11, 2026
Merged

[diffusion]: add ERNIE-Image#22439
mickqian merged 26 commits intosgl-project:mainfrom
dyhsup:ernie-image-dev

Conversation

@dyhsup
Copy link
Copy Markdown
Contributor

@dyhsup dyhsup commented Apr 9, 2026

Motivation

We have introduced a new text-to-image model called ERNIE-Image, which will soon be open-sourced to the community. This PR includes the model architecture definition, the pipeline, the config, as well as the new PE (prompt enhance) module.

Modifications

New files (all additions, no existing files modified except where noted):

  • configs/models/dits/ernie_image.py (+50): ErnieImageArchConfig (36-layer, 32-head, hidden 4096) and ErnieImageDitConfig; param_names_mapping fuses HF gate_proj/up_proj into a single gate_up_proj weight at load time.
  • configs/models/encoders/mistral3.py (+72): Mistral3EncoderArchConfig / Mistral3EncoderConfig for the Mistral-3 text encoder (26 layers, hidden 3072); declares QKV / gate-up stacked-param mappings for TP loading.
  • configs/models/vaes/ernie_image.py (+57): ErnieImageVAEArchConfig / ErnieImageVAEConfig (8× spatial VAE; use_feature_cache=True by default).
  • configs/pipeline_configs/ernie_image.py (+206): ErnieImagePipelineConfig; includes _patchify_latents / _unpatchify_latents (2×2 pixel-shuffle, 4× channel expansion), ernie_image_postprocess_text (extracts hidden_states[-2] from Mistral-3), and get_decode_scale_and_shift (reads BatchNorm running stats from VAE for proper latent rescaling).
  • configs/sample/ernie_image.py (+15): ErnieImageSamplingParams — overrides guidance_scale=5.0, num_inference_steps=50, negative_prompt=" ", use_pe=True.
  • runtime/models/dits/ernie_image.py (+477): Full DiT implementation — EmbedND3 (3D RoPE), ErnieImageSelfAttention (TP ColumnParallelLinear Q/K/V + RowParallelLinear out, QK-RMSNorm), ErnieImageMLP (MergedColumnParallelLinear gate-up fusion), ErnieImageSharedAdaLNBlock (single-stream AdaLN block), ErnieImageTransformer2DModel (inherits CachableDiT + OffloadableDiTMixin).
  • runtime/pipelines/ernie_image.py (+218): ErnieImagePipeline — detects PE module from model_index.json, reads model_max_length from tokenizer/tokenizer_config.json at load time, wires stages: input validation → (optional) PE → text encoding → denoising → decoding.
  • runtime/pipelines_core/stages/model_specific_stages/ernie_image_pe.py (+98): PromptEnhancementStage — wraps prompt + resolution into a JSON {"prompt", "width", "height"} user message, calls PE model, replaces batch.prompt with enhanced output; skipped when use_pe=False.
  • runtime/loader/component_loaders/pe_loader.py (+161): PELoader — loads Mistral-3 causal LM via AutoModelForCausalLM, prefers Flash Attention 2 with SDPA fallback, reads model_max_length from tokenizer_config.json, returns PEModelWrapper with a unified generate() interface.

Modified files:

  • configs/sample/sampling_params.py (+3): Adds use_pe: bool | None = None field to base SamplingParams for PE toggle pass-through.
  • runtime/entrypoints/openai/protocol.py (+2): Adds use_pe: Optional[bool] to ImageGenerationsRequest.
  • runtime/entrypoints/openai/image_api.py (+1): Forwards use_pe from HTTP request to sampling params.
  • registry.py (+17): Registers ErnieImagePipelineConfig + ErnieImageSamplingParams for baidu/ERNIE-Image and baidu/ERNIE-Image-Turbo with a case-insensitive ernie-image detector.

Accuracy Tests

This is a brand new diffusion text-to-image model that will not affect the output of any existing models.

Speed Tests and Profiling

We performed inference experiments with a dataset of 100 proprietary prompts of variable lengths. Model deployment was executed via the command sglang serve --model-path baidu/ERNIE-Image. The following presents our inference results and Python implementation code, encompassing comparative analyses between configurations with the PE module enabled and disabled, benchmarked against our Diffusers implementation.

code:

#!/usr/bin/env python3
"""
ERNIE-Image use_pe 对比基准测试
用法: python benchmark_pe.py
从 prompt.txt 读取 100 条 prompt,分别测试 use_pe=False 和 use_pe=True 两种模式下
全部请求正常返回的总耗时与平均耗时。
"""

import requests
import argparse
import json
import time
import base64
import sys
from pathlib import Path


def load_prompts(path: str) -> list[str]:
    p = Path(path)
    if not p.exists():
        print(f"错误: 找不到 prompt 文件 {path}")
        sys.exit(1)
    prompts = [line.strip() for line in p.read_text(encoding="utf-8").splitlines() if line.strip()]
    if not prompts:
        print(f"错误: {path} 中没有有效的 prompt")
        sys.exit(1)
    print(f"已加载 {len(prompts)} 条 prompt")
    return prompts


def send_request(url: str, payload: dict, timeout: int) -> tuple[bool, float]:
    """发送单次请求,返回 (是否成功, 耗时秒数)"""
    start = time.time()
    try:
        resp = requests.post(url, json=payload, timeout=timeout)
        elapsed = time.time() - start
        if resp.status_code == 200:
            return True, elapsed
        print(f"  [失败] status={resp.status_code}, prompt={payload['prompt'][:50]}...")
        return False, elapsed
    except requests.exceptions.Timeout:
        elapsed = time.time() - start
        print(f"  [超时] prompt={payload['prompt'][:50]}...")
        return False, elapsed
    except Exception as e:
        elapsed = time.time() - start
        print(f"  [异常] {e}, prompt={payload['prompt'][:50]}...")
        return False, elapsed


def run_benchmark(
    prompts: list[str],
    server_url: str,
    use_pe: bool,
    steps: int,
    width: int,
    height: int,
    guidance_scale: float,
    seed: int,
    timeout: int,
    concurrency: int,
) -> dict:
    """执行一轮基准测试,返回汇总结果。"""
    url = f"{server_url.rstrip('/')}/v1/images/generations"
    label = "use_pe=True" if use_pe else "use_pe=False"
    print(f"\n{'='*60}")
    print(f"开始测试: {label}  |  共 {len(prompts)} 条, 并发={concurrency}")
    print(f"{'='*60}")

    from concurrent.futures import ThreadPoolExecutor, as_completed

    success_count = 0
    fail_count = 0
    total_time = 0.0
    timings: list[float] = []

    def task(prompt: str, idx: int):
        payload = {
            "prompt": prompt,
            "size": f"{width}x{height}",
            "num_inference_steps": steps,
            "guidance_scale": guidance_scale,
            "seed": seed,
            "negative_prompt": " ",
            "response_format": "b64_json",
            "n": 1,
            "use_pe": use_pe,
        }
        ok, elapsed = send_request(url, payload, timeout)
        return idx, ok, elapsed

    wall_start = time.time()

    with ThreadPoolExecutor(max_workers=concurrency) as pool:
        futures = {pool.submit(task, p, i): i for i, p in enumerate(prompts)}
        for future in as_completed(futures):
            idx, ok, elapsed = future.result()
            if ok:
                success_count += 1
                timings.append(elapsed)
            else:
                fail_count += 1
            total_time += elapsed
            # 进度
            done = success_count + fail_count
            if done % 10 == 0 or done == len(prompts):
                print(f"  进度: {done}/{len(prompts)}  成功={success_count}  失败={fail_count}")

    wall_elapsed = time.time() - wall_start

    result = {
        "label": label,
        "use_pe": use_pe,
        "total_prompts": len(prompts),
        "success": success_count,
        "fail": fail_count,
        "wall_time": wall_elapsed,
        "avg_success_time": sum(timings) / len(timings) if timings else 0,
        "min_time": min(timings) if timings else 0,
        "max_time": max(timings) if timings else 0,
        "sum_request_time": sum(timings),
    }
    return result


def print_result(r: dict):
    print(f"\n--- {r['label']} 结果 ---")
    print(f"  总请求数:   {r['total_prompts']}")
    print(f"  成功:       {r['success']}")
    print(f"  失败:       {r['fail']}")
    print(f"  实际耗时(Wall Time): {r['wall_time']:.2f}s")
    print(f"  成功请求累计耗时:     {r['sum_request_time']:.2f}s")
    print(f"  成功请求平均耗时:     {r['avg_success_time']:.2f}s")
    print(f"  最快:       {r['min_time']:.2f}s")
    print(f"  最慢:       {r['max_time']:.2f}s")


def main():
    parser = argparse.ArgumentParser(description="ERNIE-Image use_pe 对比基准测试")
    parser.add_argument("--prompt-file", type=str, default="prompt.txt", help="prompt 文件路径")
    parser.add_argument("--server-url", type=str, default="http://0.0.0.0:8842", help="服务地址")
    parser.add_argument("--steps", type=int, default=50, help="推理步数")
    parser.add_argument("--width", type=int, default=1024, help="图像宽度")
    parser.add_argument("--height", type=int, default=1024, help="图像高度")
    parser.add_argument("--guidance-scale", type=float, default=5.0, help="Guidance scale")
    parser.add_argument("--seed", type=int, default=43, help="随机种子")
    parser.add_argument("--timeout", type=int, default=600, help="单次请求超时(秒)")
    parser.add_argument("--concurrency", type=int, default=1, help="并发数")
    args = parser.parse_args()

    prompts = load_prompts(args.prompt_file)

    # 先跑 use_pe=True
    result_true = run_benchmark(
        prompts, args.server_url, use_pe=True,
        steps=args.steps, width=args.width, height=args.height,
        guidance_scale=args.guidance_scale, seed=args.seed,
        timeout=args.timeout, concurrency=args.concurrency,
    )
    print_result(result_true)

    # 再跑 use_pe=False
    result_false = run_benchmark(
        prompts, args.server_url, use_pe=False,
        steps=args.steps, width=args.width, height=args.height,
        guidance_scale=args.guidance_scale, seed=args.seed,
        timeout=args.timeout, concurrency=args.concurrency,
    )
    print_result(result_false)

    # 汇总对比
    print(f"\n{'='*60}")
    print("对比汇总")
    print(f"{'='*60}")
    print(f"{'指标':<30} {'use_pe=False':>15} {'use_pe=True':>15}")
    print(f"{'-'*60}")
    print(f"{'成功请求数':<30} {result_false['success']:>15} {result_true['success']:>15}")
    print(f"{'失败请求数':<30} {result_false['fail']:>15} {result_true['fail']:>15}")
    print(f"{'实际耗时(Wall Time)':<30} {result_false['wall_time']:>14.2f}s {result_true['wall_time']:>14.2f}s")
    print(f"{'成功请求平均耗时':<30} {result_false['avg_success_time']:>14.2f}s {result_true['avg_success_time']:>14.2f}s")
    print(f"{'最快请求':<30} {result_false['min_time']:>14.2f}s {result_true['min_time']:>14.2f}s")
    print(f"{'最慢请求':<30} {result_false['max_time']:>14.2f}s {result_true['max_time']:>14.2f}s")

    # 保存结果到 JSON
    output = {
        "params": {
            "server_url": args.server_url,
            "steps": args.steps,
            "size": f"{args.width}x{args.height}",
            "guidance_scale": args.guidance_scale,
            "seed": args.seed,
            "timeout": args.timeout,
            "concurrency": args.concurrency,
        },
        "use_pe_false": result_false,
        "use_pe_true": result_true,
    }
    out_file = Path(f"benchmark_result_{time.strftime('%Y%m%d_%H%M%S')}.json")
    out_file.write_text(json.dumps(output, indent=2, ensure_ascii=False), encoding="utf-8")
    print(f"\n结果已保存: {out_file}")


if __name__ == "__main__":
    main()

results:
Chinese

对比汇总

指标 use_pe=False use_pe=True

成功请求数 100 100
失败请求数 0 0
实际耗时(Wall Time) 1612.74s 2896.73s
成功请求平均耗时 16.13s 28.97s
最快请求 15.63s 21.55s
最慢请求 17.99s 42.09s

English:

Comparison Summary

Metric use_pe=False use_pe=True

Successful Requests 100 100
Failed Requests 0 0
Actual Time Elapsed (Wall Time) 1612.74s 2896.73s
Average Time per Successful Request 16.13s 28.97s
Fastest Request 15.63s 21.55s
Slowest Request 17.99s 42.09s

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

@github-actions github-actions Bot added the diffusion SGLang Diffusion label Apr 9, 2026
@dyhsup dyhsup changed the title Ernie image dev [diffusion]: add ERNIE-Image Apr 9, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces comprehensive support for ErnieImage models, including new configuration files for its DiT architecture, VAE, and sampling parameters, along with the integration of a Mistral3-based text encoder. A key feature added is Prompt Enhancement (PE), which involves a new PE model loader, a dedicated pipeline, and a stage to process and enhance user prompts. The review feedback suggests improving code robustness by explicitly specifying UTF-8 encoding for file operations in several locations and enhancing code organization by moving a local import to the top-level imports.

Comment thread python/sglang/multimodal_gen/runtime/loader/component_loaders/pe_loader.py Outdated
Comment thread python/sglang/multimodal_gen/runtime/pipelines/ernie_image.py Outdated
Comment thread python/sglang/multimodal_gen/runtime/pipelines/ernie_image.py Outdated
Comment thread python/sglang/multimodal_gen/runtime/pipelines/ernie_image.py Outdated
if pe_model is not None:
pe_tokenizer = getattr(pe_model, "pe_tokenizer", None)
if pe_tokenizer is None:
from transformers import AutoTokenizer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Local imports should generally be avoided for better code readability and to make dependencies explicit at the top of the file. Please move from transformers import AutoTokenizer to the top-level imports of this module.

dyhsup and others added 5 commits April 9, 2026 17:49
…pe_loader.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
# Performance profiling
perf_dump_path: Optional[str] = None
# Prompt enhancement (ErnieImage)
use_pe: Optional[bool] = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we avoid modifying the openai endpoint?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point on keeping the OpenAI endpoint clean. What alternative approaches would you suggest for exposing the prompt enhance feature?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like:

  client.images.generate(
      model="baidu/ERNIE-Image",
      prompt="...",
      extra_body={"use_pe": False},
  )

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review!
Fixed and pushed.

@mickqian
Copy link
Copy Markdown
Collaborator

mickqian commented Apr 9, 2026

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label Apr 9, 2026

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Consider using get_local_torch_device() instead of hardcoding torch.device("cuda"), to stay consistent with all other component loaders in the repo.

The current code works fine after the framework has initialized (i.e., torch.cuda.set_device(local_rank) has been called), but there are two potential risks:

  1. It will fail on non-CUDA platforms (NPU / MUSA, etc.)
  2. If this loader is invoked before torch.cuda.set_device, it will default to cuda:0

Every other loader (transformer_loader, vae_loader, text_encoder_loader, etc.) uses get_local_torch_device(). Suggested change:

from sglang.multimodal_gen.runtime.distributed import get_local_torch_device

device = get_local_torch_device()
model = model.to(device).eval()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review!
Fixed and pushed.

@yhyang201
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yhyang201
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yhyang201
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yhyang201
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@mickqian mickqian merged commit 8cca974 into sgl-project:main Apr 11, 2026
101 of 107 checks passed
pyc96 pushed a commit to pyc96/sglang that referenced this pull request Apr 14, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants