Skip to content

Fix remaining issues in beam score calculation#27808

Merged
ArthurZucker merged 6 commits intohuggingface:mainfrom
VsonicV:beamhyp_fix
Dec 8, 2023
Merged

Fix remaining issues in beam score calculation#27808
ArthurZucker merged 6 commits intohuggingface:mainfrom
VsonicV:beamhyp_fix

Conversation

@VsonicV
Copy link
Contributor

@VsonicV VsonicV commented Dec 3, 2023

What does this PR do?

This PR further fixes the remaining issues in beam score calculation following #27351 .
More specifically:

  1. When adding new hypothesis, the hyp in process does not include the next generated token on which the current beam score is calculated, but the hyp in finalize includes all the generated tokens so far. This inconsistency is resolved by changing the add function of BeamHypotheses. Now we directly pass the current length of the generated tokens to add.
  2. When calculating best possible beam score in is_done function of BeamHypotheses, we are directly using max_length without deducting decoder_prompt_len. It is fixed now.
  3. Updated the testing expectation accordingly.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@gante

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks for fixing!

BTW, run RUN_SLOW=1 py.test tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py::ViT2GPT2ModelIntegrationTest::test_inference_coco_en -- this test may need an update in its value

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not mistaken, this is the same as cur_len (L228). I'd suggest renaming cur_len into generated_len, which is more representative of the variable contents!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the input_ids[batch_beam_idx].shape[-1] should be same for each batch_beam_idx, so we can simply use cur_len here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add a note that the else case here exists for retrocompatibility reasons :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder. Added!

@VsonicV
Copy link
Contributor Author

VsonicV commented Dec 6, 2023

Good catch, thanks for fixing!

BTW, run RUN_SLOW=1 py.test tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py::ViT2GPT2ModelIntegrationTest::test_inference_coco_en -- this test may need an update in its value

@gante I have updated the expectation value for this test. I have also incorporated all your suggestions. Ready to go!

@VsonicV
Copy link
Contributor Author

VsonicV commented Dec 6, 2023

@gante I have also updated the usage of cur_len in this Pytorch version following your suggestions in #27814 , now it represents the length of the entire sequence including the decoder prompt, which is consistent with the remaining codebase.

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, thanks for iterating 💛

Note for @ArthurZucker: there may be slow CI failures due to this change. Although I suspect there won't, since the correction is small. In any case, I'll keep an eye on the CI after this gets merged.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can actually remove this if/else, if the result is back into being the same :P

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant if/else removed

@VsonicV
Copy link
Contributor Author

VsonicV commented Dec 8, 2023

All suggested changes are incorporated. Ready to go! @gante @ArthurZucker

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for this! 🤗

score = sum_logprobs / ((hyp.shape[-1] - decoder_prompt_len) ** self.length_penalty)
if generated_len is not None:
score = sum_logprobs / (generated_len**self.length_penalty)
# This 'else' case exists for retrocompatibility
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# This 'else' case exists for retrocompatibility
# This 'else' case exists for backward compatibility

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, PR already merged, maybe let's stay with it for now?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of course no worries

Comment on lines -636 to -640
if is_pt:
expectation = 20
else:
# TODO (joao): fix me
expectation = 13
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! 🔥

@ArthurZucker ArthurZucker merged commit b31905d into huggingface:main Dec 8, 2023
@VsonicV VsonicV deleted the beamhyp_fix branch December 8, 2023 13:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants