🐛 Bug
Calling count_nonzero on a tensor that requires_grad causes an error in autograd/functions/utils.h:64
To Reproduce
Steps to reproduce the behavior:
import torch
foo = torch.empty(5).requires_grad_()
foo.count_nonzero()
Output:
>>> foo.count_nonzero()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/autograd/functions/utils.h":64, please report a bug to PyTorch.
Expected behavior
Expected a count of non-zero elements, as I get when doing it under torch.no_grad() or on tensors that do not requires_grad.
Environment
PyTorch version: 1.7.1
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: Pop!_OS 20.10 (x86_64)
GCC version: (Ubuntu 10.2.0-13ubuntu1) 10.2.0
Clang version: Could not collect
CMake version: version 3.16.3
Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: Quadro RTX 5000
Nvidia driver version: 455.38
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.19.4
[pip3] torch==1.7.1
[conda] blas 1.0 mkl
[conda] mkl 2020.0 166
[conda] mkl-service 2.3.0 py37he904b0f_0
[conda] mkl_fft 1.0.15 py37ha843d7b_0
[conda] mkl_random 1.1.0 py37hd6b4f25_0
[conda] numpy 1.18.1 py37h4f9e942_0
[conda] numpy-base 1.18.1 py37hde5b4d6_1
[conda] numpydoc 0.9.2 py_0
Additional context
cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @albanD @gqchen @pearu @nikitaved @soulitzer
🐛 Bug
Calling
count_nonzeroon a tensor thatrequires_gradcauses an error in autograd/functions/utils.h:64To Reproduce
Steps to reproduce the behavior:
Output:
Expected behavior
Expected a count of non-zero elements, as I get when doing it under
torch.no_grad()or on tensors that do notrequires_grad.Environment
Additional context
cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @albanD @gqchen @pearu @nikitaved @soulitzer