Manually unrolling cuDNN backend will cause memory usage to go sky high.
Unrolled non-cuDNN pytorch takes ~1.8GB mem.
Non-unrolled cuDNN can take ~3GB mem.
Manually unrolling over time in user script will take >12GB mem.
Important for attention models @bmccann. Repro code below.
import torch
import torch.optim as optim
import torch.nn as nn
import time
from torch.autograd import Variable
from time import sleep
#torch.backends.cudnn.enabled=False
#torch.cuda.set_device(1)
input_size = 1024
hidden_size = 1024*2
batch_size = 64
time_steps = 32*2
lr = 4e-2
def initHidden(bsz):
return (Variable(torch.cuda.FloatTensor(2, bsz, hidden_size).zero_()),
Variable(torch.cuda.FloatTensor(2, bsz, hidden_size).zero_()))
def resetHidden(hidden):
return (Variable(hidden[0].data.zero_()), Variable(hidden[1].data.zero_()))
model = nn.LSTM(input_size, hidden_size, num_layers=2).cuda()
input = torch.randn(time_steps, batch_size, input_size).cuda()
target = torch.randn(time_steps, batch_size, hidden_size).cuda()
optimizer = optim.SGD(model.parameters(),
lr = lr,
momentum=0.9,
dampening = 0.0,
weight_decay = 0.0
)
criterion = nn.MSELoss().cuda()
loss = 0
hidden = initHidden(batch_size)
input = Variable(input)
target = Variable(target, requires_grad=False)
for i in range(1):
start = time.time()
for epoch in range(9999):
print(epoch)
loss = 0
model.zero_grad()
optimizer.zero_grad()
hidden=initHidden(batch_size)
# output, hidden = model(input, hidden)
outputs = []
for j in range(input.size(0)):
output, hidden = model(input[j].view(1, *input[j].size()), hidden)
outputs.append(output)
outputs = torch.cat(outputs, 0)
output = outputs
loss = criterion(output, target)
loss.backward(retain_variables = True)
optimizer.step()
print("Test ran in " + str( time.time() - start) + " seconds")
Manually unrolling cuDNN backend will cause memory usage to go sky high.
Unrolled non-cuDNN pytorch takes ~1.8GB mem.
Non-unrolled cuDNN can take ~3GB mem.
Manually unrolling over time in user script will take >12GB mem.
Important for attention models @bmccann. Repro code below.