Skip to content

Conversation

@LawJarp-A
Copy link

@LawJarp-A LawJarp-A commented Nov 13, 2025

What does this PR do?

What is TeaCache?

TeaCache (Timestep Embedding Aware Cache) is a training-free caching technique that speeds up diffusion model inference by 1.5x-2.6x by reusing transformer block computations when consecutive timestep embeddings are similar.

Architecture & Design

TeaCache uses a ModelHook to intercept transformer forward passes without modifying model code. The algorithm:

  1. Extracts modulated input from first transformer block (after norm1 + timestep embedding)
  2. Computes relative L1 distance vs previous timestep
  3. Applies model-specific polynomial rescaling: c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]
  4. Accumulates rescaled distance across timesteps
  5. If accumulated < threshold → Reuses cached residual (FAST)
  6. If accumulated >= threshold → Full transformer pass (SLOW, update cache)

Key Design Features:

  • Hook-based: Integrates with HookRegistry and CacheMixin for lifecycle management
  • State Isolation: StateManager with context-aware state for CFG conditional/unconditional branches
  • Model Auto-Detection: Detects model type from class name and config path (specific variants checked first)
  • Boundary Guarantee: First and last timesteps always computed fully for quality
  • Specialized Strategies: Dual residual caching (CogVideoX), per-sequence-length caching (Lumina2)

Supported Models

Model Coefficients Status
FLUX Auto-detected Tested
FLUX-Kontext Auto-detected Ready
Mochi Auto-detected Ready
Lumina2 Auto-detected Ready
CogVideoX (2b/5b/1.5-5B) Auto-detected Ready

All models support automatic coefficient detection based on model class name and config path. Custom coefficients can also be provided via TeaCacheConfig.


Benchmark Results (FLUX.1-dev)

Threshold Time Speedup
Baseline 9.26s 1.00x
0.2 6.85s 1.35x
0.4 5.24s 1.77x
0.6 4.64s 2.00x
0.8 4.18s 2.22x

Benchmark Results (Lumina2)

Threshold Time Speedup
Baseline 3.45s 1.00x
0.2 3.07s 1.12x
0.4 2.27s 1.52x
0.6 1.84s 1.88x

Benchmark Results (CogVideoX-2b)

Threshold Time Speedup
Baseline 26.27s 1.00x
0.3 23.97s 1.10x
0.5 22.57s 1.16x
0.7 19.31s 1.36x
0.9 17.38s 1.51x

Benchmark Results (Mochi)

Threshold Time Speedup
Baseline 7.71s 1.00x
0.05 6.27s 1.23x
0.06 6.03s 1.28x
0.08 5.73s 1.35x
0.10 5.41s 1.42x

Test Hardware: NVIDIA h100
Framework: Diffusers with TeaCache hooks
All tests: Same seed (42) for reproducibility

Usage

from diffusers import FluxPipeline
from diffusers.hooks import TeaCacheConfig

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.to("cuda")

# Enable TeaCache (1.75x speedup with 0.4 threshold)
config = TeaCacheConfig(rel_l1_thresh=0.4)
pipe.transformer.enable_cache(config)

image = pipe("A dragon on a crystal mountain", num_inference_steps=20).images[0]

pipe.transformer.disable_cache()

Configuration Options

The TeaCacheConfig supports the following parameters:

  • rel_l1_thresh (float, default=0.2): Threshold for accumulated relative L1 distance. Recommended values: 0.25 for ~1.5x speedup, 0.4 for ~1.8x, 0.6 for ~2.0x. Mochi models require lower thresholds (0.06-0.09).
  • coefficients (List[float], optional): Polynomial coefficients for rescaling L1 distance. Auto-detected based on model type if not provided.
  • num_inference_steps (int, optional): Total inference steps. Ensures first/last timesteps are always computed. Auto-detected if not provided.
  • num_inference_steps_callback (Callable[[], int], optional): Callback returning total inference steps. Alternative to num_inference_steps.
  • current_timestep_callback (Callable[[], int], optional): Callback returning current timestep. Used for debugging/statistics.

Files Changed

  • src/diffusers/hooks/teacache.py - Core implementation with model-specific forward functions
  • src/diffusers/models/cache_utils.py - CacheMixin integration
  • src/diffusers/hooks/__init__.py - Export TeaCacheConfig and apply_teacache
  • tests/hooks/test_teacache.py - Comprehensive unit tests

Fixes # (issue)
#12589
#12635

Before submitting

Who can review?

@sayakpaul @yiyixuxu @DN6

@sayakpaul sayakpaul requested a review from DN6 November 13, 2025 16:49
@LawJarp-A
Copy link
Author

LawJarp-A commented Nov 13, 2025

Work done

  • Implement teacache for FLUX architecture using hooks (only flux for now)
  • add logging
  • add compatible tests

Waiting for feedback and review :)
cc: @dhruvrnaik @sayakpaul @yiyixuxu

@LawJarp-A LawJarp-A marked this pull request as ready for review November 14, 2025 08:23
@LawJarp-A
Copy link
Author

Hi @sayakpaul @dhruvrnaik any updates?

@sayakpaul
Copy link
Member

@LawJarp-A sorry about the delay on our end. @DN6 will review it soon.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DN6
Copy link
Collaborator

DN6 commented Nov 24, 2025

Hi @LawJarp-A I think we would need TeaCache to be implemented in a model agnostic way in order to merge the PR. The First Block Cache implementation is a good reference for this.

@LawJarp-A
Copy link
Author

Hi @LawJarp-A I think we would need TeaCache to be implemented in a model agnostic way in order to merge the PR. The First Block Cache implementation is a good reference for this.

Yep @DN6 , I agree, I wanted to first implement it just for a single model and get feedback on that before I work on Model agnostic full implementation. I'm sort of working on it, didn't push it yet. I'll take a look at First block cache for reference as well.
On the same note, lemme know if there is anything to add to the current implementation

@LawJarp-A
Copy link
Author

LawJarp-A commented Nov 26, 2025

@DN6 updated it in a more model agnostic way.
Requesting review and feedback

@LawJarp-A
Copy link
Author

Added multi model support, testing it thoroughly though.

@LawJarp-A
Copy link
Author

Hi @DN6 @sayakpaul
Two questions, I'm almost done testing, I'll update the PR with more descriptive results and changes. And do final cleanup/merging etc

  1. Any tests I should write and anything I can refer to for the same?
  2. Added support for other models, I'll add pictures comparison with speedup and threshold to the PR as well?

In the meantime any feedback would be appreciated

@sayakpaul
Copy link
Member

Thanks @LawJarp-A!

Any tests I should write and anything I can refer to for the same?

You can refer to #12569 for testing

Added support for other models, I'll add pictures comparison with speedup and threshold to the PR as well?

Yes, I think that is informative for users.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial feedback. Most important question is it seems like we need to craft different logic based on different model? Can we not keep it model agnostic?


_TEACACHE_HOOK = "teacache"

# Model-specific polynomial coefficients from TeaCache paper/reference implementations
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know if these are just model-agnostic or there's something other dependencies as well (for example num_inference_steps, guidance_scale, etc.)?

Also, can we add a calibration step similar to #12648 so that users can log these coefficients for other models?

@LawJarp-A
Copy link
Author

LawJarp-A commented Dec 8, 2025

I am trying to think if ways we can avoid having the forward model for each model now. Initially that seemed like th ebe

Some initial feedback. Most important question is it seems like we need to craft different logic based on different model? Can we not keep it model agnostic?

t was fine when I wrote for flux, but lumina needed multi stage preprocessing.
I am trying to think how to , but keeping a generic forward might not work very well :/
Firstcache, FirstBlock all work block level, but TeaCache is more model level.
Defo open to ideas :)

LawJarp-A and others added 4 commits December 9, 2025 12:15
…torch.compile support, and clean up coefficient flow

Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
@LawJarp-A
Copy link
Author

LawJarp-A commented Dec 10, 2025

@sayakpaul
Added flux image example in the PR description.
Tested it with Lumina, CogVideoX as well
Could not test with Mochi because of GPU constraints. I can try with cpu offloading maybe

@LawJarp-A
Copy link
Author

LawJarp-A commented Dec 11, 2025

@sayakpaul @DN6 I got the core logic working, and tested it for model my GPU can handle
Right now I have gone for a simple monolithic method, each of the models forward handlers, extractors all in one file. I tried to abstract it as much, but since TeaCache works on model level, rather than blocks (like most of the caches right now, taylor, firstblock etc). It's proven a bit difficult to make it model agnostic.

