@@ -1143,6 +1143,131 @@ def test_dynamic_axes_unchange(self):
11431143 opset_version = 12 ,
11441144 )
11451145
1146+ def test_stft_default (self ):
1147+ """Test STFT with default parameters"""
1148+ m1 = torch .randn ((1 , 32 ))
1149+ n_fft = 16
1150+ self .assertONNX (
1151+ lambda x : torch .stft (x , n_fft = n_fft , center = False , return_complex = False ),
1152+ (m1 ,),
1153+ opset_version = 17 ,
1154+ )
1155+
1156+ def test_stft_hop_length (self ):
1157+ """Test STFT with custom hop length"""
1158+ m1 = torch .randn ((1 , 32 ))
1159+ n_fft = 16
1160+ hop_length = 4
1161+ self .assertONNX (
1162+ lambda x : torch .stft (
1163+ x ,
1164+ n_fft = n_fft ,
1165+ center = False ,
1166+ hop_length = hop_length ,
1167+ return_complex = False ,
1168+ ),
1169+ (m1 ,),
1170+ opset_version = 17 ,
1171+ )
1172+
1173+ def test_stft_non_divisible_hop_length (self ):
1174+ """Test STFT with non-divisible custom hop length"""
1175+ m1 = torch .randn ((1 , 32 ))
1176+ n_fft = 16
1177+ hop_length = 5
1178+ self .assertONNX (
1179+ lambda x : torch .stft (
1180+ x ,
1181+ n_fft = n_fft ,
1182+ center = False ,
1183+ hop_length = hop_length ,
1184+ return_complex = False ,
1185+ ),
1186+ (m1 ,),
1187+ opset_version = 17 ,
1188+ )
1189+
1190+ def test_stft_window_int_same_size (self ):
1191+ """Test STFT with specific window length equals n_fft"""
1192+ m1 = torch .randn ((1 , 32 ))
1193+ n_fft = 16
1194+ win_length = 16
1195+ self .assertONNX (
1196+ lambda x : torch .stft (
1197+ x ,
1198+ n_fft = n_fft ,
1199+ center = False ,
1200+ win_length = win_length ,
1201+ return_complex = False ,
1202+ ),
1203+ (m1 ,),
1204+ opset_version = 17 ,
1205+ )
1206+
1207+ def test_stft_window_int_different_size (self ):
1208+ """Test STFT with specific window length different than n_fft"""
1209+ m1 = torch .randn ((1 , 32 ))
1210+ n_fft = 16
1211+ win_length = 9
1212+ self .assertONNX (
1213+ lambda x : torch .stft (
1214+ x ,
1215+ n_fft = n_fft ,
1216+ center = False ,
1217+ win_length = win_length ,
1218+ return_complex = False ,
1219+ ),
1220+ (m1 ,),
1221+ opset_version = 17 ,
1222+ )
1223+
1224+ def test_stft_window_custom (self ):
1225+ """Test STFT with a custom window"""
1226+ m1 = torch .randn ((1 , 32 ))
1227+ n_fft = 16
1228+ window = torch .hann_window (16 )
1229+ self .assertONNX (
1230+ lambda x : torch .stft (
1231+ x , n_fft = n_fft , center = False , window = window , return_complex = False
1232+ ),
1233+ (m1 ,),
1234+ opset_version = 17 ,
1235+ )
1236+
1237+ def test_stft_one_dimension (self ):
1238+ """Test STFT with a single dimension"""
1239+ m1 = torch .randn ((32 ))
1240+ n_fft = 16
1241+ self .assertONNX (
1242+ lambda x : torch .stft (x , n_fft = n_fft , center = False , return_complex = False ),
1243+ (m1 ,),
1244+ opset_version = 17 ,
1245+ )
1246+
1247+ def test_stft_normalize (self ):
1248+ """Test STFT with normalization"""
1249+ m1 = torch .randn ((32 ))
1250+ n_fft = 16
1251+ self .assertONNX (
1252+ lambda x : torch .stft (
1253+ x , n_fft = n_fft , center = False , normalized = True , return_complex = False
1254+ ),
1255+ (m1 ,),
1256+ opset_version = 17 ,
1257+ )
1258+
1259+ def test_stft_not_onesided (self ):
1260+ """Test STFT without returning a single side"""
1261+ m1 = torch .randn ((32 ))
1262+ n_fft = 16
1263+ self .assertONNX (
1264+ lambda x : torch .stft (
1265+ x , n_fft = n_fft , center = False , onesided = False , return_complex = False
1266+ ),
1267+ (m1 ,),
1268+ opset_version = 17 ,
1269+ )
1270+
11461271 def test_aten_embedding_1 (self ):
11471272 _onnx_opset_version = 12
11481273
0 commit comments