-
Notifications
You must be signed in to change notification settings - Fork 27.4k
Different behaviour in sparse matmul #88053
Description
The same code with and without the utilization of a sparse tensor gives an error:
The following is a working example of the code without sparse tensor:
a = torch.randn(2, 3, 3).requires_grad_(True)
print(a)
b = torch.randn(3, 1, requires_grad=True)
print(b)
y = torch.matmul(a, b)
print(y)
y.sum().backward()
print(a.grad)This is the same example code with a sparse tensor utilization:
a = torch.randn(2, 3, 3).to_sparse().requires_grad_(True)
print(a)
b = torch.randn(3, 1, requires_grad=True)
print(b)
y = torch.matmul(a, b)
print(y)
y.sum().backward()
print(a.grad)The code with the sparse variant gives the following error:
NotImplementedError Traceback (most recent call last)
Cella 22 in <cell line: 7>()
[4] b = torch.randn(3, 1, requires_grad=True)
[5] print(b)
----> [7] y = torch.matmul(a, b)
[8] print(y)
[10] y.sum().backward()
NotImplementedError: Tensors of type SparseTensorImpl do not have is_contiguous
I'm not very confident that I am not wrong, but I think that it is better if matmul for dense and sparse tensor follows the same behaviour.
If I do not understand something, how can I replicate the same behaviour in the sparse case?
Thank you!
Versions
Collecting environment information...
PyTorch version: 1.11.0+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A
OS: Could not collect
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A
Python version: 3.9.12 (tags/v3.9.12:b28265d, Mar 23 2022, 23:52:46) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19041-SP0
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to:
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect