Skip to content

Commit 0ee6bde

Browse files
committed
Delete torch._functorch.config.use_dynamic_shapes
As requested in #95975 (comment) Signed-off-by: Edward Z. Yang <ezyang@meta.com> [ghstack-poisoned]
1 parent 36a6e2c commit 0ee6bde

4 files changed

Lines changed: 35 additions & 65 deletions

File tree

test/functorch/test_aotdispatch.py

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def verify_aot_autograd(
245245
*,
246246
test_mutation: bool = False,
247247
decompositions: Optional[Dict] = None,
248+
dynamic: bool = False,
248249
):
249250
for keep_input_mutations in [True, False]:
250251
# Some tests pass in a callable for inp, to generate the inputs
@@ -294,15 +295,17 @@ def verify_aot_autograd(
294295
fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
295296
bw_compiler=nop,
296297
decompositions=decompositions,
297-
keep_inference_input_mutations=keep_input_mutations
298+
keep_inference_input_mutations=keep_input_mutations,
299+
dynamic=dynamic
298300
)
299301
else:
300302
compiled_f = aot_function(
301303
f,
302304
fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
303305
bw_compiler=nop,
304306
decompositions=decompositions,
305-
keep_inference_input_mutations=keep_input_mutations
307+
keep_inference_input_mutations=keep_input_mutations,
308+
dynamic=dynamic
306309
)
307310
ref_out, ref_grad = outs_and_grads(f, graph_inps, inp)
308311
test_out, test_grad = outs_and_grads(compiled_f, graph_inps_copy, inp_copy)
@@ -366,21 +369,17 @@ def f(a, b):
366369
self.verify_aot_autograd(f, inp)
367370

368371
# Test for bug occurring at the intersection of fake tensors & functionalization.
369-
@patch("torch._functorch.config.use_dynamic_shapes", True)
370-
@patch("torch._functorch.config.use_fake_tensor", True)
371372
def test_squeeze_mutation(self):
372373
def f(a):
373374
b = a.clone().squeeze(-1)
374375
b.add_(1.)
375376
return a + b
376377

377378
inp = [torch.randn(3, 1, requires_grad=True)]
378-
self.verify_aot_autograd(f, inp)
379+
self.verify_aot_autograd(f, inp, dynamic=True)
379380
inp = [torch.randn(3, 1, requires_grad=False)]
380-
self.verify_aot_autograd(f, inp)
381+
self.verify_aot_autograd(f, inp, dynamic=True)
381382

382-
@patch("torch._functorch.config.use_dynamic_shapes", True)
383-
@patch("torch._functorch.config.use_fake_tensor", True)
384383
def test_embedding_bag_view(self):
385384
# Backwards pass tries to wrap a sparse tensor in a FunctionalTensorWrapper;
386385
# test that this works even though the sparse tensor has no storage.
@@ -395,7 +394,7 @@ def forward(self, x, y):
395394

396395
x = torch.arange(3)
397396
y = torch.arange(3)
398-
self.verify_aot_autograd(F(), [x, y])
397+
self.verify_aot_autograd(F(), [x, y], dynamic=True)
399398

400399
@patch("functorch.compile.config.use_fake_tensor", True)
401400
def test_input_mutation_simple(self):
@@ -1714,8 +1713,6 @@ def bn(x):
17141713
for a, b in zip(ref, res):
17151714
assert torch.allclose(a, b)
17161715

1717-
@patch("functorch.compile.config.use_dynamic_shapes", True)
1718-
@patch("functorch.compile.config.use_fake_tensor", True)
17191716
def test_output_op_depending_on_symint(self):
17201717
"""
17211718
It won't be obvious from reading this test what it's testing for. We should probably make it into a more
@@ -1738,12 +1735,10 @@ def f(x):
17381735
# TODO: assert outputs of fwd graph trace to correct symint
17391736

17401737
# e2e test that fails without symint clone fix
1741-
af = aot_function(f, nop, partition_fn=partial(min_cut_rematerialization_partition, compiler="inductor"))
1738+
af = aot_function(f, nop, partition_fn=partial(min_cut_rematerialization_partition, compiler="inductor"), dynamic=True)
17421739
out = af(inp)
17431740
self.assertEqual(out, f(inp))
17441741

1745-
@patch("functorch.compile.config.use_dynamic_shapes", True)
1746-
@patch("functorch.compile.config.use_fake_tensor", True)
17471742
def test_default_partitioner_saves_symints_not_tensors_for_bw(self):
17481743
"""
17491744
In this test, the important thing is that primals_1 is **only** needed in the backward
@@ -1764,7 +1759,7 @@ def f(a):
17641759
d = b.masked_fill_(c, 0)
17651760
return d
17661761

1767-
compiled_f = aot_function(f, nop)
1762+
compiled_f = aot_function(f, nop, dynamic=True)
17681763
inp_ref = torch.ones(2, 2, requires_grad=True)
17691764
inp_test = torch.ones(2, 2, requires_grad=True)
17701765

@@ -1859,14 +1854,15 @@ def get_num_ins_outs(fx_g):
18591854
return tuple(len(i) for i in get_ins_outs(fx_g))
18601855

18611856

1862-
def get_fw_bw_graph(f, inps, partitioner=min_cut_rematerialization_partition):
1857+
def get_fw_bw_graph(f, inps, partitioner=min_cut_rematerialization_partition, dynamic=False):
18631858
fw_graph_cell = [None]
18641859
bw_graph_cell = [None]
18651860
aot_function(f,
18661861
fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
18671862
bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
18681863
partition_fn=partitioner,
1869-
decompositions=default_decompositions)(*inps).sum().backward()
1864+
decompositions=default_decompositions,
1865+
dynamic=dynamic)(*inps).sum().backward()
18701866
return (fw_graph_cell[0], bw_graph_cell[0])
18711867

18721868

@@ -1933,8 +1929,6 @@ def f(x, mod_weight, mod_bias):
19331929
self.assertEqual(get_num_ins_outs(fw_graph), (3, 6))
19341930
self.assertEqual(get_num_ins_outs(bw_graph), (6, 3))
19351931

1936-
@patch("functorch.compile.config.use_dynamic_shapes", True)
1937-
@patch("functorch.compile.config.use_fake_tensor", True)
19381932
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
19391933
def test_min_cut_partitioner_save_shape(self):
19401934

@@ -1943,7 +1937,7 @@ def f(x):
19431937
return s
19441938

19451939
inp = [torch.ones([10, 10], requires_grad=True)]
1946-
fw_graph, bw_graph = get_fw_bw_graph(f, inp)
1940+
fw_graph, bw_graph = get_fw_bw_graph(f, inp, dynamic=True)
19471941
_, fw_output = get_ins_outs(fw_graph)
19481942
self.assertEqual(get_num_ins_outs(fw_graph), (1, 3))
19491943
self.assertEqual(get_num_ins_outs(bw_graph), (3, 1))
@@ -1968,14 +1962,12 @@ def f(a, b, c):
19681962
x = sb[0] + sc[0]
19691963
a_sz = (x, a.size(0))
19701964
return torch.cat([a.expand(a_sz), b, c])
1971-
fw_graph, bw_graph = get_fw_bw_graph(f, inp)
1965+
fw_graph, bw_graph = get_fw_bw_graph(f, inp, dynamic=True)
19721966
self.assertEqual(get_num_ins_outs(fw_graph), (3, 5))
19731967
self.assertEqual(get_num_ins_outs(bw_graph), (5, 3))
19741968
_, outs = get_ins_outs(fw_graph)
19751969
self.assertTrue(all([is_sym_node(n) for n in outs[1:]]))
19761970

1977-
@patch("functorch.compile.config.use_dynamic_shapes", True)
1978-
@patch("functorch.compile.config.use_fake_tensor", True)
19791971
def test_default_partitioner_output_tensor_shape_tensor(self):
19801972

19811973
inp = [
@@ -2004,7 +1996,8 @@ def f(a, b, c, d):
20041996
fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
20051997
bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
20061998
partition_fn=default_partition,
2007-
decompositions=default_decompositions)(*inp)
1999+
decompositions=default_decompositions,
2000+
dynamic=True)(*inp)
20082001
fw_graph = fw_graph_cell[0]
20092002
(compiled_outs[0].sum() + compiled_outs[2].sum()).backward()
20102003
bw_graph = bw_graph_cell[0]
@@ -2037,8 +2030,6 @@ def f(a, b, c, d):
20372030
# TODO(whc) we should learn to return torch.Sizes
20382031
self.assertFalse(isinstance(compiled_outs[1], torch.Size))
20392032

2040-
@patch("functorch.compile.config.use_dynamic_shapes", True)
2041-
@patch("functorch.compile.config.use_fake_tensor", True)
20422033
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
20432034
def test_min_cut_partitioner_output_tensor_shape_tensor(self):
20442035

@@ -2068,7 +2059,8 @@ def f(a, b, c, d):
20682059
fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
20692060
bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
20702061
partition_fn=min_cut_rematerialization_partition,
2071-
decompositions=default_decompositions)(*inp)
2062+
decompositions=default_decompositions,
2063+
dynamic=True)(*inp)
20722064
fw_graph = fw_graph_cell[0]
20732065
(compiled_outs[0].sum() + compiled_outs[2].sum()).backward()
20742066
bw_graph = bw_graph_cell[0]
@@ -2617,7 +2609,7 @@ def create_new_arg(x):
26172609
except DynamicOutputShapeException:
26182610
self.skipTest("Dynamic output shape operation in trace")
26192611

2620-
def _test_aot_autograd_helper(self, device, dtype, op):
2612+
def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False):
26212613
if not op.supports_autograd:
26222614
self.skipTest("Op does not support autograd")
26232615

@@ -2639,7 +2631,7 @@ def f(args):
26392631
c_args, c_kwargs = pytree.tree_unflatten(cur_flat_args, args_spec)
26402632
return op.op(*c_args, **c_kwargs)
26412633

2642-
compiled_f = compiled_function(f, nop, nop)
2634+
compiled_f = compiled_function(f, nop, nop, dynamic=dynamic)
26432635
try:
26442636
_test_aot_autograd_forwards_backwards_helper(self, f, compiled_f, args)
26452637
except GuardOnDataDependentSymNode:
@@ -2651,7 +2643,7 @@ def f(args):
26512643
else:
26522644
raise
26532645

2654-
def _test_aot_autograd_module_helper(self, device, dtype, training, module_info):
2646+
def _test_aot_autograd_module_helper(self, device, dtype, training, module_info, *, dynamic=False):
26552647
module_cls = module_info.module_cls
26562648
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
26572649
requires_grad=True, training=training)
@@ -2696,7 +2688,7 @@ def f(params_buffers_args):
26962688
named_params = dict(m.named_parameters(remove_duplicate=False))
26972689
named_buffers = dict(m.named_buffers(remove_duplicate=False))
26982690
num_params_buffers = len(named_params) + len(named_buffers)
2699-
compiled_f = aot_function(f, nop, num_params_buffers=num_params_buffers)
2691+
compiled_f = aot_function(f, nop, num_params_buffers=num_params_buffers, dynamic=dynamic)
27002692
params_buffers_args = [named_params, named_buffers, args]
27012693
_test_aot_autograd_forwards_backwards_helper(self, f, compiled_f, params_buffers_args)
27022694

@@ -2708,13 +2700,11 @@ def test_aot_autograd_exhaustive(self, device, dtype, op):
27082700
_test_aot_autograd_helper(self, device, dtype, op)
27092701

27102702
@ops(op_db, allowed_dtypes=(torch.float,))
2711-
@patch("functorch.compile.config.use_dynamic_shapes", True)
2712-
@patch("functorch.compile.config.use_fake_tensor", True)
27132703
@patch("functorch.compile.config.use_functionalize", True)
27142704
@skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_symbolic_exhaustive',
27152705
aot_autograd_failures | symbolic_aot_autograd_failures)
27162706
def test_aot_autograd_symbolic_exhaustive(self, device, dtype, op):
2717-
_test_aot_autograd_helper(self, device, dtype, op)
2707+
_test_aot_autograd_helper(self, device, dtype, op, dynamic=True)
27182708

27192709

27202710
aot_autograd_module_failures = set({
@@ -2754,13 +2744,11 @@ def test_aot_autograd_module_exhaustive(self, device, dtype, training, module_in
27542744
_test_aot_autograd_module_helper(self, device, dtype, training, module_info)
27552745

27562746
@modules(module_db, allowed_dtypes=(torch.float,))
2757-
@patch("functorch.compile.config.use_dynamic_shapes", True)
2758-
@patch("functorch.compile.config.use_fake_tensor", True)
27592747
@patch("functorch.compile.config.use_functionalize", True)
27602748
@decorateForModules(unittest.expectedFailure,
27612749
aot_autograd_module_failures | symbolic_aot_autograd_module_failures)
27622750
def test_aot_autograd_symbolic_module_exhaustive(self, device, dtype, training, module_info):
2763-
_test_aot_autograd_module_helper(self, device, dtype, training, module_info)
2751+
_test_aot_autograd_module_helper(self, device, dtype, training, module_info, dynamic=True)
27642752

27652753

27662754
only_for = ("cpu")

torch/_export/__init__.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,19 +116,9 @@ def set_state_proxies(state_args):
116116
num_params_buffers=params_len,
117117
aot_id=-1,
118118
keep_inference_input_mutations=False,
119+
dynamic_shapes=True,
119120
)
120121

121-
@contextlib.contextmanager
122-
def setup_dynamic_shape():
123-
prev, torch._functorch.config.use_dynamic_shapes = (
124-
torch._functorch.config.use_dynamic_shapes,
125-
True,
126-
)
127-
try:
128-
yield
129-
finally:
130-
torch._functorch.config.use_dynamic_shapes = prev
131-
132122
def exported_call(*args):
133123
state_args = args[:params_len]
134124
unwrapped_state_args = _unwrap_all_tensors_from_functional(
@@ -141,7 +131,7 @@ def exported_call(*args):
141131
outputs, out_spec = pytree.tree_flatten(outputs)
142132
return outputs
143133

144-
with torch.enable_grad(), setup_dynamic_shape():
134+
with torch.enable_grad():
145135
create_aot_dispatcher_function(
146136
exported_call,
147137
full_args,

torch/_functorch/aot_autograd.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,15 +1272,7 @@ class AOTConfig:
12721272
num_params_buffers: int
12731273
aot_id: int
12741274
keep_inference_input_mutations: bool
1275-
# If None, defer to config
1276-
_dynamic_shapes: Optional[bool] = None
1277-
1278-
@property
1279-
def dynamic_shapes(self):
1280-
if self._dynamic_shapes is None:
1281-
return config.use_dynamic_shapes
1282-
else:
1283-
return self._dynamic_shapes
1275+
dynamic_shapes: bool = False
12841276

12851277
def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
12861278
with enable_python_dispatcher():
@@ -2331,7 +2323,7 @@ def call_compiled_backward():
23312323
aot_config.bw_compiler, None, None,
23322324
aot_config.decompositions, 0, aot_config.aot_id,
23332325
aot_config.keep_inference_input_mutations,
2334-
aot_config._dynamic_shapes
2326+
aot_config.dynamic_shapes
23352327
)
23362328
)
23372329
else:
@@ -2563,7 +2555,10 @@ def aot_function(
25632555
num_params_buffers: int = 0,
25642556
hasher_type=None, # deprecated
25652557
static_argnums: Optional[Tuple[int]] = None, # deprecated
2566-
keep_inference_input_mutations: bool = False
2558+
keep_inference_input_mutations: bool = False,
2559+
*,
2560+
# Whether or not to trace with dynamic shapes
2561+
dynamic=False,
25672562
) -> Callable:
25682563
"""
25692564
Traces the forward and backward graph of :attr:`fn` using torch dispatch
@@ -2630,6 +2625,7 @@ def aot_function(
26302625
num_params_buffers=num_params_buffers,
26312626
aot_id=next(AOT_COUNTER),
26322627
keep_inference_input_mutations=keep_inference_input_mutations,
2628+
dynamic_shapes=dynamic,
26332629
)
26342630
cached_res = None
26352631

@@ -2822,8 +2818,6 @@ def functional_call(*args, **kwargs):
28222818
if isinstance(x, FakeTensor):
28232819
dynamic_shapes = x.fake_mode.shape_env is not None
28242820
break
2825-
else:
2826-
dynamic_shapes = config.use_dynamic_shapes
28272821

28282822
aot_config = AOTConfig(
28292823
fw_compiler=fw_compiler,
@@ -2833,7 +2827,7 @@ def functional_call(*args, **kwargs):
28332827
num_params_buffers=params_len,
28342828
aot_id=next(AOT_COUNTER),
28352829
keep_inference_input_mutations=keep_inference_input_mutations,
2836-
_dynamic_shapes=dynamic_shapes
2830+
dynamic_shapes=dynamic_shapes
28372831
)
28382832

28392833
compiled_fn = create_aot_dispatcher_function(

torch/_functorch/config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131
# Prints out joint graph traced, before partitioning
3232
debug_joint = os.environ.get("AOT_FX_GRAPHS_JOINT", False)
3333

34-
use_dynamic_shapes = os.getenv("AOT_DYNAMIC_SHAPES", False)
35-
3634
static_weight_shapes = True
3735

3836
# Applies CSE to the graph before partitioning

0 commit comments

Comments
 (0)