@@ -3055,6 +3055,13 @@ def test_embedding_functional(self):
30553055 res_F = F.embedding(a, embeddings)
30563056 self.assertEqual(res_old, res_F)
30573057
3058+ embed_old = torch.nn.Embedding(4, 3)
3059+ embed_old = embed_old.from_pretrained(embeddings, padding_idx=2)
3060+ res_old = embed_old(a)
3061+ res_F = F.embedding(a, embeddings, padding_idx=2)
3062+
3063+ self.assertEqual(res_old, res_F)
3064+
30583065 @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
30593066 'Linear_FP16_weight requires FBGEMM. FBGEMM is only optimized for CPUs'
30603067 ' with instruction set support avx2 or newer.')
@@ -10738,6 +10745,15 @@ def fn(weight):
1073810745 fn = fn_wrapper(device)
1073910746 _assertGradAndGradgradChecks(self, fn, (weight, ))
1074010747
10748+ def fn_wrapper(device):
10749+ def padding_fn(weight):
10750+ inp = torch.tensor([[0, 1, 1, 2], [1, 1, 0, 2]], dtype=torch.long).to(device)
10751+ return torch.nn.functional.embedding(inp, weight, padding_idx=1)
10752+ return padding_fn
10753+
10754+ fn = fn_wrapper(device)
10755+ _assertGradAndGradgradChecks(self, fn, (weight, ))
10756+
1074110757 def test_embedding_scalar_weight_error(self, device):
1074210758 indices = torch.rand(2, 2, device=device).long()
1074310759 weight = torch.tensor(1.0, device=device)
0 commit comments