Skip to content

a possible hack for FSMT's SinusoidalPositionalEmbedding peculiarity #7229

@stas00

Description

@stas00

(with normal CIs not running USE_CUDA=1 I completely missed testing this, so found one issue with torchscript tests that I need help with.)

We are talking about FSMT - ported fairseq transformers model.

If I understand correctly their SinusoidalPositionalEmbedding was designed so that it won't be part of the model params
https://github.com/pytorch/fairseq/blob/master/fairseq/modules/sinusoidal_positional_embedding.py#L25
most likely so that it won't be part of the state_dict, and save space in their already huge 3.3GB model dump (well 13GB actually as they use an ensemble of 4 models). I could be wrong about the reason for this design choice.

I had to copy their implementation, and not use Bart's version, since the pretrained weights rely on it, and the positions it produces are different.

So their SinusoidalPositionalEmbedding's self.weights is a normal variable (not a buffer and not a nn.parameter.Parameter). They create a dummy buffer self._float_tensor to hold the device. So when model.to() is called, self._float_tensor gets the right device. During forward self.weights gets to(self._float_tensor) and all is good. So self.weights is kind of a ghost variable. Now you see me and now you don't.

This approach works just fine until we get to torchscript - in particular 2 common tests:

    def test_torchscript_output_attentions(self):
    def test_torchscript_output_hidden_state(self):

which blow up under USE_CUDA=1, with:

Comparison exception:   Expected all tensors to be on the same device, 
but found at least two devices, cuda:0 and cpu!

Everything is on cuda:0 but SinusoidalPositionalEmbedding's self.weights are on cpu still at this point.

The first time it encounters self.weightsinside forward, before it gets a chance to be moved to the device, torchscript blows up. It wants all variables to be on the same device before forward.

Solution 1

So, I solved this problem with the following hack:

class FSMTForConditionalGeneration(PretrainedFSMTModel):
    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        self.base_model.to(*args, **kwargs)
        return self

class FSMTModel(PretrainedFSMTModel):
    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        self.encoder.embed_positions.to(*args, **kwargs)
        self.decoder.embed_positions.to(*args, **kwargs)
        return self

class SinusoidalPositionalEmbedding(nn.Module):
    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        self.weights = self.weights.to(*args, **kwargs)
        return self

It's absolutely crazy, but it works.

Basically it forwards model.to() call to SinusoidalPositionalEmbedding's self.weights, via 3 "bridges".

I thought that each torch module got to() called but that doesn't seem to be the case, I think it traverses the model structure instead and doesn't call to for each module. Hence the 2 classes are involved to bridge it on.

(and there is also half() that needs to be dealt with too, since model.half() won't get forwarded to this non-parameter variable either.)

Solution 2

The second solution is to make SinusoidalPositionalEmbedding's self.weights a parameter, but then we have to hack save/load to not save/ignore-on-load model.encoder.embed_positions.* and model.decoder.embed_positions.* keys.

Solution 3

The third solution is to save the useless weights (useless as they aren't trained and get calculated deterministically).

Perhaps you can think of other solutions.

Thank you.

@sgugger, @patrickvonplaten, @sshleifer, @LysandreJik

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions