The torch.sum() function writes its outputs in erroneous locations when the out kwarg is used. Repro:
import torch
i = 1
a = torch.zeros(5, 3)
b = torch.randn(3, 5)
ac = a.clone()
ac[:, i].copy_(b.sum(0))
print(ac)
ac = a.clone()
b.sum(0, out=ac[:, i])
print(ac)
Output:
lvdmaaten-mbp:Desktop lvdmaaten$ python bug.py
0.0000 -0.1515 0.0000
0.0000 1.8761 0.0000
0.0000 -1.7563 0.0000
0.0000 0.5194 0.0000
0.0000 1.5322 0.0000
[torch.FloatTensor of size 5x3]
0.0000 -0.1515 1.8761
-1.7563 0.5194 1.5322
0.0000 0.0000 0.0000
0.0000 0.0000 0.0000
0.0000 0.0000 0.0000
[torch.FloatTensor of size 5x3]
I presume what happens is that torch.sum() does not respect the stride of a[:, i] when putting outputs in place.
Tested on PyTorch version 0.3.0.post4.
The
torch.sum()function writes its outputs in erroneous locations when the outkwargis used. Repro:Output:
I presume what happens is that
torch.sum()does not respect the stride ofa[:, i]when putting outputs in place.Tested on PyTorch version 0.3.0.post4.