Describe the bug
Finetune QwenVL with tensor parallel size = 4, using torch.compile for kernels optimization
Steps/Code to reproduce bug
Please list minimal steps or code snippet for us to be able to reproduce the bug.
model = model_provider_func() model = torch.compile(model)
Expected behavior
performance acceleration w/ torch.compile on Megatron
Additional context
Error occur:
File "/tmp/torchinductor_root/ab/cab27rpc5syxsj7fvzyur3lov3xqznlnvvqq553eqxnzco5hq4rr.py", line 37, in call [rank3]: assert_size_stride(arg0_1, (2005, 1, 5120), (5120, 10265600, 1)) [rank3]: AssertionError: wrong number of dimensions
Describe the bug
Finetune QwenVL with tensor parallel size = 4, using torch.compile for kernels optimization
Steps/Code to reproduce bug
Please list minimal steps or code snippet for us to be able to reproduce the bug.
model = model_provider_func() model = torch.compile(model)Expected behavior
performance acceleration w/ torch.compile on Megatron
Additional context
Error occur:
File "/tmp/torchinductor_root/ab/cab27rpc5syxsj7fvzyur3lov3xqznlnvvqq553eqxnzco5hq4rr.py", line 37, in call [rank3]: assert_size_stride(arg0_1, (2005, 1, 5120), (5120, 10265600, 1)) [rank3]: AssertionError: wrong number of dimensions