Skip to content

Conversation

@datquocnguyen
Copy link

@datquocnguyen datquocnguyen commented Sep 22, 2021

Regarding this code line 62 in the encode function in BART hub_interface, in many cases (e.g. using a monolingual vocabulary reduced from an existing multilingual one), an OOV token should be aligned with <unk> index, rather than always being added as a new token type into the vocabulary.

Recent code: https://github.com/pytorch/fairseq/blob/main/fairseq/models/bart/hub_interface.py

    def encode(
        self, sentence: str, *addl_sentences, no_separator=True
    ) -> torch.LongTensor:
        tokens = self.bpe.encode(sentence)
        if len(tokens.split(" ")) > min(self.max_positions) - 2:
            tokens = " ".join(tokens.split(" ")[: min(self.max_positions) - 2])
        bpe_sentence = "<s> " + tokens + " </s>"
        for s in addl_sentences:
            bpe_sentence += " </s>" if not no_separator else ""
            bpe_sentence += " " + self.bpe.encode(s) + " </s>"
        tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=False) # Always add OOV token as new type
        return tokens.long()

Suggest to be as follows (https://github.com/datquocnguyen/fairseq/blob/main/fairseq/models/bart/hub_interface.py):

    def encode(
        self, 
        sentence: str, 
        *addl_sentences, 
        no_separator=True,
        add_if_not_exist=True # Add an extra option
    ) -> torch.LongTensor:
        tokens = self.bpe.encode(sentence)
        if len(tokens.split(" ")) > min(self.max_positions) - 2:
            tokens = " ".join(tokens.split(" ")[: min(self.max_positions) - 2])
        bpe_sentence = "<s> " + tokens + " </s>"
        for s in addl_sentences:
            bpe_sentence += " </s>" if not no_separator else ""
            bpe_sentence += " " + self.bpe.encode(s) + " </s>"
        tokens = self.task.source_dictionary.encode_line(
            bpe_sentence, append_eos=False, add_if_not_exist=add_if_not_exist
        )
        return tokens.long()

With this suggested code, in the case mentioned above, encoding should be fairseq_model.encode(sentence, add_if_not_exist=False)

For mBART and the like, it still encodes and adds extra token types into the vocabulary (e.g. training for new languages) as before: fairseq_model.encode(sentence)

To provide an extra option to convert OOV tokens into <unk> rather than always adding the OOV tokens into the dictionary.
@datquocnguyen datquocnguyen changed the title Update BART hub_interface to add an extra option for handling OOV tokens (e.g. emojis) Update the encode function in BART hub_interface: to add an extra option for not always adding OOV tokens into vocabulary Sep 22, 2021
@stale
Copy link

stale bot commented Mar 2, 2022

This pull request has been automatically marked as stale. If this pull request is still relevant, please leave any comment (for example, "bump"), and we'll keep it open. We are sorry that we haven't been able to prioritize reviewing it yet. Your contribution is very much appreciated.

@stale stale bot added the stale label Mar 2, 2022
@datquocnguyen
Copy link
Author

bump

@stale stale bot removed the stale label Mar 3, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants