JIT torch.stft calls aten::stft which throws error
When I call torch.stft or torch.functional.stft, JIT tries to call aten::stft.
The former has the signature:
type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor
and the latter has the signature:
type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, bool) -> Tensor:
which causes an error due to the mismatching argument types.
Example:
@torch.jit.script
def f():
torch.stft(...)
Stack Trace
File "/home/travis/build/jamarshon/audio/torchaudio/functional.py", line 103, in <module>
@torch.jit.script
File "/home/travis/miniconda3/envs/testenv/lib/python3.6/site-packages/torch/jit/__init__.py", line 1050, in script
fn = torch._C._jit_script_compile(ast, _rcb, get_default_args(obj))
RuntimeError:
arguments for call are not valid:
for operator aten::stft(Tensor self, int n_fft, int? hop_length=<default>, int? win_length=<default>, Tensor? window=<default>, bool normalized=<default>, bool onesided=<default>) -> Tensor:
expected a value of type 'bool' for argument 'onesided' but instead found type 'str'.
at /home/travis/build/jamarshon/audio/torchaudio/functional.py:134:31
"""
assert sig.dim() == 2
if pad > 0:
# TODO add "with torch.no_grad():" back when JIT supports it
sig = torch.nn.functional.pad(sig, (pad, pad), "constant")
# default values are consistent with librosa.core.spectrum._spectrogram
spec_f = torch.stft(sig, n_fft, hop, ws, window,
True, 'reflect', False, True).transpose(1, 2)
~ <--- HERE
if normalize:
spec_f /= window.pow(2).sum().sqrt()
spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor (c, l, n_fft)
return spec_f
Expected behavior
The torch.stft function should be called not the aten one.
Environment
PyTorch version: 1.2.0a0+1252899
Is debug build: Yes
CUDA used to build PyTorch: 10.0.130
OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.14.0
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100
Nvidia driver version: 410.79
cuDNN version: Could not collect
Versions of relevant libraries:
[pip] numpy==1.16.2
[pip] numpydoc==0.8.0
[pip] pytorch-sphinx-theme==0.0.24
[pip] torch==1.2.0a0+1252899
[pip] torchaudio==0.2
[conda] blas 1.0 mkl
[conda] magma-cuda100 2.5.0 1 pytorch
[conda] mkl 2019.3 199
[conda] mkl-include 2019.3 199
[conda] mkl-service 1.1.2 py37he904b0f_5
[conda] mkl_fft 1.0.10 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch-sphinx-theme 0.0.24 dev_0
[conda] torch 1.2.0a0+1252899 dev_0
[conda] torchaudio 0.2 dev_0
JIT torch.stft calls aten::stft which throws error
When I call torch.stft or torch.functional.stft, JIT tries to call aten::stft.
The former has the signature:
type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensorand the latter has the signature:
type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, bool) -> Tensor:which causes an error due to the mismatching argument types.
Example:
Stack Trace
Expected behavior
The torch.stft function should be called not the aten one.
Environment
PyTorch version: 1.2.0a0+1252899
Is debug build: Yes
CUDA used to build PyTorch: 10.0.130
OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.14.0
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100
Nvidia driver version: 410.79
cuDNN version: Could not collect
Versions of relevant libraries:
[pip] numpy==1.16.2
[pip] numpydoc==0.8.0
[pip] pytorch-sphinx-theme==0.0.24
[pip] torch==1.2.0a0+1252899
[pip] torchaudio==0.2
[conda] blas 1.0 mkl
[conda] magma-cuda100 2.5.0 1 pytorch
[conda] mkl 2019.3 199
[conda] mkl-include 2019.3 199
[conda] mkl-service 1.1.2 py37he904b0f_5
[conda] mkl_fft 1.0.10 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch-sphinx-theme 0.0.24 dev_0
[conda] torch 1.2.0a0+1252899 dev_0
[conda] torchaudio 0.2 dev_0