The current implementation puts all model handlers in a single teacache.py file. This works but has scaling concerns:
I was thinking, since we have to add model specific functions anyway, make them a bit modular deisgn-wise.

Potential refactor: Registry + Handler pattern

diffusers/hooks/
├── teacache/
│   ├── __init__.py           # Public API
│   ├── config.py             # TeaCacheConfig
│   ├── hook.py               # TeaCacheHook (core logic)
│   ├── registry.py           # Handler registry
│   └── handlers/
│       ├── __init__.py       # Auto-imports all handlers
│       ├── base.py           # BaseTeaCacheHandler ABC
│       ├── flux.py
│       ├── mochi.py
│       ├── lumina2.py
│       └── cogvideox.py

Each handler self-registers and encapsulates its logic:

# handlers/flux.py
from .base import BaseTeaCacheHandler
from ..registry import register_handler

@register_handler("Flux", "FluxKontext")
class FluxHandler(BaseTeaCacheHandler):
    coefficients = [4.98651651e02, -2.83781631e02, ...]
    
    def extract_modulated_input(self, module, hidden_states, temb):
        return module.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
    
    def handle_forward(self, module, *args, **kwargs):
        # FLUX-specific forward with ControlNet, LORA, etc.
        ...
# registry.py
_HANDLER_REGISTRY = {}

def register_handler(*model_names):
    def decorator(cls):
        for name in model_names:
            _HANDLER_REGISTRY[name] = cls
        return cls
    return decorator

def get_handler(module) -> BaseTeaCacheHandler:
    for name, handler_cls in _HANDLER_REGISTRY.items():
        if name in module.__class__.__name__:
            return handler_cls()
    raise ValueError(f"No TeaCache handler for {module.__class__.__name__}")

This is similar to how attention processors and schedulers are organized. Happy to refactor if you think it's worth it, or we can keep it simple like now. Since this has proven a bit more of a challenge to integrate than I thought xD would be happy to know if you guys have some ideas.

@LawJarp-A
Copy link
Author

Hey @DN6 @sayakpaul , any updates :)

@LawJarp-A LawJarp-A requested a review from sayakpaul December 16, 2025 06:42
@LawJarp-A
Copy link
Author

@sayakpaul @DN6 checking in again :)

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some high level feedback on the design. The control flow is hard to follow as it switches between the hook object and adapter. The adapters themselves are thin wrappers around a modified forward function, so it would be better to just define them as standalone functions. e.g.

def _flux_forward(
    state: "TeaCacheState", # pass the state to the function not the hook object
    coefficients: List[float],
    rel_l1_thresh: float,
    module: torch.nn.Module,
    hidden_states: torch.Tensor,
    timestep: torch.Tensor,
    pooled_projections: torch.Tensor,
    encoder_hidden_states: torch.Tensor,
    txt_ids: torch.Tensor,
    img_ids: torch.Tensor,
    return_dict: bool = True,
    **kwargs,
):

    if _should_use_cache(state, modulated_inp, coefficients, rel_l1_thresh)
        hidden_states = _apply_cached_residual(state, hidden_states, modulated_inp)
    else:
	    # run compute
	    _update_cache(state, hidden_states, original_hidden_states, modulated_inp)

Since we're hooking the top level forward of the model, we can map this forward function using the class name during hook initialization.

    def initialize_hook(self, module):
        """Initialize hook with model-specific configuration."""
        model_config = _MODEL_CONFIG.get(module.__name__)
        if model_config is None:
            raise ValueError

        if self.config.coefficients is not None:
            self.coefficients = self.config.coefficients
        else:
            self.coefficients = model_config["coefficients"]

        # Initialize state
        self.state_manager = StateManager(TeaCacheState)
        self.forward_fn = model_config["forward_func"]

        return module

Where _MODEL_CONFIG is just a mapping for the forward functions and coefficients

_MODEL_CONFIG = {
    "FluxTransformer2DModel": {
        "forward_func": _flux_forward,
        "coefficients": [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01],
    },
}

Similarly, the methods defined in the hook object could also be turned into utility functions.

def _compute_rescaled_distance(rel_distance: float, coefficients: List[float]) -> float:
    return (
        coefficients[0] * rel_distance**4
        + coefficients[1] * rel_distance**3
        + coefficients[2] * rel_distance**2
        + coefficients[3] * rel_distance
        + coefficients[4]
    )
    
def _should_use_cache(state: "TeaCacheState", ...):
	# Return True or False based on whether to use cache. 
	return 
	
def _update_cache(state: "TeaCacheState)
	return 

def _apply_cached_residual(
    state: "TeaCacheState", input_base: torch.Tensor, modulated_inp: torch.Tensor
) -> torch.Tensor:
    """
    Apply cached residual to input (fast path).
    """
    output = input_base + state.previous_residual
    state.previous_modulated_input = modulated_inp
    state.cnt += 1
    return output

Let's remove passing cache_fn and compute_fn between the hook and the adapter. Use operations directly on the cache state + globally available utility methods. We can also remove the modulation extractors and move that logic into the model specific forward functions.

)
if self.rel_l1_thresh < 0.05:
import warnings
warnings.warn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use logger.warning


registry._set_context(None)

def enable_teacache(self, rel_l1_thresh: float = 0.2, num_inference_steps: int = None, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cacheing should only be enabled through enable_cache and passing the relevant config. Cache specific enabling is not supported.

pipe.to("cuda")
# Enable TeaCache with auto-detection (1.5x speedup)
pipe.transformer.enable_teacache(rel_l1_thresh=0.2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be

        pipe.transformer.enable_cache(...)

We don't enable specific cacheing methods directly

logger.info(f"TeaCache: Using {state.num_steps} inference steps")

def initialize_hook(self, module):
self.state_manager.set_context("teacache")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cache context is typically set in the denoising loop? I think in this case, both conditional and unconditional branches would write to the same cache state when using CFG.


def _flux_modulated_input_extractor(module, hidden_states, timestep_emb):
"""Extract modulated input for FLUX models."""
return module.transformer_blocks[0].norm1(hidden_states, emb=timestep_emb)[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these extractor functions can be folded into the adapter functions of each model. They're thin wrappers around a single line of code.

self.model_type = None

@staticmethod
def _create_rescale_func(coefficients):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to create a rescale func? If we have coefficients set, we should be able to just call the function directly?

def rescale_fn(self):
    return self.coefficients[0] * x**4 + self.coefficients[1] * x**3 + self.coefficients[2] * x**2 + self.coefficients[3] * x + self.coefficients[4]

@LawJarp-A
Copy link
Author

Thanks for the feedback @DN6
I'll take this week and rework it.
I had left some redundant code while trying to figure out the organization, will clean it up

@LawJarp-A
Copy link
Author

The per-model forward code is unavoidable due to different model architectures. The adapter pattern was an attempt to organize this, but I agree standalone functions would be cleaner. I'll refactor.

Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
…ctions

Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
@LawJarp-A
Copy link
Author

Hi @DN6 , I've updated the implementation as you requested:

  • Replaced adapter classes with standalone forward functions
  • Created _MODEL_CONFIG mapping for forward functions and coefficients
  • Removed cache_fn/compute_fn closures - now using direct if/else logic in each forward
  • Extracted utility functions: _should_compute(), _update_state(), _apply_cached_residual()
  • Removed enable_teacache() - now only enable_cache(TeaCacheConfig(...))
  • Inlined modulation extractors into forward functions

This does introduce some code duplication - each forward function now has the same if/else pattern:

  if _should_compute(state, modulated_inp, hook.coefficients, hook.config.rel_l1_thresh):
      # compute full transformer
      _update_state(state, output, original, modulated_inp)
  else:
      output = _apply_cached_residual(state, input, modulated_inp)

But the control flow is now much clearer - you can read each forward function top-to-bottom without jumping between closures and hook methods.

Let me know if you'd like any further changes!

@sayakpaul sayakpaul requested a review from DN6 January 8, 2026 06:33
Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
… isolation

Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
…elpers

Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
@LawJarp-A
Copy link
Author

LawJarp-A commented Jan 12, 2026

@DN6 @sayakpaul I spent the weekend going over the code again to understand and simplify

  • I have updated the cache context to be set in the denosing loop itself
  • removed redundant code
  • tested it with all models on a h100 and updated it in the PR description

I have kept it with per model forward function like you requested instead of the common adapter pattern I was using before.
Please review it now, I think it addresses all the recent feedback I have recieved

Btw, below are the images generated w and w/o cache

Mochi
image

lumina2
image

flux
image

cogxvideo
image

LawJarp-A and others added 2 commits January 12, 2026 16:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants