Skip to content

JIT torch.stft calls aten::stft which throws error #21478

@jamarshon

Description

@jamarshon

JIT torch.stft calls aten::stft which throws error

When I call torch.stft or torch.functional.stft, JIT tries to call aten::stft.

The former has the signature:
type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor

and the latter has the signature:
type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, bool) -> Tensor:

which causes an error due to the mismatching argument types.

Example:

@torch.jit.script
def f():
  torch.stft(...)

Stack Trace

  File "/home/travis/build/jamarshon/audio/torchaudio/functional.py", line 103, in <module>
    @torch.jit.script
  File "/home/travis/miniconda3/envs/testenv/lib/python3.6/site-packages/torch/jit/__init__.py", line 1050, in script
    fn = torch._C._jit_script_compile(ast, _rcb, get_default_args(obj))
RuntimeError: 
arguments for call are not valid:
  
  for operator aten::stft(Tensor self, int n_fft, int? hop_length=<default>, int? win_length=<default>, Tensor? window=<default>, bool normalized=<default>, bool onesided=<default>) -> Tensor:
  expected a value of type 'bool' for argument 'onesided' but instead found type 'str'.
  at /home/travis/build/jamarshon/audio/torchaudio/functional.py:134:31
      """
      assert sig.dim() == 2
  
      if pad > 0:
          # TODO add "with torch.no_grad():" back when JIT supports it
          sig = torch.nn.functional.pad(sig, (pad, pad), "constant")
  
      # default values are consistent with librosa.core.spectrum._spectrogram
      spec_f = torch.stft(sig, n_fft, hop, ws, window,
                          True, 'reflect', False, True).transpose(1, 2)
                                ~ <--- HERE
  
      if normalize:
          spec_f /= window.pow(2).sum().sqrt()
      spec_f = spec_f.pow(power).sum(-1)  # get power of "complex" tensor (c, l, n_fft)
      return spec_f

Expected behavior

The torch.stft function should be called not the aten one.

Environment

PyTorch version: 1.2.0a0+1252899
Is debug build: Yes
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.14.0

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 410.79
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.16.2
[pip] numpydoc==0.8.0
[pip] pytorch-sphinx-theme==0.0.24
[pip] torch==1.2.0a0+1252899
[pip] torchaudio==0.2
[conda] blas 1.0 mkl
[conda] magma-cuda100 2.5.0 1 pytorch
[conda] mkl 2019.3 199
[conda] mkl-include 2019.3 199
[conda] mkl-service 1.1.2 py37he904b0f_5
[conda] mkl_fft 1.0.10 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch-sphinx-theme 0.0.24 dev_0
[conda] torch 1.2.0a0+1252899 dev_0
[conda] torchaudio 0.2 dev_0

Metadata

Metadata

Assignees

Labels

oncall: jitAdd this issue/PR to JIT oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions