Paper | 🤗HF Models | Modelscope |
Emu3 is a new suite of state-of-the-art multimodal models trained solely with next-token prediction. By tokenizing images, text, and videos into a discrete space, a single transformer is trained from scratch on a mixture of multimodal sequences.
As a multimodal LLM, Emu3 uses vector quantization to tokenize images into discrete tokens. Discretized image tokens are later fused with text token ids for image and text generation. The model can additionally generate images by predicting image token ids.
Emu3 outperforms several well-established task-specific models in both generation and perception tasks, surpassing flagship open models such as SDXL, LLaVA-1.6 and OpenSora-1.2, while eliminating the need for diffusion or compositional architectures.
- Emu3 is capable of generating high-quality images following the text input, by simply predicting the next vision token. The model naturally supports flexible resolutions and styles.
- Emu3 shows strong vision-language understanding capabilities to see the physical world and provides coherent text responses. Notably, this capability is achieved without depending on a CLIP and a pretrained LLM.
- Emu3 simply generates a video causally by predicting the next token in a video sequence, unlike the video diffusion model as in Sora. With a video in context, Emu3 can also naturally extend the video and predict what will happen next.
- Model weights of Vision Tokenizer, Emu3-Stage1, Emu3-Chat and Emu3-Gen.
- Inference code.
- Training scripts for sft.
Text to Image Generation:
| input | generated image |
|---|---|
prompt: "a portrait of young girl"image cnofig: {ratio:1:1, image_area:720x720} |
![]() |
prompt: "a shiba inu"image cnofig: {ratio:16:9, image_area:720x720} |
![]() |
Image VQA:
| mindspore | ascend driver | cann |
|---|---|---|
| >=2.6.0 | >=24.1.RC3 | >=8.1.RC1 |
Enter this directory and install required packages:
cd mindone
pip install -e .[training]
cd examples/emu3
pip install -r requirements.txtDetails
| Model name | HF Weight | Modelscope | Wisemodel |
|---|---|---|---|
| Emu3-Stage1 | 🤗 HF link | Modelscope link | |
| Emu3-Chat | 🤗 HF link | Modelscope link | Wisemodel link |
| Emu3-Gen | 🤗 HF link | Modelscope link | Wisemodel link |
| Emu3-VisionTokenizer | 🤗 HF link | Modelscope link | Wisemodel link |
An inference script is provided in scripts/infer_img_gen.sh.
An example to generate image is as follows:
Details
from emu3.mllm import Emu3ForCausalLM, Emu3Processor, Emu3Tokenizer
from emu3.tokenizer import Emu3VisionVQImageProcessor, Emu3VisionVQModel
from PIL import Image
from transformers.generation.configuration_utils import GenerationConfig
import mindspore as ms
from mindspore import Tensor, nn
from mindone.transformers.generation.logits_process import (
LogitsProcessorList,
PrefixConstrainedLogitsProcessor,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
)
from mindone.utils.amp import auto_mixed_precision
# prepare model and processor
EMU_HUB = "BAAI/Emu3-Gen"
VQ_HUB = "BAAI/Emu3-VisionTokenizer"
EMU_DTYPE = ms.bfloat16
VQ_DTYPE = ms.bfloat16
# prepare model and processor
model = Emu3ForCausalLM.from_pretrained(
EMU_HUB,
mindspore_dtype=EMU_DTYPE,
use_safetensors=True,
attn_implementation="flash_attention_2",
).set_train(False)
tokenizer = Emu3Tokenizer.from_pretrained(EMU_HUB, padding_side="left")
image_processor = Emu3VisionVQImageProcessor.from_pretrained(VQ_HUB)
image_tokenizer = Emu3VisionVQModel.from_pretrained(
VQ_HUB,
use_safetensors=True,
mindspore_dtype=VQ_DTYPE
).set_train(False)
image_tokenizer = auto_mixed_precision(
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
)
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
# prepare input
POSITIVE_PROMPT = " masterpiece, film grained, best quality."
NEGATIVE_PROMPT = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry."
classifier_free_guidance = 3.0
prompt = "a portrait of young girl."
prompt += POSITIVE_PROMPT
kwargs = dict(
mode='G',
ratio="1:1",
image_area=model.config.image_area,
return_tensors="np",
padding="longest",
)
pos_inputs = processor(text=prompt, **kwargs)
neg_inputs = processor(text=NEGATIVE_PROMPT, **kwargs)
# prepare hyper parameters
GENERATION_CONFIG = GenerationConfig(
use_cache=True,
bos_token_id=model.config.bos_token_id,
eos_token_id=model.config.eos_token_id,
pad_token_id=model.config.pad_token_id,
max_new_tokens=40960,
do_sample=True,
top_k=2048,
)
h = pos_inputs.image_size[:, 0]
w = pos_inputs.image_size[:, 1]
constrained_fn = processor.build_prefix_constrained_fn(h, w)
logits_processor = LogitsProcessorList([
UnbatchedClassifierFreeGuidanceLogitsProcessor(
classifier_free_guidance,
model,
unconditional_ids=neg_inputs.input_ids.to("cuda:0"),
),
PrefixConstrainedLogitsProcessor(
constrained_fn ,
num_beams=1,
),
])
# generate
outputs = model.generate(
Tensor(pos_inputs.input_ids, dtype=ms.int32),
GENERATION_CONFIG,
logits_processor=logits_processor,
attention_mask=Tensor(pos_inputs.attention_mask),
)
out = outputs[0]
if not model.config.img_token_id in out: # img_token_id was deleted in generate() output
out = ops.cat(Tensor([model.config.img_token_id], out))
mm_list = processor.decode(out)
for idx, im in enumerate(mm_list):
if not isinstance(im, Image.Image):
continue
im.save(f"result_{idx}.png")An inference script is provided in scripts/infer_vqa.sh.
An example is as follows:
Details
from emu3.mllm import Emu3ForCausalLM, Emu3Tokenizer
from emu3.mllm.processing_emu3 import Emu3Processor
from emu3.tokenizer import Emu3VisionVQImageProcessor, Emu3VisionVQModel
from PIL import Image
from transformers.generation.configuration_utils import GenerationConfig
import mindspore as ms
from mindspore import Tensor, nn
from mindone.utils.amp import auto_mixed_precision
# model path
EMU_HUB = "BAAI/Emu3-Chat"
VQ_HUB = "BAAI/Emu3-VisionTokenizer"
EMU_DTYPE = ms.bfloat16
VQ_DTYPE = ms.bfloat16
# prepare model and processor
model = Emu3ForCausalLM.from_pretrained(
EMU_HUB,
mindspore_dtype=EMU_DTYPE,
use_safetensors=True,
attn_implementation="flash_attention_2", # optional: "eager"
).set_train(False)
tokenizer = Emu3Tokenizer.from_pretrained(EMU_HUB, padding_side="left")
image_processor = Emu3VisionVQImageProcessor.from_pretrained(VQ_HUB)
image_tokenizer = Emu3VisionVQModel.from_pretrained(
VQ_HUB,
use_safetensors=True,
mindspore_dtype=VQ_DTYPE
).set_train(False)
image_tokenizer = auto_mixed_precision(
image_tokenizer, amp_level="O2", dtype=VQ_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
)
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
# prepare input
text = ["Please describe the image", "请描述该图片"]
image = Image.open("assets/demo.png") # TODO: you need to modify the path here
image = [image, image]
inputs = processor(
text=text,
image=image,
mode="U",
padding_image=True,
padding="longest",
return_tensors="np",
)
# prepare hyper parameters
GENERATION_CONFIG = GenerationConfig(
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
# generate
outputs = model.generate(
Tensor(inputs.input_ids, dtype=ms.int32),
GENERATION_CONFIG,
max_new_tokens=1024,
attention_mask=Tensor(inputs.attention_mask),
)
answers = processor.batch_decode(outputs, skip_special_tokens=True)
for ans in answers:
print(ans)An inference script is provided in scripts/infer_img_rec.sh.
An example to reconstruct image/video is as follows:
Details
import os
from PIL import Image
import mindspore as ms
from mindspore import Tensor, nn
from mindone.utils.amp import auto_mixed_precision
from emu3.tokenizer import Emu3VisionVQImageProcessor, Emu3VisionVQModel
# TODO: you need to modify the path here
MODEL_HUB = "BAAI/Emu3-VisionTokenizer"
MS_DTYPE = ms.bfloat16
model = Emu3VisionVQModel.from_pretrained(
MODEL_HUB,
use_safetensors=True,
mindspore_dtype=MS_DTYPE
).set_train(False)
model = auto_mixed_precision(
model, amp_level="O2", dtype=MS_DTYPE, custom_fp32_cells=[mint.nn.BatchNorm3d]
)
processor = Emu3VisionVQImageProcessor.from_pretrained(MODEL_HUB)
# TODO: you need to modify the path here
VIDEO_FRAMES_PATH = "YOUR_VIDEO_FRAMES_PATH"
video = os.listdir(VIDEO_FRAMES_PATH)
video.sort()
video = [Image.open(os.path.join(VIDEO_FRAMES_PATH, v)) for v in video]
images = processor(video, return_tensors="np")["pixel_values"]
images = Tensor(images).unsqueeze(0)
# image autoencode
image = images[:, 0]
print(image.shape)
codes = model.encode(image)
recon = model.decode(codes)
recon = recon.view(-1, *recon.shape[2:])
recon_image = processor.postprocess(recon)["pixel_values"][0]
recon_image.save("recon_image.png")
# video autoencode
# NOTE: number of frames must be multiple of `model.config.temporal_downsample_factor`
images = images.view(
-1, # if OOM, reduce batch size
model.config.temporal_downsample_factor,
*images.shape[2:],
)
codes = model.encode(images)
recon = model.decode(codes)
recon = recon.view(-1, *recon.shape[2:])
recon_images = processor.postprocess(recon)["pixel_values"]
for idx, im in enumerate(recon_images):
im.save(f"recon_video_{idx}.png")Emu3-Stage1 can be fine-tuned on text-to-image(T2I) or vision-question answering (VQA) tasks.
The model continues training with the next token prediction task using standard cross-entropy loss.
For T2I task, supervision is applied exclusively on vision tokens; while for VQA task, supervision is applied exclusively on response text tokens.
Some SFT scripts are provided in scripts/XX_sft_seq_parallel.sh. To fine-tune Emu3-Stage1, run the following script:
bash scripts/t2i_sft_seq_parallel.sh # for T2I task
bash scripts/vqa_sft_seq_parallel.sh # for VQA task
Training data should store input and output, vision and text tokens.
train/prepare_vision_data.py and train/prepare_vqa_data.py are used to format T2I and VQA data respectively into:
DATA_DIR
├── list
│ │ └── train.json
│ ├── feature
│ │ ├── 0000.ckpt
│ │ ├── 0001.ckpt
│ │ ├── ...
│ │ ├── N-2.ckpt
│ │ └── N-1.ckpttrain.json lists all *.ckpt locations, each *.ckpt stores:
{
"name": name,
"images": token_ids,
"texts": input_prompt,
"response": answer_prompt # optional
}
Input an image or a clip of video frames, outout the reconstructed image(s).
Experiments are tested on Ascend Atlas 800T A2 machines.
- mindspore 2.6.0
| mode | model name | precision* | cards | batch size | resolution | s/step | img/s |
|---|---|---|---|---|---|---|---|
| pynative | Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.42 | 0.41 |
| pynative | Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 0.95 | 4.21 |
| graph | Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 3.06 | 0.33 |
| graph | Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 2.70 | 1.48 |
- mindspore 2.7.0
| mode | model name | precision* | cards | batch size | resolution | s/step | img/s |
|---|---|---|---|---|---|---|---|
| pynative | Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.46 | 0.41 |
| pynative | Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 1.23 | 3.25 |
| graph | Emu3-VisionTokenizer | bfloat16 | 1 | 1 | 768x1360 | 2.76 | 0.36 |
| graph | Emu3-VisionTokenizer | bfloat16 | 1 | 4 (video) | 768x1360 | 2.70 | 1.48 |
*note: mixed precision, BatchNorm3d uses fp32, Conv3d and Flash Attention use fp16.
Input a text prompt, output an image.
Experiments are tested on Ascend Atlas 800T A2 machines with pynative mode.
- mindspore 2.6.0
| model name | precision* | cards | batch size | resolution | flash attn | tokens/s | step |
|---|---|---|---|---|---|---|---|
| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | OFF | 1.68 | 8193 |
| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | ON | 2.13 | 8193 |
- mindspore 2.7.0
| model name | precision* | cards | batch size | resolution | flash attn | tokens/s | step |
|---|---|---|---|---|---|---|---|
| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | OFF | 1.85 | 8193 |
| Emu3-Gen | bfloat16 | 1 | 1 | 720x720 | ON | 2.33 | 8193 |
*note: mixed precision, BatchNorm3d and Emu3RMSNorm use fp32, Conv3d and Flash Attention use fp16.
Input an image and a text prompt, output textual response.
Experiments are tested on Ascend Atlas 800T A2 machines with pynative mode.
- mindspore 2.6.0
| model name | precision* | cards | batch size | resolution | flash attn | tokens/s | step |
|---|---|---|---|---|---|---|---|
| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | OFF | 4.12 | 659 |
| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | ON | 4.37 | 652 |
- mindspore 2.7.0
| model name | precision* | cards | batch size | resolution | flash attn | tokens/s | step |
|---|---|---|---|---|---|---|---|
| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | OFF | 5.15 | 659 |
| Emu3-Chat | bfloat16 | 1 | 1 | 384x384 | ON | 5.16 | 652 |
*note: mixed precision, BatchNorm3d and Emu3RMSNorm use fp32, Conv3d and Flash Attention use fp16.
Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.7.0*.
| mode | stage | pre-trained model | precision* | cards | batch size | resolution | max token | init lr | recompute | zero stage | grad accu | flash attn | sequence parallel | s/step | step | sample/s |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| pynative | stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 1.79 | 400 | 0.56 |
| pynative | stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 4 shards | 1.79 | 400 | 0.56 |
| graph | stage2-T2I | Emu3-Stage1 | float16 | 8 | 1 | 512x512 | 4200 | 1e-6 | ON | 3 | 1 | ON | 8 shards | 34.11 | 400 | 0.03 |
| graph | stage2-VQA | Emu3-Stage1 | float16 | 4 | 1 | 384x384 | 2560 | 1e-5 | ON | 3 | 1 | ON | 4 shards | 20.10 | 400 | 0.05 |
*note: currently it supports training with mindspore 2.7.0 only.
Used mixed precision, BatchNorm3d and Emu3RMSNorm use fp32.


