@@ -3794,14 +3794,17 @@ def forward(self, input):
37943794 @disableScriptTest ()
37953795 def test_chunk (self ):
37963796 class ChunkModel (torch .nn .Module ):
3797- def __init__ (self ):
3797+ def __init__ (self , dim = 1 ):
37983798 super (ChunkModel , self ).__init__ ()
3799+ self .dim = dim
37993800
38003801 def forward (self , x ):
3801- return torch .chunk (x , 3 , dim = 1 )
3802+ return torch .chunk (x , 3 , dim = self . dim )
38023803
38033804 model = ChunkModel ()
38043805 model .eval ()
3806+ model_neg_dim = ChunkModel (- 1 )
3807+ model_neg_dim .eval ()
38053808 x = torch .randn (1 , 18 )
38063809
38073810 for dim_size_ in range (13 , 16 ):
@@ -3810,6 +3813,10 @@ def forward(self, x):
38103813 input_names = ['x' ],
38113814 dynamic_axes = {'x' : {0 : 'batch_size' , 1 : 'dims' }})
38123815
3816+ self .run_test (model_neg_dim , x , test_with_inputs = [y ],
3817+ input_names = ['x' ],
3818+ dynamic_axes = {'x' : {0 : 'batch_size' , 1 : 'dims' }})
3819+
38133820 def test_concat (self ):
38143821 class ConcatModel (torch .nn .Module ):
38153822 def forward (self , x , y , z ):
@@ -5823,6 +5830,22 @@ def make_input(batch_size):
58235830 other_input = make_input (RNN_BATCH_SIZE + 1 )
58245831 self .run_test (model , other_input , batch_size = RNN_BATCH_SIZE + 1 )
58255832
5833+ @disableScriptTest () # TODO: RuntimeError: Exporting the operator __is_ to ONNX is not supported
5834+ def test_transformer_encoder (self ):
5835+ from torch .nn import TransformerEncoderLayer , TransformerEncoder
5836+
5837+ class MyModule (torch .nn .Module ):
5838+ def __init__ (self , ninp , nhead , nhid , dropout , nlayers ):
5839+ super (MyModule , self ).__init__ ()
5840+ encoder_layers = TransformerEncoderLayer (ninp , nhead , nhid , dropout )
5841+ self .transformer_encoder = TransformerEncoder (encoder_layers , nlayers )
5842+
5843+ def forward (self , input ):
5844+ return self .transformer_encoder (input )
5845+
5846+ x = torch .rand (10 , 32 , 512 )
5847+ self .run_test (MyModule (512 , 8 , 2048 , 0. , 3 ), (x ,), atol = 1e-6 )
5848+
58265849 @skipIfUnsupportedMinOpsetVersion (10 )
58275850 def test_fake_quantize_per_tensor (self ):
58285851 class FakeQuantizePerTensorModel (torch .nn .Module ):
0 commit comments