Generate: improve assisted generation tests#27540
Conversation
|
|
||
| # This for loop is a naive and temporary effort to make the test less flaky. | ||
| failed = 0 | ||
| for i in range(10): |
There was a problem hiding this comment.
This was essentially the same as @is_flaky, but (IMO) less elegant.
Now that we understand the cause for the mismatch (matmul with different shapes), and know that there is no workaround, it is safe to confirm that this test is indeed flaky :)
tests/generation/test_utils.py
Outdated
| if any( | ||
| model_name in model_class.__name__.lower() | ||
| for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"] | ||
| for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet", "seamlessm4t"] |
There was a problem hiding this comment.
Note: seamlessm4t was already in the skip list of test_assisted_decoding_sample, probably for the same post mortem reasons
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for adding!
Great comments to provide context in the tests 🙏 Only comment is about having config.is_decoder set for all these tests. Is the case when config.is_encoder_decoder fully covered?
| @@ -1599,18 +1609,27 @@ def test_assisted_decoding_sample(self): | |||
| config.use_cache = True | |||
| config.is_decoder = True | |||
There was a problem hiding this comment.
Can we also have a test for when config.is_encoder_decoder to make sure any relevant logic is handled there?
|
@amyeroberts It is also not mutually exclusive with All tests that require caching, such as the assisted generation ones, have to set |
|
@gante Thanks for explaining! I thought they were mutually exclusive |

What does this PR do?
Strengthens the test suite for assisted generation. With these modifications, previously found API problems will be properly caught in advance.
Post mortem
Why weren't API problems caught before?
Assisted generation has two loops: the loop to obtain the candidate tokens from the assistant model (inner loop), and the loop to generate the final tokens from the main model (outer loop). Both loops are slightly different depending on whether the main model accepts the matches or not -- there are different code paths depending on whether
n_matches > 0or not.The following cases were being tested and had no API issues:
n_matches == 0n_matches > 0, but we only run 1 iteration of the outer loop👉 We weren't explicitly testing the case where
n_matches > 0AND we ran more than 1 outer loop iteration.If we weren't testing that case, why was the CI randomly red?
Each individual test had a ~97% chance of being green. The (random) assistant model was building the candidate sequence from the most likely tokens from its vocabulary (size = 99), and the main model was comparing the candidate sequence against sampling from its logits. Most of the times,
n_matches == 0, so the test passed. However, sometimes we hadn_matches > 0, but not to the point where it was enough to complete assisted generation in 1 outer loop.👉 There was a low chance (per test) of hitting the failing case, resulting in inconsistent CI failures