Skip to content

Instantly share code, notes, and snippets.

@vkuzo
Created September 2, 2020 17:13
Show Gist options
  • Select an option

  • Save vkuzo/8501121d0d25261b70110aae4465ae0a to your computer and use it in GitHub Desktop.

Select an option

Save vkuzo/8501121d0d25261b70110aae4465ae0a to your computer and use it in GitHub Desktop.
import os
import random
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
random.seed(0)
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def demo_basic(rank, world_size):
print(f"Starting rank {rank}.")
setup(rank, world_size)
model = nn.Sequential(
nn.Conv2d(1, 1, 1),
nn.BatchNorm2d(1),
)
process_ids = [0, 1, 2, 3] if rank < 4 else [4, 5, 6, 7]
process_group = torch.distributed.new_group(process_ids)
print(rank, process_ids, process_group)
syncbn_model = nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group)
syncbn_model.to(rank)
ddp_model = DDP(syncbn_model, device_ids=[rank], process_group=process_group)
data = torch.randn(4, 1, 4, 4).to(rank)
ddp_model(data)
print('done')
cleanup()
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__ == "__main__":
run_demo(demo_basic, 8)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment