Skip to content

PP + Compile breaking CI #2771

@xmfan

Description

@xmfan

Bug description

Reported by @tianyu-l, breaks CI

NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_train.sh --module llama4 --config llama4_debugmodel --parallelism.pipeline_parallel_degree 2 --parallelism.pipeline_parallel_schedule Interleaved1F1B --parallelism.data_parallel_shard_degree 2 --parallelism.tensor_parallel_degree 2 --parallelism.expert_parallel_degree 4 --parallelism.expert_tensor_parallel_degree 1 --compile.enable

Logs: https://gist.github.com/xmfan/8ebfad161cd02af9c1ae5c5818799f43

[rank1]:[rank1]: Traceback (most recent call last):
[rank1]:[rank1]:   File "/home/xmfan/core/a/pytorch/torch/distributed/pipelining/_backward.py", line 368, in stage_backward
[rank1]:[rank1]:     torch.autograd.backward(
[rank1]:[rank1]:   File "/home/xmfan/core/a/pytorch/torch/autograd/__init__.py", line 379, in backward
[rank1]:[rank1]:     _engine_run_backward(
[rank1]:[rank1]:   File "/home/xmfan/core/a/pytorch/torch/autograd/graph.py", line 877, in _engine_run_backward
[rank1]:[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/home/xmfan/core/a/pytorch/torch/utils/checkpoint.py", line 1168, in unpack_hook
[rank1]:[rank1]:     frame.check_recomputed_tensors_match(gid)
[rank1]:[rank1]:   File "/home/xmfan/core/a/pytorch/torch/utils/checkpoint.py", line 899, in check_recomputed_tensors_match
[rank1]:[rank1]:     raise CheckpointError(
[rank1]:[rank1]: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
[rank1]:[rank1]: tensor at position 3:
[rank1]:[rank1]: saved metadata: {'shape': torch.Size([1, 1, 16, 16]), 'dtype': torch.int32, 'device': device(type='cuda', index=1)}
[rank1]:[rank1]: recomputed metadata: {'shape': torch.Size([128, 256]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
[rank1]:[rank1]: tensor at position 4:

Versions

main

Metadata

Metadata

Assignees

Type

No fields configured for Bug.

Projects

Status
Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions