-
-
Notifications
You must be signed in to change notification settings - Fork 12.2k
[np.random] shuffle is not compatible with pytorch tensor #18206
Copy link
Copy link
Closed
Milestone
Description
When shuffling a pytorch tensor with numpy.random.shuffle, incorrect result will be produced with 100% (so far) chance.
Reproducing code example:
import torch as th
import numpy as np
N = 10
x = th.arange(N)
print(x)
np.random.shuffle(x)
print(x)
assert(len(set(x.tolist())) == N)Error message:
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([0, 0, 2, 1, 3, 4, 6, 3, 3, 9])
Traceback (most recent call last):
File "x.py", line 9, in <module>
assert(len(set(x.tolist())) == N)
AssertionError
In the shuffled result, 0 and 3 occured twice, which is clearly problematic.
NumPy/Python version information:
numpy '1.18.5' through anaconda.
See also: pytorch/pytorch#50880
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels