Is your feature request related to a problem? Please describe.
Would be great to be able to load a LoRA to a model compiled with torch.compile
Describe the solution you'd like.
Do load_lora_weights with a compiled pipe (ideally without triggering recompilation)
Currently, running this code:
import torch
from diffusers import DiffusionPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe = pipe.to(device)
pipe.transformer = pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead")
pipe.load_lora_weights("multimodalart/flux-tarot-v1")
It errors:
Loading adapter weights from state_dict led to unexpected keys not found in the model: ['single_transformer_blocks.0.attn.to_k.lora_A.default_3.weight', 'single_transformer_blocks.0.attn.to_k.lora_B.default_3.weight', 'single_transformer_blocks.0.attn.to_q.lora_A.default_3.weight', 'single_transformer_blocks.0.attn.to_q.lora_B.default_3.weight',
When compiled, the state dict of a model seem to add a _orig_mod prefix to all keys
odict_keys(['_orig_mod.time_text_embed.timestep_embedder.linear_1.weight',...
Describe alternatives you've considered.
An alternative is to fuse the LoRA into the model and then compile, however this does not allow for hot swapping LoRAs (as a new pipeline and a new compilation is needed for every LoRA)
Additional context.
This seems to have been achieved by @chengzeyi , author of the now paused https://github.com/chengzeyi/stable-fast , however it seems to be part of the non-open source FAL optimized inference (however if you'd like to contribute this upstream, feel free!)
Is your feature request related to a problem? Please describe.
Would be great to be able to load a LoRA to a model compiled with
torch.compileDescribe the solution you'd like.
Do
load_lora_weightswith a compiledpipe(ideally without triggering recompilation)Currently, running this code:
It errors:
When compiled, the state dict of a model seem to add a _orig_mod prefix to all keys
Describe alternatives you've considered.
An alternative is to fuse the LoRA into the model and then compile, however this does not allow for hot swapping LoRAs (as a new pipeline and a new compilation is needed for every LoRA)
Additional context.
This seems to have been achieved by @chengzeyi , author of the now paused https://github.com/chengzeyi/stable-fast , however it seems to be part of the non-open source FAL optimized inference (however if you'd like to contribute this upstream, feel free!)