[Cache] Don't initialize the cache on meta device#36543
[Cache] Don't initialize the cache on meta device#36543gante merged 29 commits intohuggingface:mainfrom
meta device#36543Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the |
|
Are there concrete plans on landing this PR and if so which release will that get into? #35164 has caused some troubles for our torch.compile backend, which currently prevents us from using v4.49.0 transformers |
|
Btw, @gante this issue was originally fixing a case in Gemma2 like models when we init cache within model's forward. In multiGPU setting we don't do any I didn't add test probably, but we need to make sure that case doesn't fail before merging, causing all forward call to model fail in multiGPU |
|
@zucchini-nlp ready for review :) I've added regression tests, including for the gemma 2 case you mentioned! |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
e146911 to
9d9e0d6
Compare
zucchini-nlp
left a comment
There was a problem hiding this comment.
I feel like rolling back to manually indicating device for each layer is not the solution we want. Also I am not sure why test for Gemma2 passes now, since we just rolled back to prev version where it was failing
I cannot say for sure, but if a better solution exists which enables us to keep both: torch tracing and auto-device allocation, that would be super cool.
tests/utils/test_cache_utils.py
Outdated
| inputs = tokenizer("Today is a beautiful day!", return_tensors="pt").to(0) | ||
| _ = model(**inputs) | ||
| _ = model.generate(**inputs, max_new_tokens=2, cache_implementation="static") |
There was a problem hiding this comment.
I don't know how and why this test passes, because in forward pass Gemma2 cache defaults to self.device and thus initializes cache on "cuda:0". Can you check why it failed before in the same situation, but now it doesn't? Or maybe it is not using cache, which is also interesting 🤔
Btw, for generate i think it is hybrid, no?
There was a problem hiding this comment.
You're right, this test is not passing, I've made incompatible changes after an initial fix that made this pass. I'll ping you again when it's fixed
Btw, for generate i think it is hybrid, no?
Correct, updating it. (for testing purposes, however, it was the same, as they share the initialization)
…ormers into non_meta_multi_device_cache
…ormers into non_meta_multi_device_cache
…ormers into non_meta_multi_device_cache
…ormers into non_meta_multi_device_cache
|
@zucchini-nlp Why must the address of the cache tensors be static? It's a requirement for high throughput when we have the cache as a input (as opposed to an internal parameter of the
|
|
@zucchini-nlp the multigpu test is not yet corrected, will tag you for review when it's ready :) |
|
@zucchini-nlp ready now :) The current solution simultaneously:
|
| self.assertListEqual(expected_output_sentence, batch_out_sentence) | ||
| self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence]) | ||
|
|
||
| @slow |
There was a problem hiding this comment.
unrelated to the original problem, but this test should be slow (uses a non-standard 0.5B model with TF) and is causing test timeouts
… breakages (#57174) This applies the unmerged upstream PR huggingface/transformers#36543 on top of the 4.49 release, to avoid the graph breaks that currently crash the torch.compile backend. --------- Co-authored-by: Dan Moldovan <26628547+mdanatg@users.noreply.github.com> MODULAR_ORIG_COMMIT_REV_ID: 3b6f3cd3b2f21943a0d2401d489c70ada592cc6e
zucchini-nlp
left a comment
There was a problem hiding this comment.
Thanks, makes sense to me. One question, when we stop deprecation in the middle, we just remove warning and no logging needed from us?
Yes, because no action is needed from the user (this specific case was about removing an argument that will no longer be removed. If it was a rename, we would need to deprecate the new name) |
8e91047 to
2650f0c
Compare
|
@ZolotukhinM I missed your comment :D |
… breakages (#57174) This applies the unmerged upstream PR huggingface/transformers#36543 on top of the 4.49 release, to avoid the graph breaks that currently crash the torch.compile backend. --------- Co-authored-by: Dan Moldovan <26628547+mdanatg@users.noreply.github.com> MODULAR_ORIG_COMMIT_REV_ID: 3b6f3cd3b2f21943a0d2401d489c70ada592cc6e
… breakages (#57174) This applies the unmerged upstream PR huggingface/transformers#36543 on top of the 4.49 release, to avoid the graph breaks that currently crash the torch.compile backend. --------- Co-authored-by: Dan Moldovan <26628547+mdanatg@users.noreply.github.com> MAX_INTEGRATION_TESTS_REV_ID: 3b6f3cd3b2f21943a0d2401d489c70ada592cc6e
What does this PR do?
Context:
metadevice by default. This was done to fix multi-device inference with cache, to postpone device allocation;torch.nn.Moduleinheritance fromCache, to enable tracing (e.g. foroptimum)Sadly, these two changes conflict. On main, we see graph breaks, because we assign new tensors on
.update()when the device ismeta:The fix
Somewhere between removing the
torch.nn.Moduleinheritance and other fixes merged tomain, reverting most of #35164 solves both the graph breaks and the issue that PR is meant to work. In other words, the following example works with this PR.Note that the last line was not working with #35164, but is working with this PR. It gets a graph break, but it is expected (we can't compile multi-device operations)
(run with 2 GPUs)
Benchmarks/Tests
✅ slow llama tests (vs main)
✅ slow cache tests (vs main)
✅ no slowdowns (vs main)
(main)

(this PR)
