Skip to content

Fix Dynamo lru_cache warnings during torch.compile#13384

Merged
sayakpaul merged 4 commits intohuggingface:mainfrom
jiqing-feng:compile
Apr 3, 2026
Merged

Fix Dynamo lru_cache warnings during torch.compile#13384
sayakpaul merged 4 commits intohuggingface:mainfrom
jiqing-feng:compile

Conversation

@jiqing-feng
Copy link
Copy Markdown
Contributor

What does this PR do?

Fixes Dynamo lru_cache warnings when using torch.compile on diffusion pipelines. Two changes:

  1. attention_dispatch.py: dispatch_attention_fn calls is_torch_version(">=", "2.5.0") at runtime, which is @lru_cache-wrapped. Replace with the existing module-level constant _CAN_USE_FLEX_ATTN so Dynamo never traces into it.

  2. torch_utils.py: lru_cache_unless_export only bypasses lru_cache during torch.export (is_exporting). Add is_compiling check so torch.compile also bypasses the cache wrapper.

Reproduce

import torch
from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.2-klein-4B", torch_dtype=torch.bfloat16).to("cpu")
pipe.transformer = torch.compile(pipe.transformer, backend="inductor")
pipe(prompt="a cat", height=256, width=256, num_inference_steps=1, generator=torch.Generator().manual_seed(0))
# Before: UserWarning about lru_cache in attention_dispatch.py / torch_utils.py
# After: no warning

out before fix:

/opt/venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:2435: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function at 'attention_dispatch.py:426'. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS=+dynamo for a DEBUG stack trace.

This call originates from:
  File "/home/jiqing/diffusers/src/diffusers/models/attention_dispatch.py", line 426, in dispatch_attention_fn
    if is_torch_version(">=", "2.5.0"):

  torch._dynamo.utils.warn_once(msg)

Flux2PipelineOutput(images=[<PIL.Image.Image image mode=RGB size=256x256 at 0x7A540282FDA0>])

out after fix:

Flux2PipelineOutput(images=[<PIL.Image.Image image mode=RGB size=256x256 at 0x7A540282FDA0>])

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @sayakpaul . Would you please review this PR? Thanks!

"_parallel_config": parallel_config,
}
if is_torch_version(">=", "2.5.0"):
if _CAN_USE_FLEX_ATTN:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a safe replacement? If so, could you elaborate further?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments for it.

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Left one comment.

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

jiqing-feng commented Apr 3, 2026

Hi @sayakpaul . These failures are unrelated to this PR. They are caused by a missing key in peft==0.18.2.dev0's _MOE_TARGET_MODULE_MAPPING ('llava', 'qwen2_vl'), which is a pre-existing issue in the PEFT dev build. My changes only touch attention_dispatch.py (version check) and torch_utils.py (compile bypass), neither of which is in the LoRA/PEFT code path.

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Failing test is unrelated.

@sayakpaul sayakpaul merged commit a05c8e9 into huggingface:main Apr 3, 2026
10 of 11 checks passed
terarachang pushed a commit to terarachang/diffusers that referenced this pull request Apr 30, 2026
…3384)

* fix compile issue

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* compile friendly

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add comments

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants