-
Notifications
You must be signed in to change notification settings - Fork 27.4k
MPS torch.where() is giving objectively incorrect results, leading to critical calculation errors #122916
Copy link
Copy link
Closed
Labels
module: 64-bitProblems related to incorrectly using 32-bit integers when 64-bit is needed (e.g., 8G tensors)Problems related to incorrectly using 32-bit integers when 64-bit is needed (e.g., 8G tensors)module: correctness (silent)issue that returns an incorrect result silentlyissue that returns an incorrect result silentlymodule: mpsRelated to Apple Metal Performance Shaders frameworkRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
I think I have an example of how MPS can get completely different results from CPU. Hopefully the simplicity of this example will be clear and helpful. This may be related to a previous issue noted on this forum (#84936).
import numpy as np
import torch
mps_device = torch.device("mps")
## Create a numpy matrix with many zeros
np.random.seed(0)
Numpy_Test = np.random.random(200000000)
indices = np.random.choice(np.arange(Numpy_Test.size), replace=False,size=int(Numpy_Test.size * 0.6))
Numpy_Test[indices] = 0
Numpy_Matrix = Numpy_Test.reshape((20000,10000))
## Get the indices of non-zero values in the matrix, and convert these indices into a numpy array
indices = np.where(Numpy_Matrix != 0)
indices = np.asarray(indices)
## Use numpy, torch, or a torch.mps object to find where indices[1] == 8000
# Using np.where
np.where(indices[1] == 8000)[0]
array([ 19165, 27061, 39165, ..., 79979029, 79987021, 79995171])
# Using torch.where
torch.where(torch.from_numpy(indices)[1] == 8000)[0]
tensor([ 19165, 27061, 39165, ..., 79979029, 79987021, 79995171])
# Using torch.where with an NPS object
torch.where(torch.from_numpy(indices)[1].to(mps_device) == 8000)[0]
tensor([ 19165, 27061, 39165, ..., 79979032, 79987024, 79995168], device='mps:0')Notice how the first two np.where and torch.where examples give them same results, but when using the tensor converted to MPS we get different results?
If I've not made an obvious mistake, this is a clear example of how MPS completely ruins calculations, because in this case, the indexes change, and all downstream calculations become meaningless.
Versions
torch version v0.2.1 and v0.2.0
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
module: 64-bitProblems related to incorrectly using 32-bit integers when 64-bit is needed (e.g., 8G tensors)Problems related to incorrectly using 32-bit integers when 64-bit is needed (e.g., 8G tensors)module: correctness (silent)issue that returns an incorrect result silentlyissue that returns an incorrect result silentlymodule: mpsRelated to Apple Metal Performance Shaders frameworkRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module