Skip to content

Commit cfaa39b

Browse files
committed
Update on "Fix training enablement in AOTAutograd"
Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
2 parents c3caef9 + e638f0c commit cfaa39b

1 file changed

Lines changed: 12 additions & 5 deletions

File tree

torch/_functorch/aot_autograd.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,7 +1272,15 @@ class AOTConfig:
12721272
num_params_buffers: int
12731273
aot_id: int
12741274
keep_inference_input_mutations: bool
1275-
dynamic_shapes: bool
1275+
# If None, defer to config
1276+
_dynamic_shapes: Optional[bool] = None
1277+
1278+
@property
1279+
def dynamic_shapes(self):
1280+
if self._dynamic_shapes is None:
1281+
return config.use_dynamic_shapes
1282+
else:
1283+
return self._dynamic_shapes
12761284

12771285
def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
12781286
with enable_python_dispatcher():
@@ -2323,7 +2331,7 @@ def call_compiled_backward():
23232331
aot_config.bw_compiler, None, None,
23242332
aot_config.decompositions, 0, aot_config.aot_id,
23252333
aot_config.keep_inference_input_mutations,
2326-
aot_config.dynamic_shapes
2334+
aot_config._dynamic_shapes
23272335
)
23282336
)
23292337
else:
@@ -2456,7 +2464,7 @@ def create_aot_dispatcher_function(
24562464
shape_env = fake_mode.shape_env
24572465
break
24582466
else:
2459-
shape_env = ShapeEnv() if config.use_dynamic_shapes or aot_config.dynamic_shapes else None
2467+
shape_env = ShapeEnv() if aot_config.dynamic_shapes else None
24602468
fake_mode = (
24612469
FakeTensorMode(shape_env=shape_env)
24622470
if config.use_fake_tensor
@@ -2622,7 +2630,6 @@ def aot_function(
26222630
num_params_buffers=num_params_buffers,
26232631
aot_id=next(AOT_COUNTER),
26242632
keep_inference_input_mutations=keep_inference_input_mutations,
2625-
dynamic_shapes=config.dynamic_shapes,
26262633
)
26272634
cached_res = None
26282635

@@ -2826,7 +2833,7 @@ def functional_call(*args, **kwargs):
28262833
num_params_buffers=params_len,
28272834
aot_id=next(AOT_COUNTER),
28282835
keep_inference_input_mutations=keep_inference_input_mutations,
2829-
dynamic_shapes=dynamic_shapes
2836+
_dynamic_shapes=dynamic_shapes
28302837
)
28312838

28322839
compiled_fn = create_aot_dispatcher_function(

0 commit comments

Comments
 (0)