Skip to content

TF generate refactor - XLA sample#16713

Merged
gante merged 9 commits intohuggingface:mainfrom
gante:xla_sample
Apr 18, 2022
Merged

TF generate refactor - XLA sample#16713
gante merged 9 commits intohuggingface:mainfrom
gante:xla_sample

Conversation

@gante
Copy link
Contributor

@gante gante commented Apr 11, 2022

What does this PR do?

This PR brings XLA to sample, in generate. Four important details before reviewing:

  1. The diff has the changes of TF beam search: handle case without past #16704, review that PR first plz :) It fixes a test from beam_search. I will rebase as soon as the other PR gets merged (the changes were bundled to confirm that it passes all generate tests).
  2. The body is mostly copy/paste from greedy_search;
  3. The sample step was changed from the previous implementation -- if we want to seed sampling with XLA, we need to use the stateless functions;
  4. The XLA sample tests do not compare all generated tokens to their non-XLA sample counterparts, due to the numerical instabilities discussed on Slack. We do compare the first tokens, which are the same.

Finally, tests have been run for the usual models (gpt2, t5, rag, speech2text, encoder_decoder, vision_encoder_decoder, bart).


I've also run a quick sanity check on GPU. Using GPT2+sample, on an Nvidia T4:

  • eager TF: ~1.7s
  • XLA TF: ~54ms (~22s compile time) 👉 31x speedup

Comment on lines 450 to 462
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was pretty much the same as test_lm_generate_gpt2 (the only difference was the starting input_ids)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 11, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand this correctly, if the user passes a seed tuple, the same seed is used on every sampling for the entire generation run? That's not a problem, just making sure I got it right!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) I'd prefer to move input_ids_length before model_kwargs -> kwargs or model_kwargs is usually the last function arg

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Awesome to see such a speed-up !

Looks good to me. The only thing that is not very intuitive to me is that seed is a list of integers - why is this? Is this normal in TF?

@gante
Copy link
Contributor Author

gante commented Apr 12, 2022

@patrickvonplaten the stateless TF functions accept a seed argument that is a tuple of two integers 😅 Not very intuitive, I agree. They correspond to the key and counter used in the internal RNG algorithms (source).

If you think it will be unintuitive for users, I can change it so that our seed argument corresponds to the key of the tuple (i.e. a single integer), and fix the counter to 0. For practical purposes, it should be the same thing.

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten the stateless TF functions accept a seed argument that is a tuple of two integers sweat_smile Not very intuitive, I agree. They correspond to the key and counter used in the internal RNG algorithms (source).

If you think it will be unintuitive for users, I can change it so that our seed argument corresponds to the key of the tuple (i.e. a single integer), and fix the counter to 0. For practical purposes, it should be the same thing.

I see - ok maybe better to leave as is then to be aligned with TF

Comment on lines 536 to 537
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small comment but "schöner" is correct here - I hope it's just sampling issues and not a sign of a bug that the XLA version gets it wrong!

Copy link
Member

@Rocketknight1 Rocketknight1 Apr 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you try, as a kind of stupid manual once-off test, asking it to translate a batch of sentences from English to Portuguese and make sure that even if they're different, the quality of the XLA ones is similar to the quality of the manual ones? Numerical bugs can be very annoying to catch, but if the quality is similar then that would make me confident that the XLA implementation is not worse.

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks great! I have a couple of very nitpicky nitpicks, mostly because I had a bug in my XLA implementation of greedy that I didn't notice for a long time because it only degraded the quality of sampling, so now I'm paranoid about catching bugs like that by checking output quality.

@gante
Copy link
Contributor Author

gante commented Apr 18, 2022

While running tests for T5 (as suggested by @Rocketknight1), I found out that our XLA code is not behaving properly for T5, for both sample and greedy_search. Because the problem is not exclusive to sample, I'm merging this PR and fixing the issue in a future one.

(example)
image

@gante gante merged commit b4ddd26 into huggingface:main Apr 18, 2022
@gante gante deleted the xla_sample branch April 18, 2022 09:58
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants