Llama: make slow tests green 🟢 #33138
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| ) -> bool: | ||
| """ | ||
| Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. | ||
| Detects whether the optional user-specified attention_mask & the automatically created causal mask can be |
There was a problem hiding this comment.
This function had many >120 char lines, which did not fit on my screen :D The diff in docstrings/comments is exclusively due to breaking likes.
| is_tracing = ( | ||
| torch.jit.is_tracing() | ||
| or isinstance(inputs_embeds, torch.fx.Proxy) | ||
| or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) |
There was a problem hiding this comment.
is_torchdynamo_compiling() is a more comprehensive check.
I've replaced all occurrences of this pattern.
There was a problem hiding this comment.
is torchdynamo_compiling takes into account the version of torch?
There was a problem hiding this comment.
Yes!
It uses a try/except, if the newer functions are not available it falls back to torch._dynamo.is_compiling() (and, if that is not available, it means it can't compile)
| # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static | ||
| # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. | ||
| # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using | ||
| # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 | ||
|
|
There was a problem hiding this comment.
No longer true with #32227
Removed all occurrences of this comment.
| # compiled static cache (removes the cache initialized in the previous check, to confirm we can | ||
| # initialize the cache in full compiled mode) | ||
| model._cache = None | ||
| # end-to-end compiled dynamic cache |
There was a problem hiding this comment.
the reference for end-to-end compilation is with dynamic cache -- I forgot to update this test before the last set of changes in the PR that introduced end-to-end compilation :)
There was a problem hiding this comment.
this includes recompile no? Or dynamic shapes are handled now?
There was a problem hiding this comment.
Yes, it issues a recompile
(but static cache doesn't work at the moment -- both things, recompilations and using static caches, need to be addressed)
| out.logits[0, 0, :15], | ||
| atol=1e-3, | ||
| rtol=1e-3, | ||
| atol=1e-2, |
There was a problem hiding this comment.
We have expected results built for the GPUs in our CI, T4 and A10. However, if I navigate to the original commit and install the torch version at the time (torch 2.3), the test fails on my RTX4090
This means there are probably tiny device-related differences whose explanation is beyond the cuda compute capability major version.
As such, increased the tolerance. To be fair, 1e-2 is within the expected differences for 16-bit computations, just like 1e-5 is for 32-bit computations.
| is_tracing = ( | ||
| torch.jit.is_tracing() | ||
| or isinstance(inputs_embeds, torch.fx.Proxy) | ||
| or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) |
There was a problem hiding this comment.
is torchdynamo_compiling takes into account the version of torch?
| # compiled static cache (removes the cache initialized in the previous check, to confirm we can | ||
| # initialize the cache in full compiled mode) | ||
| model._cache = None | ||
| # end-to-end compiled dynamic cache |
There was a problem hiding this comment.
this includes recompile no? Or dynamic shapes are handled now?
What does this PR do?
Part 1 of #32685 -- update our tests to be sure we don't break things 🤗
Makes slow
llamatests happy on my local environment. Some tests are still failing on our slow CI, mostly due to hardware (e.g. out of memory), but they do not have an impact on #32685 [in other words, in checking the correctness of new changes].The exception is
test_compile_static_cache, which passes when run in isolation but fails if run with other tests due toaccelerate+torch.compileincompatibilities (see our internal discussion here).