Skip to content

Commit fddf732

Browse files
Elias Ellisonfacebook-github-bot
authored andcommitted
[JIT] fix resolving of functions in torch/functional. fix compilation of torch.stft (#33504)
Summary: Pull Request resolved: #33504 Fix resolution fo functions that are bound onto torch in torch/functional.py. This does not fix compilation of all of those functions, those will be done in follow ups. Does torch.stft as a start. Fixes #21478 Test Plan: Imported from OSS Differential Revision: D20014591 Pulled By: eellison fbshipit-source-id: bb362f1b5479adbb890e72a54111ef716679d127
1 parent 057fd5e commit fddf732

9 files changed

Lines changed: 35 additions & 10 deletions

File tree

test/test_jit.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
6161
skipIfRocm, suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, \
6262
freeze_rng_state, set_rng_seed, slowTest, TemporaryFileName, skipIfCompiledWithoutNumpy, \
63-
enable_profiling_mode
63+
enable_profiling_mode, TEST_MKL
6464
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, disable_autodiff_subgraph_inlining, \
6565
_trace, enable_cpu_fuser_if, do_input_map, get_execution_plan, \
6666
execWrapper, _inline_everything, _tmp_donotuse_dont_inline_everything, \
@@ -10246,6 +10246,15 @@ def test_pack_unpack_state(self):
1024610246
self.assertTrue(imported.unpack_called.item())
1024710247
torch.testing.assert_allclose(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
1024810248

10249+
@unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
10250+
def test_torch_functional(self):
10251+
def foo(input, n_fft):
10252+
# type: (Tensor, int) -> Tensor
10253+
return torch.stft(input, n_fft)
10254+
10255+
inps = (torch.randn(10), 7)
10256+
self.assertEqual(foo(*inps), torch.jit.script(foo)(*inps))
10257+
1024910258
def test_missing_getstate(self):
1025010259
class Foo(torch.nn.Module):
1025110260
def __init__(self):
File renamed without changes.

torch/functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ._overrides import has_torch_function, handle_torch_function
66

77
Tensor = torch.Tensor
8+
from torch import _VF
89

910
__all__ = [
1011
'align_tensors',
@@ -301,7 +302,6 @@ def meshgrid(*tensors, **kwargs):
301302
return torch._C._VariableFunctions.meshgrid(tensors)
302303

303304

304-
305305
def stft(input, n_fft, hop_length=None, win_length=None, window=None,
306306
center=True, pad_mode='reflect', normalized=False, onesided=True):
307307
# type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor
@@ -399,7 +399,7 @@ def stft(input, n_fft, hop_length=None, win_length=None, window=None,
399399
pad = int(n_fft // 2)
400400
input = F.pad(input.view(extended_shape), (pad, pad), pad_mode)
401401
input = input.view(input.shape[-signal_dim:])
402-
return torch._C._VariableFunctions.stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
402+
return _VF.stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
403403

404404

405405
del torch.unique_dim

torch/jit/_builtins.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,25 @@
8383
(torch.nn.init._no_grad_zero_, "aten::_no_grad_zero_"),
8484
(torch._C._get_tracing_state, "aten::_get_tracing_state"),
8585
(warnings.warn, "aten::warn"),
86+
(torch._VF.stft, "aten::stft")
8687
]
8788

89+
# ops in torch.functional are bound to torch
90+
# in these cases, we want to resolve the function to their python implementation
91+
# instead looking up a builtin "aten::" schema
92+
93+
def _gen_torch_functional_registered_ops():
94+
# eventually ops should encompass all of torch/functional.py, (torch.functional.__all__)
95+
# but we are currently only able to compile some of the functions. additionally,
96+
# some functions directly map to their aten:: implementations.
97+
# TODO: add support for more ops
98+
ops = ["stft"]
99+
return set(getattr(torch.functional, name) for name in ops)
100+
101+
_functional_registered_ops = _gen_torch_functional_registered_ops()
102+
103+
def _is_special_functional_bound_op(fn):
104+
return fn in _functional_registered_ops
88105

89106
# lazily built to ensure the correct initialization order
90107
def _get_builtin_table():
@@ -96,7 +113,7 @@ def _get_builtin_table():
96113
def register_all(mod):
97114
for name in dir(mod):
98115
v = getattr(mod, name)
99-
if callable(v):
116+
if callable(v) and not _is_special_functional_bound_op(v):
100117
_builtin_ops.append((v, "aten::" + name))
101118
for mod in _modules_containing_builtins:
102119
register_all(mod)

torch/jit/quantized.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
from torch._jit_internal import Tuple, Optional, List # noqa: F401
44

5-
from torch import Tensor # noqa: F401
6-
from torch.nn import _VF
5+
from torch import Tensor, _VF # noqa: F401
76

87
from torch.nn.utils.rnn import PackedSequence
98

torch/nn/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .modules import utils
1111
from .modules.utils import _single, _pair, _triple, _list_with_default
1212
from . import grad # noqa: F401
13-
from . import _VF
13+
from torch import _VF
1414
from .._jit_internal import boolean_dispatch, List
1515
from .._overrides import has_torch_function, handle_torch_function
1616

torch/nn/modules/rnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ..parameter import Parameter
88
from ..utils.rnn import PackedSequence
99
from .. import init
10-
from .. import _VF
10+
from ... import _VF
1111

1212
_rnn_impls = {
1313
'RNN_TANH': _VF.rnn_tanh,

torch/nn/quantized/dynamic/modules/rnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
import torch.nn as nn
55
from torch import Tensor # noqa: F401
6-
from torch.nn import _VF
6+
from torch import _VF
77
from torch._jit_internal import Tuple, Optional, List # noqa: F401
88
from torch.nn.utils.rnn import PackedSequence
99
import numbers

torch/nn/utils/rnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import warnings
33

44
import torch
5-
from .. import _VF
5+
from ... import _VF
66
from ..._jit_internal import Optional
77

88

0 commit comments

Comments
 (0)