Skip to content

Commit 6d47e2c

Browse files
BowenBaofacebook-github-bot
authored andcommitted
[ONNX] Fix opset 11 ConstantChunk with negative dim (#51396) (#51525)
Summary: Pull Request resolved: #51525 Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D26203115 Pulled By: SplitInfinity fbshipit-source-id: d76942f7cc5812c8a1cc16891e4956cc658283d8
1 parent ba824eb commit 6d47e2c

2 files changed

Lines changed: 26 additions & 4 deletions

File tree

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3794,14 +3794,17 @@ def forward(self, input):
37943794
@disableScriptTest()
37953795
def test_chunk(self):
37963796
class ChunkModel(torch.nn.Module):
3797-
def __init__(self):
3797+
def __init__(self, dim=1):
37983798
super(ChunkModel, self).__init__()
3799+
self.dim = dim
37993800

38003801
def forward(self, x):
3801-
return torch.chunk(x, 3, dim=1)
3802+
return torch.chunk(x, 3, dim=self.dim)
38023803

38033804
model = ChunkModel()
38043805
model.eval()
3806+
model_neg_dim = ChunkModel(-1)
3807+
model_neg_dim.eval()
38053808
x = torch.randn(1, 18)
38063809

38073810
for dim_size_ in range(13, 16):
@@ -3810,6 +3813,10 @@ def forward(self, x):
38103813
input_names=['x'],
38113814
dynamic_axes={'x': {0: 'batch_size', 1: 'dims'}})
38123815

3816+
self.run_test(model_neg_dim, x, test_with_inputs=[y],
3817+
input_names=['x'],
3818+
dynamic_axes={'x': {0: 'batch_size', 1: 'dims'}})
3819+
38133820
def test_concat(self):
38143821
class ConcatModel(torch.nn.Module):
38153822
def forward(self, x, y, z):
@@ -5823,6 +5830,22 @@ def make_input(batch_size):
58235830
other_input = make_input(RNN_BATCH_SIZE + 1)
58245831
self.run_test(model, other_input, batch_size=RNN_BATCH_SIZE + 1)
58255832

5833+
@disableScriptTest() # TODO: RuntimeError: Exporting the operator __is_ to ONNX is not supported
5834+
def test_transformer_encoder(self):
5835+
from torch.nn import TransformerEncoderLayer, TransformerEncoder
5836+
5837+
class MyModule(torch.nn.Module):
5838+
def __init__(self, ninp, nhead, nhid, dropout, nlayers):
5839+
super(MyModule, self).__init__()
5840+
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
5841+
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
5842+
5843+
def forward(self, input):
5844+
return self.transformer_encoder(input)
5845+
5846+
x = torch.rand(10, 32, 512)
5847+
self.run_test(MyModule(512, 8, 2048 , 0., 3), (x,), atol=1e-6)
5848+
58265849
@skipIfUnsupportedMinOpsetVersion(10)
58275850
def test_fake_quantize_per_tensor(self):
58285851
class FakeQuantizePerTensorModel(torch.nn.Module):

torch/onnx/symbolic_opset11.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -863,8 +863,7 @@ def embedding_bag(g,
863863
def prim_ConstantChunk(g, self, chunks, dim):
864864
input_shape = g.op("Shape", self)
865865
axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
866-
axis_next = g.op("Constant", value_t=torch.tensor([dim + 1], dtype=torch.long))
867-
input_shape_dim = g.op("Slice", input_shape, axis, axis_next)
866+
input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0)
868867
start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
869868
chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long))
870869
chunk_size_minus_1 = g.op("Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long))

0 commit comments

Comments
 (0)