@@ -1272,7 +1272,15 @@ class AOTConfig:
12721272 num_params_buffers : int
12731273 aot_id : int
12741274 keep_inference_input_mutations : bool
1275- dynamic_shapes : bool
1275+ # If None, defer to config
1276+ _dynamic_shapes : Optional [bool ] = None
1277+
1278+ @property
1279+ def dynamic_shapes (self ):
1280+ if self ._dynamic_shapes is None :
1281+ return config .use_dynamic_shapes
1282+ else :
1283+ return self ._dynamic_shapes
12761284
12771285def aot_dispatch_base (flat_fn , flat_args : List [Tensor ], aot_config : AOTConfig ):
12781286 with enable_python_dispatcher ():
@@ -2323,7 +2331,7 @@ def call_compiled_backward():
23232331 aot_config .bw_compiler , None , None ,
23242332 aot_config .decompositions , 0 , aot_config .aot_id ,
23252333 aot_config .keep_inference_input_mutations ,
2326- aot_config .dynamic_shapes
2334+ aot_config ._dynamic_shapes
23272335 )
23282336 )
23292337 else :
@@ -2456,7 +2464,7 @@ def create_aot_dispatcher_function(
24562464 shape_env = fake_mode .shape_env
24572465 break
24582466 else :
2459- shape_env = ShapeEnv () if config . use_dynamic_shapes or aot_config .dynamic_shapes else None
2467+ shape_env = ShapeEnv () if aot_config .dynamic_shapes else None
24602468 fake_mode = (
24612469 FakeTensorMode (shape_env = shape_env )
24622470 if config .use_fake_tensor
@@ -2622,7 +2630,6 @@ def aot_function(
26222630 num_params_buffers = num_params_buffers ,
26232631 aot_id = next (AOT_COUNTER ),
26242632 keep_inference_input_mutations = keep_inference_input_mutations ,
2625- dynamic_shapes = config .dynamic_shapes ,
26262633 )
26272634 cached_res = None
26282635
@@ -2826,7 +2833,7 @@ def functional_call(*args, **kwargs):
28262833 num_params_buffers = params_len ,
28272834 aot_id = next (AOT_COUNTER ),
28282835 keep_inference_input_mutations = keep_inference_input_mutations ,
2829- dynamic_shapes = dynamic_shapes
2836+ _dynamic_shapes = dynamic_shapes
28302837 )
28312838
28322839 compiled_fn = create_aot_dispatcher_function (
0 commit comments