Skip to content

Llama4 bf16 compile issue: "wrong number of dimensions4 for op: torch.ops.aten._scaled_dot_product_cudnn_attention.default" #1713

@danielvegamyhre

Description

@danielvegamyhre

Bug description

Summary

Repro command: NGPU=4 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --parallelism.data_parallel_shard_degree=4 --parallelism.tensor_parallel_degree=1 --model.print-after-conversion --metrics.log_freq=10 --training.steps=30 --compile.enable

Error:

[rank0]:[rank0]:   File "/home/danvm/.conda/envs/torch/lib/python3.12/site-packages/torch/_inductor/utils.py", line 2968, in run
[rank0]:[rank0]:     out = model(new_inputs)
[rank0]:[rank0]:           ^^^^^^^^^^^^^^^^^
[rank0]:[rank0]:   File "/tmp/torchinductor_danvm/b3/cb33nhusgdmkhvubfh6rhdkeotoqijnyasb4qlxhc7oahyb3rpp7.py", line 1319, in call
[rank0]:[rank0]:     assert_size_stride(buf27, (8, 40, 2048), (81920, 2048, 1), 'torch.ops.aten._scaled_dot_product_cudnn_attention.default')
[rank0]:[rank0]: AssertionError: wrong number of dimensions4 for op: torch.ops.aten._scaled_dot_product_cudnn_attention.default

Does NOT repro in eager.


After side stepping the issue: #1713 (comment)

There's a second error: P1949996815


cc @xmfan @tianyu-l

Versions

  • torch cuda 12.8 nightly build
  • torchtitan latest main branch

Metadata

Metadata

Assignees

Type

No fields configured for Bug.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions