Skip to content

torch.compile is incompatible with tensor parallelism #2598

@robotsp

Description

@robotsp

Describe the bug

Finetune QwenVL with tensor parallel size = 4, using torch.compile for kernels optimization

Steps/Code to reproduce bug

Please list minimal steps or code snippet for us to be able to reproduce the bug.

model = model_provider_func() model = torch.compile(model)

Expected behavior

performance acceleration w/ torch.compile on Megatron

Additional context

Error occur:
File "/tmp/torchinductor_root/ab/cab27rpc5syxsj7fvzyur3lov3xqznlnvvqq553eqxnzco5hq4rr.py", line 37, in call [rank3]: assert_size_stride(arg0_1, (2005, 1, 5120), (5120, 10265600, 1)) [rank3]: AssertionError: wrong number of dimensions

Metadata

Metadata

Assignees

Labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions