-
Notifications
You must be signed in to change notification settings - Fork 32.5k
Description
(with normal CIs not running USE_CUDA=1 I completely missed testing this, so found one issue with torchscript tests that I need help with.)
We are talking about FSMT - ported fairseq transformers model.
If I understand correctly their SinusoidalPositionalEmbedding was designed so that it won't be part of the model params
https://github.com/pytorch/fairseq/blob/master/fairseq/modules/sinusoidal_positional_embedding.py#L25
most likely so that it won't be part of the state_dict, and save space in their already huge 3.3GB model dump (well 13GB actually as they use an ensemble of 4 models). I could be wrong about the reason for this design choice.
I had to copy their implementation, and not use Bart's version, since the pretrained weights rely on it, and the positions it produces are different.
So their SinusoidalPositionalEmbedding's self.weights is a normal variable (not a buffer and not a nn.parameter.Parameter). They create a dummy buffer self._float_tensor to hold the device. So when model.to() is called, self._float_tensor gets the right device. During forward self.weights gets to(self._float_tensor) and all is good. So self.weights is kind of a ghost variable. Now you see me and now you don't.
This approach works just fine until we get to torchscript - in particular 2 common tests:
def test_torchscript_output_attentions(self):
def test_torchscript_output_hidden_state(self):
which blow up under USE_CUDA=1, with:
Comparison exception: Expected all tensors to be on the same device,
but found at least two devices, cuda:0 and cpu!
Everything is on cuda:0 but SinusoidalPositionalEmbedding's self.weights are on cpu still at this point.
The first time it encounters self.weightsinside forward, before it gets a chance to be moved to the device, torchscript blows up. It wants all variables to be on the same device before forward.
Solution 1
So, I solved this problem with the following hack:
class FSMTForConditionalGeneration(PretrainedFSMTModel):
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.base_model.to(*args, **kwargs)
return self
class FSMTModel(PretrainedFSMTModel):
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.encoder.embed_positions.to(*args, **kwargs)
self.decoder.embed_positions.to(*args, **kwargs)
return self
class SinusoidalPositionalEmbedding(nn.Module):
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.weights = self.weights.to(*args, **kwargs)
return self
It's absolutely crazy, but it works.
Basically it forwards model.to() call to SinusoidalPositionalEmbedding's self.weights, via 3 "bridges".
I thought that each torch module got to() called but that doesn't seem to be the case, I think it traverses the model structure instead and doesn't call to for each module. Hence the 2 classes are involved to bridge it on.
(and there is also half() that needs to be dealt with too, since model.half() won't get forwarded to this non-parameter variable either.)
Solution 2
The second solution is to make SinusoidalPositionalEmbedding's self.weights a parameter, but then we have to hack save/load to not save/ignore-on-load model.encoder.embed_positions.* and model.decoder.embed_positions.* keys.
Solution 3
The third solution is to save the useless weights (useless as they aren't trained and get calculated deterministically).
Perhaps you can think of other solutions.
Thank you.