Skip to content

Can't pickle torch.distributed.ProcessGroupNCCL objects #12168

@ezyang

Description

@ezyang

🐛 Bug

Re-filing from #11683

import argparse
import torch

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int)
    args = parser.parse_args()

    device = torch.device('cuda', args.local_rank)

    torch.distributed.init_process_group(backend='nccl')
    model = torch.nn.LSTM(10, 10).to(device)
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[args.local_rank], output_device=args.local_rank, dim=1)

    torch.save(model, 'model.pt')

Run with

python -m torch.distributed.launch test.py

Fails with

Traceback (most recent call last):
  File "test.py", line 16, in <module>
    torch.save(model, 'model.pt')
  File "/data/users/ezyang/pytorch-tmp/torch/serialization.py", line 209, in save
    return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
  File "/data/users/ezyang/pytorch-tmp/torch/serialization.py", line 134, in _with_file_like
    return body(f)
  File "/data/users/ezyang/pytorch-tmp/torch/serialization.py", line 209, in <lambda>
    return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
  File "/data/users/ezyang/pytorch-tmp/torch/serialization.py", line 282, in _save
    pickler.dump(obj)
TypeError: can't pickle torch.distributed.ProcessGroupNCCL objects

@teng-li says this is because ProcessGroupNCCL serialization is not implemented yet.

Metadata

Metadata

Assignees

Labels

oncall: distributedAdd this issue/PR to distributed oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions