Skip to content

Commit ee4ce8e

Browse files
BowenBaofacebook-github-bot
authored andcommitted
[ONNX] fix export of embedding with padding_idx (#53053) (#53530)
Summary: Pull Request resolved: #53530 fix export of embedding with padding_idx Test Plan: Imported from OSS Reviewed By: navahgar, jamesr66a, malfet Differential Revision: D26922420 Pulled By: SplitInfinity fbshipit-source-id: b8b867a96a13cf810f9c0ae88fcc5c95072bb390
1 parent a572f70 commit ee4ce8e

3 files changed

Lines changed: 38 additions & 1 deletion

File tree

test/onnx/test_pytorch_onnx_caffe2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,7 @@ def forward(self, x):
809809
x = torch.randn(2, 3, 4)
810810
self.run_model_test(ArithmeticModule(), input=x, train=False, batch_size=BATCH_SIZE)
811811

812+
@skipIfUnsupportedMinOpsetVersion(9) # Where op not supported for lower opsets
812813
def test_embedding(self):
813814
model = nn.Embedding(10, 3, padding_idx=-1)
814815
input = torch.LongTensor(list(range(10))[::-1])

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6077,6 +6077,23 @@ def run_model():
60776077

60786078
self.assertRaises(TypeError, run_model)
60796079

6080+
@skipIfUnsupportedMinOpsetVersion(9)
6081+
def test_embedding(self):
6082+
class EmbedModel(torch.nn.Module):
6083+
def forward(self, input, emb):
6084+
return torch.nn.functional.embedding(input, emb, padding_idx=1)
6085+
6086+
model = EmbedModel()
6087+
x = torch.randint(4, (4, ))
6088+
x[2] = x[0] = 1
6089+
embedding_matrix = torch.rand(10, 3)
6090+
self.run_test(model, (x, embedding_matrix))
6091+
6092+
x = torch.randint(4, (4, 3, 2))
6093+
x[2] = 1
6094+
x[0][1] = 1
6095+
self.run_test(model, (x, embedding_matrix))
6096+
60806097
def _dispatch_rnn_test(self, name, *args, **kwargs):
60816098
if name == 'elman':
60826099
self._elman_rnn_test(*args, **kwargs)

torch/onnx/symbolic_opset9.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,8 +471,27 @@ def expand_as(g, self, other):
471471
return g.op("Expand", self, shape)
472472

473473

474+
@parse_args('v', 'v', 'i', 'b', 'v')
474475
def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
475-
return g.op("Gather", weight, indices)
476+
if scale_grad_by_freq and sym_help._training_mode:
477+
raise RuntimeError('Unsupported: ONNX export of embedding with scale_grad_by_freq=True '
478+
'for training mode. ONNX does not support scaling the gradients.')
479+
# To match the torch operator behavior for padding_idx:
480+
# if (padding_idx >= 0) {
481+
# embedding.masked_fill_((indices == padding_idx).reshape({-1, 1}), 0);
482+
# }
483+
# auto out = weight.index_select(0, indices.reshape(-1));
484+
# zerofill_padding(out);
485+
# return out.view(size);
486+
weight = g.op("Gather", weight, indices)
487+
if (padding_idx >= 0):
488+
mask = eq(g, indices, g.op("Constant", value_t=torch.tensor(padding_idx)))
489+
if sym_help._export_onnx_opset_version < 11:
490+
mask = unsqueeze(g, mask, -1)
491+
else:
492+
mask = sym_help._unsqueeze_helper(g, mask, [-1])
493+
weight = masked_fill(g, weight, mask, torch.tensor(0.))
494+
return weight
476495

477496

478497
@parse_args('v', 'v', 'v', 'i', 'i', 'i', 'v', 'i')

0 commit comments

Comments
 (0)