Skip to content

Commit 8e14013

Browse files
authored
support Z-Image-Omni-Base (#226)
* support Z-Image-Omni-Base * support non-batch-cfg * fix parallel * fix cfg
1 parent 618f2e4 commit 8e14013

File tree

9 files changed

+1775
-0
lines changed

9 files changed

+1775
-0
lines changed

diffsynth_engine/configs/pipeline.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,13 +307,16 @@ class ZImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig,
307307
vae_dtype: torch.dtype = torch.bfloat16
308308
encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
309309
encoder_dtype: torch.dtype = torch.bfloat16
310+
image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
311+
image_encoder_dtype: torch.dtype = torch.bfloat16
310312

311313
@classmethod
312314
def basic_config(
313315
cls,
314316
model_path: str | os.PathLike | List[str | os.PathLike],
315317
encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
316318
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
319+
image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
317320
device: str = "cuda",
318321
parallelism: int = 1,
319322
offload_mode: Optional[str] = None,
@@ -324,6 +327,7 @@ def basic_config(
324327
device=device,
325328
encoder_path=encoder_path,
326329
vae_path=vae_path,
330+
image_encoder_path=image_encoder_path,
327331
parallelism=parallelism,
328332
use_cfg_parallel=True if parallelism > 1 else False,
329333
use_fsdp=True if parallelism > 1 else False,
@@ -391,6 +395,7 @@ class ZImageStateDicts:
391395
model: Dict[str, torch.Tensor]
392396
encoder: Dict[str, torch.Tensor]
393397
vae: Dict[str, torch.Tensor]
398+
image_encoder: Optional[Dict[str, torch.Tensor]] = None
394399

395400

396401
def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig | WanPipelineConfig | ZImagePipelineConfig):

diffsynth_engine/models/z_image/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33
Qwen3Config,
44
)
55
from .z_image_dit import ZImageDiT
6+
from .z_image_dit_omni_base import ZImageOmniBaseDiT
7+
from .siglip import Siglip2ImageEncoder
68

79
__all__ = [
810
"Qwen3Model",
911
"Qwen3Config",
1012
"ZImageDiT",
13+
"ZImageOmniBaseDiT",
14+
"Siglip2ImageEncoder",
1115
]
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from transformers import Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
2+
import torch
3+
4+
5+
class Siglip2ImageEncoder(Siglip2VisionModel):
6+
def __init__(self, **kwargs):
7+
config = Siglip2VisionConfig(
8+
attention_dropout = 0.0,
9+
dtype = "bfloat16",
10+
hidden_act = "gelu_pytorch_tanh",
11+
hidden_size = 1152,
12+
intermediate_size = 4304,
13+
layer_norm_eps = 1e-06,
14+
model_type = "siglip2_vision_model",
15+
num_attention_heads = 16,
16+
num_channels = 3,
17+
num_hidden_layers = 27,
18+
num_patches = 256,
19+
patch_size = 16,
20+
transformers_version = "4.57.1"
21+
)
22+
super().__init__(config)
23+
self.processor = Siglip2ImageProcessorFast(
24+
**{
25+
"data_format": "channels_first",
26+
"default_to_square": True,
27+
"device": None,
28+
"disable_grouping": None,
29+
"do_convert_rgb": None,
30+
"do_normalize": True,
31+
"do_pad": None,
32+
"do_rescale": True,
33+
"do_resize": True,
34+
"image_mean": [
35+
0.5,
36+
0.5,
37+
0.5
38+
],
39+
"image_processor_type": "Siglip2ImageProcessorFast",
40+
"image_std": [
41+
0.5,
42+
0.5,
43+
0.5
44+
],
45+
"input_data_format": None,
46+
"max_num_patches": 256,
47+
"pad_size": None,
48+
"patch_size": 16,
49+
"processor_class": "Siglip2Processor",
50+
"resample": 2,
51+
"rescale_factor": 0.00392156862745098,
52+
"return_tensors": None,
53+
}
54+
)
55+
56+
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
57+
siglip_inputs = self.processor(images=[image], return_tensors="pt").to(device)
58+
shape = siglip_inputs.spatial_shapes[0]
59+
hidden_state = super().forward(**siglip_inputs).last_hidden_state
60+
B, N, C = hidden_state.shape
61+
hidden_state = hidden_state[:, : shape[0] * shape[1]]
62+
hidden_state = hidden_state.view(shape[0], shape[1], C)
63+
hidden_state = hidden_state.to(torch_dtype)
64+
return hidden_state
65+
66+
@classmethod
67+
def from_state_dict(cls, state_dict, device: str, dtype: torch.dtype):
68+
model = cls()
69+
model.requires_grad_(False)
70+
model.load_state_dict(state_dict, assign=True)
71+
model.to(device=device, dtype=dtype, non_blocking=True)
72+
return model

0 commit comments

Comments
 (0)