Skip to content

Commit 1d233d7

Browse files
kshitij12345facebook-github-bot
authored andcommitted
[fix] torch.nn.functional.embedding -> padding_idx behavior (#46714)
Summary: Reference #46585 Fix for second snippet in the mentioned issue. ```python predefined_weights = torch.rand(10, 3) result = torch.nn.functional.embedding(torch.LongTensor([1,2,0]), predefined_weights, padding_idx=0) ``` Pull Request resolved: #46714 Reviewed By: VitalyFedyunin Differential Revision: D24593352 Pulled By: albanD fbshipit-source-id: 655b69d9ec57891871e26feeda2aa0dcff73beba
1 parent 3e499e4 commit 1d233d7

5 files changed

Lines changed: 37 additions & 8 deletions

File tree

aten/src/ATen/native/Embedding.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,27 @@ Tensor embedding(const Tensor & weight, const Tensor & indices,
1717
auto indices_arg = TensorArg(indices, "indices", 1);
1818
checkScalarType("embedding", indices_arg, kLong);
1919

20+
auto zerofill_padding = [&](Tensor& embedding) {
21+
if (padding_idx >= 0) {
22+
embedding.masked_fill_((indices == padding_idx).reshape({-1, 1}), 0);
23+
}
24+
};
25+
2026
// TODO: use tensor.index() after improving perf
2127
if (indices.dim() == 1) {
22-
return weight.index_select(0, indices);
28+
auto out = weight.index_select(0, indices);
29+
zerofill_padding(out);
30+
return out;
2331
}
2432

2533
auto size = indices.sizes().vec();
2634
for (auto d : weight.sizes().slice(1)) {
2735
size.push_back(d);
2836
}
29-
return weight.index_select(0, indices.reshape(-1)).view(size);
37+
38+
auto out = weight.index_select(0, indices.reshape(-1));
39+
zerofill_padding(out);
40+
return out.view(size);
3041
}
3142

3243
Tensor embedding_backward(

test/test_nn.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3055,6 +3055,13 @@ def test_embedding_functional(self):
30553055
res_F = F.embedding(a, embeddings)
30563056
self.assertEqual(res_old, res_F)
30573057

3058+
embed_old = torch.nn.Embedding(4, 3)
3059+
embed_old = embed_old.from_pretrained(embeddings, padding_idx=2)
3060+
res_old = embed_old(a)
3061+
res_F = F.embedding(a, embeddings, padding_idx=2)
3062+
3063+
self.assertEqual(res_old, res_F)
3064+
30583065
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
30593066
'Linear_FP16_weight requires FBGEMM. FBGEMM is only optimized for CPUs'
30603067
' with instruction set support avx2 or newer.')
@@ -10738,6 +10745,15 @@ def fn(weight):
1073810745
fn = fn_wrapper(device)
1073910746
_assertGradAndGradgradChecks(self, fn, (weight, ))
1074010747

10748+
def fn_wrapper(device):
10749+
def padding_fn(weight):
10750+
inp = torch.tensor([[0, 1, 1, 2], [1, 1, 0, 2]], dtype=torch.long).to(device)
10751+
return torch.nn.functional.embedding(inp, weight, padding_idx=1)
10752+
return padding_fn
10753+
10754+
fn = fn_wrapper(device)
10755+
_assertGradAndGradgradChecks(self, fn, (weight, ))
10756+
1074110757
def test_embedding_scalar_weight_error(self, device):
1074210758
indices = torch.rand(2, 2, device=device).long()
1074310759
weight = torch.tensor(1.0, device=device)

tools/autograd/derivatives.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1186,7 +1186,7 @@
11861186
weight: embedding_backward(grad, indices, weight.size(0), padding_idx, scale_grad_by_freq, sparse)
11871187

11881188
- name: embedding_dense_backward(Tensor grad_output, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor
1189-
grad_output: embedding_dense_double_backward(grad, indices)
1189+
grad_output: embedding_dense_double_backward(grad, indices, padding_idx)
11901190
indices: non_differentiable
11911191

11921192
- name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)

torch/csrc/autograd/FunctionsManual.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2694,16 +2694,18 @@ Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) {
26942694
return at::constant_pad_nd(grad, negated_pad, 0);
26952695
}
26962696

2697-
Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices) {
2698-
// since first backward takes care of padding_idx
2699-
// and scaling by frequency, we don't need to worry
2700-
// about it here.
2697+
Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices, int64_t padding_idx) {
2698+
// since first backward takes care of scaling by frequency,
2699+
// we don't need to worry about it here.
27012700
auto gg_weight = grad.index_select(0, indices.reshape(-1));
27022701

27032702
// reshape gradient as per the shape of indices
27042703
auto size = indices.sizes().vec();
27052704
size.push_back(-1);
27062705

2706+
if (padding_idx >= 0) {
2707+
gg_weight.masked_fill_((indices == padding_idx).reshape({-1, 1}), 0);
2708+
}
27072709
return gg_weight.view(size);
27082710
}
27092711

torch/csrc/autograd/FunctionsManual.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ at::Tensor logdet_backward(const at::Tensor & grad, const at::Tensor& self, cons
118118
at::Tensor slogdet_backward(const at::Tensor& grad_logabsdet, const at::Tensor& self, const at::Tensor& signdet, const at::Tensor& logabsdet);
119119
at::Tensor log1p_backward(const at::Tensor& grad, const at::Tensor& self);
120120
at::Tensor sparse_constructor_values_backward(const at::Tensor& sparse_grad_out, const at::Tensor& indices, at::IntArrayRef values_shape);
121-
at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Tensor & indices);
121+
at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Tensor & indices, int64_t padding_idx);
122122
at::Tensor index_backward(at::Tensor zeros_like_self, at::TensorList indices, const at::Tensor& grad);
123123
at::Tensor _cudnn_ctc_loss_backward(const at::Tensor& grad_out, const at::Tensor& loss, const at::Tensor& raw_grad, bool zero_infinity);
124124

0 commit comments

Comments
 (0)