[Dynamo] Add operator support to run UNet2DConditionModel from diffusers#151
[Dynamo] Add operator support to run UNet2DConditionModel from diffusers#151yaoyaoding merged 22 commits intohidet-org:mainfrom
Conversation
yaoyaoding
left a comment
There was a problem hiding this comment.
Thanks @xinli-git!
Could you also add some tests to the added operators (under tests/frontends/torch)?
And we can also consider to add some model-level tests like the tests in tests/unit_tests/test_frontend_onnx.py but for pytorch frontend. And place the tests in a test script like tests/frontends/torch/models/test_unet.py).
| def baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out: Optional[Tensor] = None) -> Tensor: | ||
| if out is not None: | ||
| raise NotImplementedError("hidet: does not support torch.bmm(..., out=...)") | ||
| return beta * input + alpha * ops.matmul(batch1, batch2) |
There was a problem hiding this comment.
Better to check whether alpha==1 and beta==1 and do not perform the multiplication as much as possible.
Otherwise, we need to write some graph-level pattern rewrite rules to do this simplification.
|
Thanks Yaoyao! Will add tests shortly and let you know |
|
Hi @AndreSlavescu can you also review this PR to see if all the stuff makes sense? I will add the tests shortly following your PR merge |
|
Hey @xinli-git , looks good. I can also review fully when testcases are added.
Once the PR is merged, please update #132 with the modules and operators supported |
This reverts commit a1e8df0.
|
Hi @yaoyaoding, this PR is ready for a final review |
|
Thanks @xinli-git! Looks good to me. Merge this PR now. If you want to track the performance of stable diffusion model, you could add the model to our benchmark cli |
|
Thanks! will check that out |
Stable diffusion leverages UNet2DConditionModel for the diffusion process.
This is a popular model and because of its size, the torch -> onnx -> hidet workflow is difficult and unnatural to work with.
This PR adds operator support required by
torch.compilefor UNet2DConditionModele.g.