Skip to content

Commit a95ffd7

Browse files
committed
Added ONNX STFT support, including unit tests. Addressed all CR
comments.
1 parent 65b9983 commit a95ffd7

4 files changed

Lines changed: 495 additions & 8 deletions

File tree

test/onnx/test_op_consistency.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ def reason_flaky() -> str:
302302
[
303303
"ceil",
304304
"sqrt",
305+
"stft",
305306
"t",
306307
]
307308
)

test/onnx/test_operators.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,131 @@ def test_dynamic_axes_unchange(self):
11431143
opset_version=12,
11441144
)
11451145

1146+
def test_stft_default(self):
1147+
"""Test STFT with default parameters"""
1148+
m1 = torch.randn((1, 32))
1149+
n_fft = 16
1150+
self.assertONNX(
1151+
lambda x: torch.stft(x, n_fft=n_fft, center=False, return_complex=False),
1152+
(m1,),
1153+
opset_version=17,
1154+
)
1155+
1156+
def test_stft_hop_length(self):
1157+
"""Test STFT with custom hop length"""
1158+
m1 = torch.randn((1, 32))
1159+
n_fft = 16
1160+
hop_length = 4
1161+
self.assertONNX(
1162+
lambda x: torch.stft(
1163+
x,
1164+
n_fft=n_fft,
1165+
center=False,
1166+
hop_length=hop_length,
1167+
return_complex=False,
1168+
),
1169+
(m1,),
1170+
opset_version=17,
1171+
)
1172+
1173+
def test_stft_non_divisible_hop_length(self):
1174+
"""Test STFT with non-divisible custom hop length"""
1175+
m1 = torch.randn((1, 32))
1176+
n_fft = 16
1177+
hop_length = 5
1178+
self.assertONNX(
1179+
lambda x: torch.stft(
1180+
x,
1181+
n_fft=n_fft,
1182+
center=False,
1183+
hop_length=hop_length,
1184+
return_complex=False,
1185+
),
1186+
(m1,),
1187+
opset_version=17,
1188+
)
1189+
1190+
def test_stft_window_int_same_size(self):
1191+
"""Test STFT with specific window length equals n_fft"""
1192+
m1 = torch.randn((1, 32))
1193+
n_fft = 16
1194+
win_length = 16
1195+
self.assertONNX(
1196+
lambda x: torch.stft(
1197+
x,
1198+
n_fft=n_fft,
1199+
center=False,
1200+
win_length=win_length,
1201+
return_complex=False,
1202+
),
1203+
(m1,),
1204+
opset_version=17,
1205+
)
1206+
1207+
def test_stft_window_int_different_size(self):
1208+
"""Test STFT with specific window length different than n_fft"""
1209+
m1 = torch.randn((1, 32))
1210+
n_fft = 16
1211+
win_length = 9
1212+
self.assertONNX(
1213+
lambda x: torch.stft(
1214+
x,
1215+
n_fft=n_fft,
1216+
center=False,
1217+
win_length=win_length,
1218+
return_complex=False,
1219+
),
1220+
(m1,),
1221+
opset_version=17,
1222+
)
1223+
1224+
def test_stft_window_custom(self):
1225+
"""Test STFT with a custom window"""
1226+
m1 = torch.randn((1, 32))
1227+
n_fft = 16
1228+
window = torch.hann_window(16)
1229+
self.assertONNX(
1230+
lambda x: torch.stft(
1231+
x, n_fft=n_fft, center=False, window=window, return_complex=False
1232+
),
1233+
(m1,),
1234+
opset_version=17,
1235+
)
1236+
1237+
def test_stft_one_dimension(self):
1238+
"""Test STFT with a single dimension"""
1239+
m1 = torch.randn((32))
1240+
n_fft = 16
1241+
self.assertONNX(
1242+
lambda x: torch.stft(x, n_fft=n_fft, center=False, return_complex=False),
1243+
(m1,),
1244+
opset_version=17,
1245+
)
1246+
1247+
def test_stft_normalize(self):
1248+
"""Test STFT with normalization"""
1249+
m1 = torch.randn((32))
1250+
n_fft = 16
1251+
self.assertONNX(
1252+
lambda x: torch.stft(
1253+
x, n_fft=n_fft, center=False, normalized=True, return_complex=False
1254+
),
1255+
(m1,),
1256+
opset_version=17,
1257+
)
1258+
1259+
def test_stft_not_onesided(self):
1260+
"""Test STFT without returning a single side"""
1261+
m1 = torch.randn((32))
1262+
n_fft = 16
1263+
self.assertONNX(
1264+
lambda x: torch.stft(
1265+
x, n_fft=n_fft, center=False, onesided=False, return_complex=False
1266+
),
1267+
(m1,),
1268+
opset_version=17,
1269+
)
1270+
11461271
def test_aten_embedding_1(self):
11471272
_onnx_opset_version = 12
11481273

0 commit comments

Comments
 (0)