Skip to content

[Dynamo] Add operator support to run UNet2DConditionModel from diffusers#151

Merged
yaoyaoding merged 22 commits intohidet-org:mainfrom
xinli-git:dynamo_unet
Apr 7, 2023
Merged

[Dynamo] Add operator support to run UNet2DConditionModel from diffusers#151
yaoyaoding merged 22 commits intohidet-org:mainfrom
xinli-git:dynamo_unet

Conversation

@xinli-git
Copy link
Copy Markdown
Contributor

@xinli-git xinli-git commented Mar 28, 2023

Stable diffusion leverages UNet2DConditionModel for the diffusion process.

This is a popular model and because of its size, the torch -> onnx -> hidet workflow is difficult and unnatural to work with.

This PR adds operator support required by torch.compile for UNet2DConditionModel

e.g.

device= 'cuda'
model_dtype = torch.float16

unet = (
  UNet2DConditionModel.from_pretrained(
      'CompVis/stable-diffusion-v1-4',
      subfolder="unet",
      revision="fp16",
  )
  .eval()
  .to(device)
)

hidet_model = torch.compile(unet, backend="hidet")

batch_size = 1
UNET_INPUTS_CHANNEL = 4
height = width = 512

tokinizer_max_len = 64 
embedding_hidden_size = 768

latents = torch.randn(
    (batch_size * 2, UNET_INPUTS_CHANNEL, height // 8, width // 8),
    device=device,
    dtype=model_dtype
)
t = torch.ones(1, dtype=torch.int64, device=device)
text_embedding = torch.randn(
    batch_size * 2,
    tokinizer_max_len,
    embedding_hidden_size,
    dtype=model_dtype,
    device=device,
)
inputs = (latents, t, text_embedding)

hidet_model(*inputs)

@xinli-git xinli-git changed the title [Dynamo] Add operator support to run UNet [Dynamo] Add operator support to run UNet2DConditionModel from diffusers Mar 28, 2023
Copy link
Copy Markdown
Member

@yaoyaoding yaoyaoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @xinli-git!

Could you also add some tests to the added operators (under tests/frontends/torch)?

And we can also consider to add some model-level tests like the tests in tests/unit_tests/test_frontend_onnx.py but for pytorch frontend. And place the tests in a test script like tests/frontends/torch/models/test_unet.py).

def baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.bmm(..., out=...)")
return beta * input + alpha * ops.matmul(batch1, batch2)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to check whether alpha==1 and beta==1 and do not perform the multiplication as much as possible.

Otherwise, we need to write some graph-level pattern rewrite rules to do this simplification.

@xinli-git
Copy link
Copy Markdown
Contributor Author

Thanks Yaoyao! Will add tests shortly and let you know

@xinli-git
Copy link
Copy Markdown
Contributor Author

Hi @AndreSlavescu can you also review this PR to see if all the stuff makes sense? I will add the tests shortly following your PR merge

@AndreSlavescu
Copy link
Copy Markdown
Contributor

AndreSlavescu commented Mar 30, 2023

Hey @xinli-git , looks good. I can also review fully when testcases are added.

  • For the Group Norm test, can you add it in /hidet/tests/operators/test_norm.py
  • For the Interpolation test, can you make a new file for vision functions called test_vision.py
    and add it under the same directory as shown above.

Once the PR is merged, please update #132 with the modules and operators supported

@xinli-git
Copy link
Copy Markdown
Contributor Author

Hi @yaoyaoding, this PR is ready for a final review

@yaoyaoding
Copy link
Copy Markdown
Member

Thanks @xinli-git!

Looks good to me. Merge this PR now.

If you want to track the performance of stable diffusion model, you could add the model to our benchmark cli
and add a line here. The performance will be tracked at this issue.

@yaoyaoding yaoyaoding merged commit 68faaa5 into hidet-org:main Apr 7, 2023
@xinli-git
Copy link
Copy Markdown
Contributor Author

Thanks! will check that out

@xinli-git xinli-git deleted the dynamo_unet branch April 21, 2023 17:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants