Add global_attention_mask to gen_kwargs in Seq2SeqTrainer.prediction_step#16485
Conversation
If global_attention_mask is found in the models inputs (used by certain models, like LED) in the prediction_step method of Seq2SeqTrainer, it is added to the gen_kwargs, which are passed to model.decode(). This allows us to properly set the global attention when decoding.
|
The documentation is not available anymore as the PR was closed or merged. |
|
Not expert enough in Thanks a lot for your PR! |
patrickvonplaten
left a comment
There was a problem hiding this comment.
This looks good to me - this will indeed enable generation for LED.
If @sgugger is ok with adding this somewhat model-specific line to the Trainer, the PR is good to go for me.
|
Hi @JohnGiorgi and @patrickvonplaten, using Just in case this is correct, should I open a new pull request for this? Thanks |
|
Good point @caesar-one ! Yes, it would be nice if you could open a new PR for this |
What does this PR do?
Certain Seq2Seq models (e.g. LED-based models such as PRIMERA) need to pass the
global_attention_masktomodel.generate()so that global attention is computed for particular tokens when decoding. This does not currently happen inSeq2SeqTrainer, but can easily be added by looking forglobal_attention_maskin the provided inputs, and adding them togen_kwargs, much the same way as the regularattention_maskis currently handled. This PR does exactly that.Other changes
transformers/src/transformers/trainer_seq2seq.py.Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sgugger, @patrickvonplaten