import argparse
import torch
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int)
args = parser.parse_args()
device = torch.device('cuda', args.local_rank)
torch.distributed.init_process_group(backend='nccl')
model = torch.nn.LSTM(10, 10).to(device)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, dim=1)
torch.save(model, 'model.pt')
python -m torch.distributed.launch test.py
Traceback (most recent call last):
File "test.py", line 16, in <module>
torch.save(model, 'model.pt')
File "/data/users/ezyang/pytorch-tmp/torch/serialization.py", line 209, in save
return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
File "/data/users/ezyang/pytorch-tmp/torch/serialization.py", line 134, in _with_file_like
return body(f)
File "/data/users/ezyang/pytorch-tmp/torch/serialization.py", line 209, in <lambda>
return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
File "/data/users/ezyang/pytorch-tmp/torch/serialization.py", line 282, in _save
pickler.dump(obj)
TypeError: can't pickle torch.distributed.ProcessGroupNCCL objects
🐛 Bug
Re-filing from #11683
Run with
Fails with
@teng-li says this is because ProcessGroupNCCL serialization is not implemented yet.