Skip to content

[Refactor] Refactor backward interface in DP#141

Merged
nijkah merged 1 commit intoEleutherAI:distributed_data_parallelfrom
nijkah:ddp_backward
Mar 2, 2023
Merged

[Refactor] Refactor backward interface in DP#141
nijkah merged 1 commit intoEleutherAI:distributed_data_parallelfrom
nijkah:ddp_backward

Conversation

@nijkah
Copy link
Copy Markdown
Contributor

@nijkah nijkah commented Mar 2, 2023

Title

Refactor backward in DP

Description

Followed https://github.com/KKIEEK/oslo/blob/3ca6b1aa0d87688af891f12b22837d89847680e9/oslo/torch/nn/parallel/data_parallel/distributed_data_parallel.py#L96.
And committed as KKIEEK for the code ownership.

Test Script

Torch DDP

import os
import torch.multiprocessing as mp

import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch import nn
from torch import optim
import torch.distributed as dist

from oslo.torch.distributed.parallel_context import ParallelContext


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["LOCAL_WORLD_SIZE"] = str(world_size)


def cleanup():
    dist.destroy_process_group()

def main_print(args):
    if dist.get_rank() != 0:
        return
    print(args)

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def train(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)
    parallel_context = ParallelContext.from_torch(data_parallel_size=world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    main_print('= before model =')
    main_print(model.net1.weight)

    optimizer.zero_grad()
    outputs = ddp_model(torch.ones(20, 10).to(rank))
    labels = torch.ones(20, 5).to(rank)
    loss_fn(outputs, labels).backward()
    optimizer.step()
    main_print('= after model =')
    main_print(ddp_model.module.net1.weight)
    main_print('= output =')
    main_print(outputs[0])
    cleanup()


def main(world_size):
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main(2)

OSLO DDP

import os
import torch.multiprocessing as mp

import torch
from torch import nn
from torch import optim
import torch.distributed as dist

import oslo
from oslo.torch.distributed.parallel_context import ParallelContext
from oslo.torch.nn.parallel.data_parallel import DistributedDataParallel as DDP


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12345"
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["LOCAL_WORLD_SIZE"] = str(world_size)


def cleanup():
    dist.destroy_process_group()

def main_print(args):
    if dist.get_rank() != 0:
        return
    print(args)


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def train(rank, world_size):
    print(f"Running oslo DDP example on rank {rank}.")
    setup(rank, world_size)
    parallel_context = ParallelContext.from_torch(data_parallel_size=world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, parallel_context)

    loss_fn = nn.MSELoss()
    main_print('= before model =')
    main_print(model.net1.weight)
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)
    oslo.ready(ddp_model, parallel_context)
    optimizer.zero_grad()
    outputs = ddp_model(torch.ones(20, 10).to(rank))
    labels = torch.ones(20, 5).to(rank)
    loss = loss_fn(outputs, labels)
    loss.backward()
    optimizer.step()
    main_print('= after model =')
    main_print(ddp_model.net1.weight)
    main_print('= output =')
    main_print(outputs[0])
    cleanup()


def main(world_size):
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main(2)

Results

Running basic DDP example on rank 0.
Running basic DDP example on rank 1.
= before model =
Parameter containing:
tensor([[ 0.2418,  0.2625, -0.0741,  0.2905, -0.0693,  0.0638, -0.1540,  0.1857,
          0.2788, -0.2320],
        [ 0.2749,  0.0592,  0.2336,  0.0428,  0.1525, -0.0446,  0.2438,  0.0467,
         -0.1476,  0.0806],
        [-0.1457, -0.0371, -0.1284,  0.2098, -0.2496, -0.1458, -0.0893, -0.1901,
          0.0298, -0.3123],
        [ 0.2856, -0.2686,  0.2441,  0.0526, -0.1027,  0.1954,  0.0493,  0.2555,
          0.0346, -0.0997],
        [ 0.0850, -0.0858,  0.1331,  0.2823,  0.1828, -0.1382,  0.1825,  0.0566,
          0.1606, -0.1927],
        [-0.3130, -0.1222, -0.2426,  0.2595,  0.0911,  0.1310,  0.1000, -0.0055,
          0.2475, -0.2247],
        [ 0.0199, -0.2158,  0.0975, -0.1089,  0.0969, -0.0659,  0.2623, -0.1874,
         -0.1886, -0.1886],
        [ 0.2844,  0.1054,  0.3043, -0.2610, -0.3137, -0.2474, -0.2127,  0.1281,
          0.1132,  0.2628],
        [-0.1633, -0.2156,  0.1678, -0.1278,  0.1919, -0.0750,  0.1809, -0.2457,
         -0.1596,  0.0964],
        [ 0.0669, -0.0806,  0.1885,  0.2150, -0.2293, -0.1688,  0.2896, -0.1067,
         -0.1121, -0.3060]], device='cuda:0', requires_grad=True)
= after model =
Parameter containing:
tensor([[ 0.2445,  0.2652, -0.0714,  0.2932, -0.0666,  0.0665, -0.1513,  0.1884,
          0.2815, -0.2293],
        [ 0.2726,  0.0569,  0.2313,  0.0405,  0.1502, -0.0470,  0.2415,  0.0444,
         -0.1499,  0.0783],
        [-0.1457, -0.0371, -0.1284,  0.2098, -0.2496, -0.1458, -0.0893, -0.1901,
          0.0298, -0.3123],
        [ 0.2819, -0.2723,  0.2405,  0.0489, -0.1064,  0.1917,  0.0456,  0.2518,
          0.0309, -0.1034],
        [ 0.0837, -0.0871,  0.1318,  0.2810,  0.1815, -0.1395,  0.1812,  0.0553,
          0.1593, -0.1940],
        [-0.3130, -0.1222, -0.2426,  0.2595,  0.0911,  0.1310,  0.1000, -0.0055,
          0.2475, -0.2247],
        [ 0.0199, -0.2158,  0.0975, -0.1089,  0.0969, -0.0659,  0.2623, -0.1874,
         -0.1886, -0.1886],
        [ 0.2844,  0.1054,  0.3043, -0.2610, -0.3137, -0.2474, -0.2127,  0.1281,
          0.1132,  0.2628],
        [-0.1633, -0.2156,  0.1678, -0.1278,  0.1919, -0.0750,  0.1809, -0.2457,
         -0.1596,  0.0964],
        [ 0.0669, -0.0806,  0.1885,  0.2150, -0.2293, -0.1688,  0.2896, -0.1067,
         -0.1121, -0.3060]], device='cuda:0', requires_grad=True)
= output =
tensor([ 0.2793, -0.2856,  0.3405, -0.6016,  0.1381], device='cuda:0',
       grad_fn=<SelectBackward0>)



Running oslo DDP example on rank 1.
Running oslo DDP example on rank 0.
= before model =
Parameter containing:
tensor([[ 0.2418,  0.2625, -0.0741,  0.2905, -0.0693,  0.0638, -0.1540,  0.1857,
          0.2788, -0.2320],
        [ 0.2749,  0.0592,  0.2336,  0.0428,  0.1525, -0.0446,  0.2438,  0.0467,
         -0.1476,  0.0806],
        [-0.1457, -0.0371, -0.1284,  0.2098, -0.2496, -0.1458, -0.0893, -0.1901,
          0.0298, -0.3123],
        [ 0.2856, -0.2686,  0.2441,  0.0526, -0.1027,  0.1954,  0.0493,  0.2555,
          0.0346, -0.0997],
        [ 0.0850, -0.0858,  0.1331,  0.2823,  0.1828, -0.1382,  0.1825,  0.0566,
          0.1606, -0.1927],
        [-0.3130, -0.1222, -0.2426,  0.2595,  0.0911,  0.1310,  0.1000, -0.0055,
          0.2475, -0.2247],
        [ 0.0199, -0.2158,  0.0975, -0.1089,  0.0969, -0.0659,  0.2623, -0.1874,
         -0.1886, -0.1886],
        [ 0.2844,  0.1054,  0.3043, -0.2610, -0.3137, -0.2474, -0.2127,  0.1281,
          0.1132,  0.2628],
        [-0.1633, -0.2156,  0.1678, -0.1278,  0.1919, -0.0750,  0.1809, -0.2457,
         -0.1596,  0.0964],
        [ 0.0669, -0.0806,  0.1885,  0.2150, -0.2293, -0.1688,  0.2896, -0.1067,
         -0.1121, -0.3060]], device='cuda:0', requires_grad=True)
= after model =
Parameter containing:
tensor([[ 0.2445,  0.2652, -0.0714,  0.2932, -0.0666,  0.0665, -0.1513,  0.1884,
          0.2815, -0.2293],
        [ 0.2726,  0.0569,  0.2313,  0.0405,  0.1502, -0.0470,  0.2415,  0.0444,
         -0.1499,  0.0783],
        [-0.1457, -0.0371, -0.1284,  0.2098, -0.2496, -0.1458, -0.0893, -0.1901,
          0.0298, -0.3123],
        [ 0.2819, -0.2723,  0.2405,  0.0489, -0.1064,  0.1917,  0.0456,  0.2518,
          0.0309, -0.1034],
        [ 0.0837, -0.0871,  0.1318,  0.2810,  0.1815, -0.1395,  0.1812,  0.0553,
          0.1593, -0.1940],
        [-0.3130, -0.1222, -0.2426,  0.2595,  0.0911,  0.1310,  0.1000, -0.0055,
          0.2475, -0.2247],
        [ 0.0199, -0.2158,  0.0975, -0.1089,  0.0969, -0.0659,  0.2623, -0.1874,
         -0.1886, -0.1886],
        [ 0.2844,  0.1054,  0.3043, -0.2610, -0.3137, -0.2474, -0.2127,  0.1281,
          0.1132,  0.2628],
        [-0.1633, -0.2156,  0.1678, -0.1278,  0.1919, -0.0750,  0.1809, -0.2457,
         -0.1596,  0.0964],
        [ 0.0669, -0.0806,  0.1885,  0.2150, -0.2293, -0.1688,  0.2896, -0.1067,
         -0.1121, -0.3060]], device='cuda:0', requires_grad=True)
= output =
tensor([ 0.2793, -0.2856,  0.3405, -0.6016,  0.1381], device='cuda:0',
       grad_fn=<SelectBackward0>)

@nijkah nijkah requested a review from jinwonkim93 March 2, 2023 07:09
@nijkah nijkah changed the title Refactor backward in DP [Refactor] Refactor backward interface in DP Mar 2, 2023
@jinwonkim93
Copy link
Copy Markdown
Member

We have discussed this issue on discord. LGTM

Copy link
Copy Markdown
Member

@jinwonkim93 jinwonkim93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@nijkah nijkah marked this pull request as ready for review March 2, 2023 07:32
@nijkah nijkah merged commit b49092c into EleutherAI:distributed_data_parallel Mar 2, 2023
@nijkah nijkah deleted the ddp_backward branch March 2, 2023 07:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants