[TF] Fix Tensorflow XLA Generation on limited seq_len models#33903
[TF] Fix Tensorflow XLA Generation on limited seq_len models#33903ArthurZucker merged 3 commits intohuggingface:mainfrom
TF] Fix Tensorflow XLA Generation on limited seq_len models#33903Conversation
| inputs = inputs_dict["input_ids"] if is_input_ids else inputs_dict["input_features"] | ||
|
|
||
| # fix config for models with additional sequence-length limiting settings | ||
| seq_len = inputs.get_shape()[1] |
There was a problem hiding this comment.
Not sure if input_features also has its sequence length at 1. Just submitting this quickly before I don't have time.
|
@vasqu I suspect the reason that these errors are silent on GPU is that
Either way, reviewing the rest of this now! |
|
@vasqu can you push another empty commit with the |
|
Oh wow, thanks for the insights into this. This is rather counter-intuitive to me tho :D Even some archaic cuda / gpu errors would suffice for me but it is what it is. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
tests pass :) gentle ping @Rocketknight1 |
Rocketknight1
left a comment
There was a problem hiding this comment.
LGTM!
For core maintainers: This generation test fails on some models because it generates past the limit of the model's max_position_embeddings. Because the model is initialized from a config, not a pretrained checkpoint, this PR simply increases config.max_position_embeddings or config.max_target_positions, which seems like the right fix to me.
cc @LysandreJik for core maintainer review!
…ngface#33903) * fix tf xla generation on limited seq_len models * [run-slow] opt * [run-slow] opt
What does this PR do?
Fixes xla generation with models with a limited seq_len that they can generate. It has been recently (re)discovered in #33298, which results in slow runs to fail as we have indexing issues based on absolute positional embeddings (i.e., an index out-of-bounds error). The issue can only be observed on CPUs and not on GPUs (idk why but the error is just silently swallowed).
The solution is to extend the embedding not only by
max_new_tokensbut also by the currentseq_lenat hand. Atm,seq_lenhasn't been considered yet.No fixes but a ref to #33298
(possibly adding an issue if wanted)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@Rocketknight1 @amyeroberts