TF generate refactor - XLA sample#16713
Conversation
tests/gpt2/test_modeling_tf_gpt2.py
Outdated
There was a problem hiding this comment.
This test was pretty much the same as test_lm_generate_gpt2 (the only difference was the starting input_ids)
|
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
If I understand this correctly, if the user passes a seed tuple, the same seed is used on every sampling for the entire generation run? That's not a problem, just making sure I got it right!
There was a problem hiding this comment.
(nit) I'd prefer to move input_ids_length before model_kwargs -> kwargs or model_kwargs is usually the last function arg
patrickvonplaten
left a comment
There was a problem hiding this comment.
Great! Awesome to see such a speed-up !
Looks good to me. The only thing that is not very intuitive to me is that seed is a list of integers - why is this? Is this normal in TF?
|
@patrickvonplaten the If you think it will be unintuitive for users, I can change it so that our |
I see - ok maybe better to leave as is then to be aligned with TF |
tests/t5/test_modeling_tf_t5.py
Outdated
There was a problem hiding this comment.
Small comment but "schöner" is correct here - I hope it's just sampling issues and not a sign of a bug that the XLA version gets it wrong!
There was a problem hiding this comment.
Could you try, as a kind of stupid manual once-off test, asking it to translate a batch of sentences from English to Portuguese and make sure that even if they're different, the quality of the XLA ones is similar to the quality of the manual ones? Numerical bugs can be very annoying to catch, but if the quality is similar then that would make me confident that the XLA implementation is not worse.
Rocketknight1
left a comment
There was a problem hiding this comment.
Overall this looks great! I have a couple of very nitpicky nitpicks, mostly because I had a bug in my XLA implementation of greedy that I didn't notice for a long time because it only degraded the quality of sampling, so now I'm paranoid about catching bugs like that by checking output quality.
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
|
While running tests for T5 (as suggested by @Rocketknight1), I found out that our XLA code is not behaving properly for T5, for both |

What does this PR do?
This PR brings XLA to
sample, ingenerate. Four important details before reviewing:beam_search. I will rebase as soon as the other PR gets merged (the changes were bundled to confirm that it passes all generate tests).greedy_search;statelessfunctions;Finally, tests have been run for the usual models (
gpt2,t5,rag,speech2text,encoder_decoder,vision_encoder_decoder,bart).I've also run a quick sanity check on GPU. Using GPT2+sample, on an Nvidia T4: