Skip to content

Manually unrolling cuDNN RNN OOM #914

@csarofeen

Description

@csarofeen

Manually unrolling cuDNN backend will cause memory usage to go sky high.

Unrolled non-cuDNN pytorch takes ~1.8GB mem.
Non-unrolled cuDNN can take ~3GB mem.
Manually unrolling over time in user script will take >12GB mem.

Important for attention models @bmccann. Repro code below.

import torch
import torch.optim as optim
import torch.nn as nn
import time
from torch.autograd import Variable
from time import sleep

#torch.backends.cudnn.enabled=False
#torch.cuda.set_device(1)

input_size = 1024
hidden_size = 1024*2
batch_size = 64
time_steps = 32*2
lr   = 4e-2

def initHidden(bsz):
    return (Variable(torch.cuda.FloatTensor(2, bsz, hidden_size).zero_()),
            Variable(torch.cuda.FloatTensor(2, bsz, hidden_size).zero_()))
def resetHidden(hidden):
    return (Variable(hidden[0].data.zero_()), Variable(hidden[1].data.zero_()))

model = nn.LSTM(input_size, hidden_size, num_layers=2).cuda()
input = torch.randn(time_steps, batch_size, input_size).cuda()
target = torch.randn(time_steps, batch_size, hidden_size).cuda()


optimizer = optim.SGD(model.parameters(),
                      lr = lr,
                      momentum=0.9,
                      dampening = 0.0,
                      weight_decay = 0.0
                     )

criterion = nn.MSELoss().cuda()

loss = 0

hidden = initHidden(batch_size)
input = Variable(input)
target = Variable(target, requires_grad=False)

for i in range(1):
    start = time.time()
    for epoch in range(9999):
        print(epoch)
        loss = 0
        model.zero_grad()
        optimizer.zero_grad()

        hidden=initHidden(batch_size)
#        output, hidden = model(input, hidden)
        outputs = []
        for j in range(input.size(0)):
            output, hidden = model(input[j].view(1, *input[j].size()), hidden)
            outputs.append(output)
        outputs = torch.cat(outputs, 0)
        output = outputs

        
        loss = criterion(output, target)
        loss.backward(retain_variables = True)
        
        optimizer.step()

    print("Test ran in " + str( time.time() - start) + " seconds")

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions