I am trying to run DeepSeekV3 with SimpleFSDP and torch.compile enabled, but the compilation fails. The error seems to be related to dynamic tensor shapes (the number of tokens per expert) in the MoE model.
File "/home/axel/ctml/torchtitan/torchtitan/train.py", line 696, in train
self.train_step(data_iterator)
File "/home/axel/ctml/torchtitan/torchtitan/train.py", line 588, in train_step
loss = self.forward_backward_step(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/ctml/torchtitan/torchtitan/train.py", line 536, in forward_backward_step
pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 469, in __call__
return super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1006, in compile_wrapper
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 2575, in _call_user_compiler
raise BackendCompilerFailed(
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 2550, in _call_user_compiler
compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/__init__.py", line 2530, in __call__
return self.compiler_fn(model_, inputs_, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/ctml/torchtitan/torchtitan/experiments/simple_fsdp/backend.py", line 153, in simple_fsdp_custom_pass
return backend(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/backends/inductor.py", line 31, in inductor
return compile_fx(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2564, in compile_fx
return _maybe_wrap_and_compile_fx_main(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2641, in _maybe_wrap_and_compile_fx_main
return _compile_fx_main(
^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2836, in _compile_fx_main
return aot_autograd(
^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 124, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1149, in aot_module_simplified
aot_state = create_aot_state(
^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 583, in create_aot_state
fw_metadata = run_functionalized_fw_and_collect_metadata(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 221, in inner
flat_f_outs = f(*flat_f_args)
^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1414, in functional_call
out = PropagateUnbackedSymInts(mod).run(*args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/fx/interpreter.py", line 200, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7957, in run_node
rebind_unbacked(fake_mode.shape_env, n, result)
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 673, in rebind_unbacked
assert not raw_u1.free_symbols, (
^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='simple_fsdp_custom_pass' raised:
AssertionError: should have been constant, but got u10 + u11 + u12 + u13 + u14 + u15 + u16 + u9
While executing %getitem_83 : [num_users=2] = call_function[target=operator.getitem](args = (%routed_input_1, slice(None, sym_sum_11, None)), kwargs = {})
Original traceback:
File "/home/axel/ctml/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 502, in forward
h = layer(h, self.freqs_cis, attention_masks, positions)
File "/home/axel/ctml/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 391, in forward
x = x + self.moe(self.ffn_norm(x))
File "/home/axel/ctml/torchtitan/torchtitan/models/moe/moe.py", line 518, in forward
routed_output = self.experts(routed_input, num_tokens_per_expert)
File "/home/axel/ctml/torchtitan/torchtitan/models/moe/moe.py", line 177, in forward
return _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert)
File "/home/axel/ctml/torchtitan/torchtitan/models/moe/moe.py", line 91, in _run_experts_for_loop
x[: sum(num_tokens_per_expert_list)],
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+dynamo,+inductor,graph_code" \
TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHDYNAMO_DISABLE_CACHE=1 TORCHINDUCTOR_CACHE_DISABLE=1 \
torchrun --nproc_per_node=1 -m torchtitan.train --model.name simple_fsdp.deepseek_v3 \
--job.config_file=torchtitan/models/deepseek_v3/train_configs/debug_model.toml \
--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config \
--compile.enable \
--parallelism.data_parallel_shard_degree 1 \
--parallelism.tensor_parallel_degree 1 \
--parallelism.expert_parallel_degree=1 \
--training.mixed_precision_param=float32 \
--training.local_batch_size 1 \
--training.seq_len 256 \
--activation_checkpoint.mode "none"
Bug description
Description
I am trying to run DeepSeekV3 with SimpleFSDP and torch.compile enabled, but the compilation fails. The error seems to be related to dynamic tensor shapes (the number of tokens per expert) in the MoE model.
Error log
Versions
Environment
GPU: Nvidia TITAN V
Python 3.12.12
torchtitan version: 02661e8
torch version: 2.11.0.dev20260202+cu126
Reproduce