|
60 | 60 | from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \ |
61 | 61 | skipIfRocm, suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, \ |
62 | 62 | freeze_rng_state, set_rng_seed, slowTest, TemporaryFileName, skipIfCompiledWithoutNumpy, \ |
63 | | - enable_profiling_mode |
| 63 | + enable_profiling_mode, TEST_MKL |
64 | 64 | from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, disable_autodiff_subgraph_inlining, \ |
65 | 65 | _trace, enable_cpu_fuser_if, do_input_map, get_execution_plan, \ |
66 | 66 | execWrapper, _inline_everything, _tmp_donotuse_dont_inline_everything, \ |
@@ -10246,6 +10246,15 @@ def test_pack_unpack_state(self): |
10246 | 10246 | self.assertTrue(imported.unpack_called.item()) |
10247 | 10247 | torch.testing.assert_allclose(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float))) |
10248 | 10248 |
|
| 10249 | + @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support") |
| 10250 | + def test_torch_functional(self): |
| 10251 | + def foo(input, n_fft): |
| 10252 | + # type: (Tensor, int) -> Tensor |
| 10253 | + return torch.stft(input, n_fft) |
| 10254 | + |
| 10255 | + inps = (torch.randn(10), 7) |
| 10256 | + self.assertEqual(foo(*inps), torch.jit.script(foo)(*inps)) |
| 10257 | + |
10249 | 10258 | def test_missing_getstate(self): |
10250 | 10259 | class Foo(torch.nn.Module): |
10251 | 10260 | def __init__(self): |
|
0 commit comments