Bug description
Reported by @tianyu-l, breaks CI
NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_train.sh --module llama4 --config llama4_debugmodel --parallelism.pipeline_parallel_degree 2 --parallelism.pipeline_parallel_schedule Interleaved1F1B --parallelism.data_parallel_shard_degree 2 --parallelism.tensor_parallel_degree 2 --parallelism.expert_parallel_degree 4 --parallelism.expert_tensor_parallel_degree 1 --compile.enable
Logs: https://gist.github.com/xmfan/8ebfad161cd02af9c1ae5c5818799f43
[rank1]:[rank1]: Traceback (most recent call last):
[rank1]:[rank1]: File "/home/xmfan/core/a/pytorch/torch/distributed/pipelining/_backward.py", line 368, in stage_backward
[rank1]:[rank1]: torch.autograd.backward(
[rank1]:[rank1]: File "/home/xmfan/core/a/pytorch/torch/autograd/__init__.py", line 379, in backward
[rank1]:[rank1]: _engine_run_backward(
[rank1]:[rank1]: File "/home/xmfan/core/a/pytorch/torch/autograd/graph.py", line 877, in _engine_run_backward
[rank1]:[rank1]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/home/xmfan/core/a/pytorch/torch/utils/checkpoint.py", line 1168, in unpack_hook
[rank1]:[rank1]: frame.check_recomputed_tensors_match(gid)
[rank1]:[rank1]: File "/home/xmfan/core/a/pytorch/torch/utils/checkpoint.py", line 899, in check_recomputed_tensors_match
[rank1]:[rank1]: raise CheckpointError(
[rank1]:[rank1]: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
[rank1]:[rank1]: tensor at position 3:
[rank1]:[rank1]: saved metadata: {'shape': torch.Size([1, 1, 16, 16]), 'dtype': torch.int32, 'device': device(type='cuda', index=1)}
[rank1]:[rank1]: recomputed metadata: {'shape': torch.Size([128, 256]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
[rank1]:[rank1]: tensor at position 4:
Versions
main
Bug description
Reported by @tianyu-l, breaks CI
Logs: https://gist.github.com/xmfan/8ebfad161cd02af9c1ae5c5818799f43
Versions
main