Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Dim,
dynamic_dim,
export,
unflatten,
)
from torch.export._trace import (
_export_to_torch_ir,
Expand Down Expand Up @@ -2166,6 +2167,220 @@ def forward(self, x):
self.assertEqual(v, ep.state_dict[k])
self.assertTrue(torch.allclose(ep(test_inp), orig_eager(test_inp)))

def test_nn_module_stack(self):
class Leaf(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)

def forward(self, x):
return self.linear(x)

class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self.leaf = Leaf()
self.register_buffer("buffer", torch.randn(4, 4))

def forward(self, x):
return self.buffer.sum() + self.leaf(x).sum()

class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.bar = Bar()

def forward(self, x):
y = self.bar.buffer + x
return (self.bar(x) + y.sum(),)

inp = (torch.randn(4, 4),)
mod = Foo()
ep_strict = torch.export.export(mod, inp)
ep_non_strict = torch.export.export(mod, inp, strict=False)

gm_unflat_non_strict = unflatten(ep_non_strict)
self.assertTrue(hasattr(gm_unflat_non_strict, "bar"))
self.assertTrue(hasattr(gm_unflat_non_strict.bar, "buffer"))
self.assertTrue(hasattr(gm_unflat_non_strict.bar, "leaf"))

gm_unflat_strict = unflatten(ep_strict)

self.assertEqual(gm_unflat_non_strict(*inp), gm_unflat_strict(*inp))
self.assertExpectedInline(
str(gm_unflat_non_strict.bar.leaf.linear.graph).strip(), """\
graph():
%arg3_1 : [num_users=1] = placeholder[target=arg3_1]
%bias : [num_users=1] = get_attr[target=bias]
%weight : [num_users=1] = get_attr[target=weight]
%t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%weight,), kwargs = {})
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%bias, %arg3_1, %t), kwargs = {})
return addmm"""
)

gm_flat_non_strict = ep_non_strict.module()
gm_flat_strict = ep_strict.module()

self.assertEqual(gm_flat_non_strict(*inp), gm_flat_strict(*inp))

def test_nn_module_stack_shared_submodule(self):
class Leaf(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)

def forward(self, x):
return self.linear(x)

class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self.leaf = Leaf()
self.register_buffer("buffer", torch.randn(4, 4))

def forward(self, x):
return self.buffer.sum() + self.leaf(x).sum()

class BarDifferent(torch.nn.Module):
def __init__(self):
super().__init__()
self.leaf = Leaf()

def forward(self, x):
a = self.leaf(x).sum()
b = self.leaf(x).sum()
return a + b

class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.bar = Bar()
self.bar_different = BarDifferent()

def forward(self, x):
y = self.bar.buffer + x
return (self.bar(x) + self.bar_different(x + 2), y.sum(),)

inp = (torch.randn(4, 4),)
mod = Foo()
ep_strict = torch.export.export(mod, inp)
ep_non_strict = torch.export.export(mod, inp, strict=False)

gm_unflat_non_strict = unflatten(ep_non_strict)
self.assertTrue(hasattr(gm_unflat_non_strict, "bar"))
self.assertTrue(hasattr(gm_unflat_non_strict.bar, "buffer"))
self.assertTrue(hasattr(gm_unflat_non_strict.bar, "leaf"))
self.assertTrue(hasattr(gm_unflat_non_strict.bar_different, "leaf"))

gm_unflat_strict = unflatten(ep_strict)

self.assertEqual(gm_unflat_non_strict(*inp), gm_unflat_strict(*inp))
self.assertExpectedInline(
str(gm_unflat_non_strict.bar.leaf.linear.graph).strip(), """\
graph():
%arg5_1 : [num_users=1] = placeholder[target=arg5_1]
%bias : [num_users=1] = get_attr[target=bias]
%weight : [num_users=1] = get_attr[target=weight]
%t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%weight,), kwargs = {})
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%bias, %arg5_1, %t), kwargs = {})
return addmm"""
)
self.assertExpectedInline(
str(gm_unflat_non_strict.bar_different.leaf.linear.graph).strip(), """\
graph():
%add_2 : [num_users=1] = placeholder[target=add_2]
%bias : [num_users=1] = get_attr[target=bias]
%weight : [num_users=1] = get_attr[target=weight]
%t_1 : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%weight,), kwargs = {})
%addmm_1 : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%bias, %add_2, %t_1), kwargs = {})
return addmm_1"""
)

gm_flat_non_strict = ep_non_strict.module()
gm_flat_strict = ep_strict.module()

self.assertEqual(gm_flat_non_strict(*inp), gm_flat_strict(*inp))

def test_cond_with_module_stack_export_with(self):
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)

def forward(self, x):
def true_fn(x):
return self.linear(x).cos()
def false_fn(x):
return self.linear(x).sin()
return torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])

class CondExport(torch.nn.Module):
def __init__(self):
super().__init__()
self.bar = Bar()

def forward(self, x):
return x.cos() + self.bar(x)

inp = (torch.randn(4, 4),)
ep = torch.export.export(CondExport(), inp, strict=False)
self.assertExpectedInline(ep.graph_module.code.strip(), """\
def forward(self, arg0_1, arg1_1, arg2_1):
cos = torch.ops.aten.cos.default(arg2_1)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg1_1, arg0_1, arg2_1]); true_graph_0 = false_graph_0 = arg1_1 = arg0_1 = arg2_1 = None
getitem = conditional[0]; conditional = None
add = torch.ops.aten.add.Tensor(cos, getitem); cos = getitem = None
return (add,)""")

cond_top_level_nn_module_stack = [
node.meta["nn_module_stack"]
for node in ep.graph.nodes
if node.name == "true_graph_0"
][0]

self.assertTrue("test_cond_with_module_stack_export_with.<locals>.Bar" in str(cond_top_level_nn_module_stack))

# TODO: See https://github.com/pytorch/pytorch/issues/115790
@unittest.expectedFailure
def test_cond_with_module_stack_export_with_unflatten(self):
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)

def forward(self, x):
def true_fn(x):
return self.linear(x).cos()
def false_fn(x):
return self.linear(x).sin()
return torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])

class CondExport(torch.nn.Module):
def __init__(self):
super().__init__()
self.bar = Bar()

def forward(self, x):
return x.cos() + self.bar(x)

inp = (torch.randn(4, 4),)
ep = torch.export.export(CondExport(), inp, strict=False)

cond_top_level_nn_module_stack = [
node.meta["nn_module_stack"]
for node in ep.graph.nodes
if node.name == "true_graph_0"
][0]

# we can't preserve nn_module_stack for the subgraphs for now.
for node in ep.graph_module.true_graph_0.graph.nodes:
self.assertEqual(node.meta["nn_module_stack"], cond_top_level_nn_module_stack)

# this doesn't work today
gm_unflat_strict = unflatten(ep)


if __name__ == '__main__':
run_tests()
49 changes: 49 additions & 0 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings
import operator
from collections.abc import Iterable
from torch.nn.utils import stateless
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_methods_invocations import op_db, skip, xfail, skipOps
from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException, FakeTensorMode
Expand Down Expand Up @@ -1456,6 +1457,54 @@ def forward(self, a_1):
div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None
return div""")

def test_make_fx_with_custom_tracer_preserving_nn_module_stack(self):

class Bar(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x + 1

class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.bar = Bar()

def forward(self, x):
return x + self.bar(x)

gm = make_fx(Foo())(torch.randn(4, 4))
for node in gm.graph.nodes:
self.assertTrue("nn_module_stack" not in node.meta)

foo = Foo()

def functional_call(*args, **kwargs):
with stateless._reparametrize_module(foo, {}):
return foo(*args, **kwargs)

functional_call._orig_mod = foo

gm_with_stack = make_fx(functional_call, record_module_stack=True)(torch.randn(4, 4))
found = False
for node in gm_with_stack.graph.nodes:
if "nn_module_stack" in node.meta:
if len(node.meta["nn_module_stack"]) == 1:
self.assertTrue("custom_tracer_preserving_nn_module_stack.<locals>.Foo" in str(node.meta["nn_module_stack"]))
found = True
elif len(node.meta["nn_module_stack"]) == 2:
self.assertTrue("preserving_nn_module_stack.<locals>.Bar" in str(node.meta["nn_module_stack"]))
found = True
else:
# there can be at most 2 level
self.assertTrue(False)

self.assertTrue(found)

gm_without_stack = make_fx(functional_call)(torch.randn(4, 4))
for node in gm_without_stack.graph.nodes:
self.assertTrue("nn_module_stack" not in node.meta)

def test_symint_to_tensor(self):
def f(a):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@

def _create_graph(f, args, *, aot_config: AOTConfig) -> torch.fx.GraphModule:
with enable_python_dispatcher():
fx_g = make_fx(f, decomposition_table=aot_config.decompositions)(*args)
fx_g = make_fx(
f,
decomposition_table=aot_config.decompositions,
record_module_stack=True,
)(*args)

return fx_g

Expand Down
20 changes: 15 additions & 5 deletions torch/_functorch/_aot_autograd/traced_function_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import warnings
from contextlib import nullcontext
from functools import wraps
from typing import Any, Callable, List, Tuple, Union
from unittest.mock import patch

Expand Down Expand Up @@ -63,6 +64,7 @@ def fn_input_mutations_to_outputs(
meta: ViewAndMutationMeta,
keep_data_input_mutations: bool,
) -> Any:
@wraps(fn)
def inner_fn(*args):
outs = fn(*args)
assert len(meta.output_info) == len(outs)
Expand Down Expand Up @@ -95,6 +97,7 @@ def fn_prepped_for_autograd(
fn: Callable,
meta: ViewAndMutationMeta,
) -> Any:
@wraps(fn)
def inner_fn(*args):
args_maybe_cloned = [
maybe_to_fresh_input(i, t, meta) for i, t in enumerate(args)
Expand Down Expand Up @@ -342,6 +345,7 @@ def create_functionalized_fn(
aot_config: AOTConfig,
trace_joint: bool,
) -> Any:
@wraps(fn)
def _functionalized_f_helper(*args):
# Wrap inputs into functional wrappers
f_args = pytree.tree_map(to_fun, args)
Expand Down Expand Up @@ -459,13 +463,11 @@ def _functionalized_f_helper(*args):

# Kinda annoying, but needed to make sure that the fx graph we trace out has "primals"
# and "tangents" as its input names (which are special-cased by the partitioner)
# TODO (tmanlaibaatar) revisit this if we ever need to turn on non-strict joint graph export
def joint_helper(primals, tangents):
return _functionalized_f_helper(primals, tangents)

def fwd_helper(*args):
return _functionalized_f_helper(*args)

helper = joint_helper if trace_joint else fwd_helper
helper = joint_helper if trace_joint else _functionalized_f_helper
if config.functionalize_rng_ops:
# Setup the wrapper for functionalization of rng ops
helper, args = create_functionalized_rng_ops_wrapper(helper, args, trace_joint)
Expand Down Expand Up @@ -586,7 +588,7 @@ def metadata_fn(*primals):
)


def create_functional_call(mod, params_spec, params_len):
def create_functional_call(mod, params_spec, params_len, store_orig_mod=False):
# Redundant with dynamo, but worth having in case this gets invoked elsewhere.
# https://github.com/pytorch/pytorch/issues/103569

Expand All @@ -612,4 +614,12 @@ def functional_call(*args, **kwargs):
)
return out

# Note [Preserving the nn module stack metadata during export non-strict mode]
# This path is currently only used by the non-strict export flow,
# where we cannot rely on dynamo to preserve nn stack metadata in our captured graph.
# Instead, we stash the original user nn module here, and rely on `make_fx` to grab
# this stashed module and use it to track nn module stack metadata
if store_orig_mod and not hasattr(functional_call, "_orig_mod"):
functional_call._orig_mod = mod # type: ignore[attr-defined]

return functional_call
5 changes: 5 additions & 0 deletions torch/_functorch/_aot_autograd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ def flat_fn(*flat_args):
out_spec.set(spec)
return flat_out

# Can't use functools.wraps here because the wrapper has different
# calling convention
if hasattr(fn, "_orig_mod"):
flat_fn._orig_mod = fn._orig_mod # type: ignore[attr-defined]

return flat_fn, out_spec


Expand Down
Loading