Skip to content

torch.stft fails to center complex input tensors #50234

@peterbell10

Description

@peterbell10

🐛 Bug

torch.stft with default pad_mode will raise an error while centring complex input tensors. This is because centering relies on F.pad which only partially supports complex types.

To Reproduce

import torch
a = torch.rand(100, dtype=torch.complex64)
torch.stft(a, n_fft=10)

Produces the following traceback:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-9-220c388e1270> in <module>
----> 1 torch.stft(a, n_fft=10)

~/.conda/envs/pytorch-dev/lib/python3.8/site-packages/torch/functional.py in stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, return_complex)
    580         extended_shape = [1] * (3 - signal_dim) + list(input.size())
    581         pad = int(n_fft // 2)
--> 582         input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
    583         input = input.view(input.shape[-signal_dim:])
    584     return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore

~/.conda/envs/pytorch-dev/lib/python3.8/site-packages/torch/nn/functional.py in _pad(input, pad, mode, value)
   3563             assert len(pad) == 2, '3D tensors expect 2 values for padding'
   3564             if mode == 'reflect':
-> 3565                 return torch._C._nn.reflection_pad1d(input, pad)
   3566             elif mode == 'replicate':
   3567                 return torch._C._nn.replication_pad1d(input, pad)

RuntimeError: "reflection_pad1d" not implemented for 'ComplexFloat'

Calling stft with center=False or pad_mode='constant' works fine though.

Environment

PyTorch version: 1.8.0a0+22832fa
Is debug build: False
CUDA used to build PyTorch: 11.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.18.2

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: 
GPU 0: Quadro RTX 8000
GPU 1: Quadro RTX 8000

Nvidia driver version: 460.27.04
cuDNN version: Probably one of the following:
/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7
/usr/local/cuda-10.2.89/targets/x86_64-linux/lib/libcudnn.so.7
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] torch==1.8.0a0
[conda] magma-cuda112             2.5.2                         1    pytorch
[conda] mkl                       2020.2                      256    conda-forge
[conda] mkl-include               2020.2                      256    conda-forge
[conda] numpy                     1.19.2           py38hf89b668_1    conda-forge
[conda] torch                     1.8.0a0                  pypi_0    pypi

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @dylanbespalko @mruberry @peterbell10 @walterddr

Metadata

Metadata

Assignees

Labels

high prioritymodule: complexRelated to complex number support in PyTorchmodule: ffttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions