Skip to content

[Cache] Don't initialize the cache on meta device#36543

Merged
gante merged 29 commits intohuggingface:mainfrom
gante:non_meta_multi_device_cache
Mar 13, 2025
Merged

[Cache] Don't initialize the cache on meta device#36543
gante merged 29 commits intohuggingface:mainfrom
gante:non_meta_multi_device_cache

Conversation

@gante
Copy link
Contributor

@gante gante commented Mar 4, 2025

What does this PR do?

Context:

  1. On Init cache on meta device #35164, we started initializing the cache on meta device by default. This was done to fix multi-device inference with cache, to postpone device allocation;
  2. On Make cache traceable #35873, we removed the torch.nn.Module inheritance from Cache, to enable tracing (e.g. for optimum)

Sadly, these two changes conflict. On main, we see graph breaks, because we assign new tensors on .update() when the device is meta:

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

model_id = "meta-llama/Llama-3.2-1B-Instruct"

device = "cuda"
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)

generation_config = GenerationConfig(use_cache=True, cache_implementation="static")

processed = tokenizer('Anything', return_tensors="pt")
input_ids = processed["input_ids"].to(device)
output = model.generate(input_ids, generation_config=generation_config)

The fix

Somewhere between removing the torch.nn.Module inheritance and other fixes merged to main, reverting most of #35164 solves both the graph breaks and the issue that PR is meant to work. In other words, the following example works with this PR.

Note that the last line was not working with #35164, but is working with this PR. It gets a graph break, but it is expected (we can't compile multi-device operations)

(run with 2 GPUs)

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_id = 'google/gemma-2-2b-it'
tokenizer = AutoTokenizer.from_pretrained(model_id)

device_map = {"model.embed_tokens": 0, "model.norm": 1, "model.rotary_emb": 1, "lm_head": 0}
num_hidden_layers = 26
for i in range(num_hidden_layers):
    device_map[f"model.layers.{i}"] = 0 if i < 13 else 1

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="bfloat16",
    device_map=device_map,
)

inputs = tokenizer("Today is a beautiful day!", return_tensors='pt').to(0)
out = model(**inputs)

out_2 = model.generate(**inputs, max_new_tokens=2, cache_implementation="static")

Benchmarks/Tests

✅ slow llama tests (vs main)
✅ slow cache tests (vs main)
✅ no slowdowns (vs main)

(main)
Screenshot 2025-03-04 at 17 56 09

(this PR)
Screenshot 2025-03-04 at 17 47 50

@github-actions github-actions bot marked this pull request as draft March 4, 2025 17:27
@github-actions
Copy link
Contributor

github-actions bot commented Mar 4, 2025

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@gante gante requested a review from zucchini-nlp March 4, 2025 17:27
@ZolotukhinM
Copy link

Are there concrete plans on landing this PR and if so which release will that get into? #35164 has caused some troubles for our torch.compile backend, which currently prevents us from using v4.49.0 transformers

@zucchini-nlp
Copy link
Member

Btw, @gante this issue was originally fixing a case in Gemma2 like models when we init cache within model's forward. In multiGPU setting we don't do any layer_device_map so all cache layers get assigned to self.device

I didn't add test probably, but we need to make sure that case doesn't fail before merging, causing all forward call to model fail in multiGPU

@gante gante marked this pull request as ready for review March 6, 2025 10:52
@gante
Copy link
Contributor Author

gante commented Mar 6, 2025

@zucchini-nlp ready for review :)

I've added regression tests, including for the gemma 2 case you mentioned!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante gante force-pushed the non_meta_multi_device_cache branch from e146911 to 9d9e0d6 Compare March 6, 2025 13:16
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like rolling back to manually indicating device for each layer is not the solution we want. Also I am not sure why test for Gemma2 passes now, since we just rolled back to prev version where it was failing

I cannot say for sure, but if a better solution exists which enables us to keep both: torch tracing and auto-device allocation, that would be super cool.

Comment on lines +694 to +696
inputs = tokenizer("Today is a beautiful day!", return_tensors="pt").to(0)
_ = model(**inputs)
_ = model.generate(**inputs, max_new_tokens=2, cache_implementation="static")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how and why this test passes, because in forward pass Gemma2 cache defaults to self.device and thus initializes cache on "cuda:0". Can you check why it failed before in the same situation, but now it doesn't? Or maybe it is not using cache, which is also interesting 🤔

Btw, for generate i think it is hybrid, no?

Copy link
Contributor Author

@gante gante Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, this test is not passing, I've made incompatible changes after an initial fix that made this pass. I'll ping you again when it's fixed

Btw, for generate i think it is hybrid, no?

Correct, updating it. (for testing purposes, however, it was the same, as they share the initialization)

@gante
Copy link
Contributor Author

gante commented Mar 6, 2025

@zucchini-nlp
We can't have both automatic devices and compilation: if we don't do explicit device allocation, we have to create and assign new non-meta tensors in the first forward (or move devices, which ultimately leads to the same problem). This breaks compilation (graph break), because the address of the cache tensors was tagged as static.

Why must the address of the cache tensors be static? It's a requirement for high throughput when we have the cache as a input (as opposed to an internal parameter of the nn.Module, our first implementation of the static cache). You can read more about it here (read from In order to mitigate this problem, the mark_static_address until the end of that paragraph).

To alleviate the difficulty of having to find the right devices, I've exposed the function to find the cache devices :)
EDIT: I've added a .to operation instead in the classes that strictly need them.

@gante
Copy link
Contributor Author

gante commented Mar 6, 2025

@zucchini-nlp the multigpu test is not yet corrected, will tag you for review when it's ready :)

@gante
Copy link
Contributor Author

gante commented Mar 6, 2025

@zucchini-nlp ready now :)

The current solution simultaneously:

  • is compatible with multi-gpu (even when initializing the cache inside the model, although this approach will never be compatible with torch.compile, so I've added a note there); Added tests to prevent regressions.
  • has no cuda graph breaks; Added a test to prevent regressions
  • is compatible with torch.export; Confirmed by RUN_SLOW=1 py.test tests/utils/test_cache_utils.py -k test_static_cache_exportability

self.assertListEqual(expected_output_sentence, batch_out_sentence)
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])

@slow
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated to the original problem, but this test should be slow (uses a non-standard 0.5B model with TF) and is causing test timeouts

modularbot pushed a commit to modular/modular that referenced this pull request Mar 8, 2025
… breakages (#57174)

This applies the unmerged upstream PR
huggingface/transformers#36543 on top of the
4.49 release, to avoid the graph breaks that currently crash the
torch.compile backend.

---------

Co-authored-by: Dan Moldovan <26628547+mdanatg@users.noreply.github.com>
MODULAR_ORIG_COMMIT_REV_ID: 3b6f3cd3b2f21943a0d2401d489c70ada592cc6e
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, makes sense to me. One question, when we stop deprecation in the middle, we just remove warning and no logging needed from us?

@gante
Copy link
Contributor Author

gante commented Mar 11, 2025

we just remove warning and no logging needed from us?

Yes, because no action is needed from the user

(this specific case was about removing an argument that will no longer be removed. If it was a rename, we would need to deprecate the new name)

@gante gante force-pushed the non_meta_multi_device_cache branch from 8e91047 to 2650f0c Compare March 11, 2025 10:59
@gante
Copy link
Contributor Author

gante commented Mar 11, 2025

@ZolotukhinM I missed your comment :D v4.50 should include this fix, but I don't know when will we release it

@gante gante merged commit c416123 into huggingface:main Mar 13, 2025
23 checks passed
@gante gante deleted the non_meta_multi_device_cache branch March 13, 2025 10:13
modularbot pushed a commit to modular/modular that referenced this pull request Mar 25, 2025
… breakages (#57174)

This applies the unmerged upstream PR
huggingface/transformers#36543 on top of the
4.49 release, to avoid the graph breaks that currently crash the
torch.compile backend.

---------

Co-authored-by: Dan Moldovan <26628547+mdanatg@users.noreply.github.com>
MODULAR_ORIG_COMMIT_REV_ID: 3b6f3cd3b2f21943a0d2401d489c70ada592cc6e
@gante gante mentioned this pull request Apr 11, 2025
2 tasks
modularbot pushed a commit to modular/modular that referenced this pull request Dec 9, 2025
… breakages (#57174)

This applies the unmerged upstream PR
huggingface/transformers#36543 on top of the
4.49 release, to avoid the graph breaks that currently crash the
torch.compile backend.

---------

Co-authored-by: Dan Moldovan <26628547+mdanatg@users.noreply.github.com>
MAX_INTEGRATION_TESTS_REV_ID: 3b6f3cd3b2f21943a0d2401d489c70ada592cc6e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants