Skip to content

GPU memory consumption increases while training #1509

@EthanZhangYi

Description

@EthanZhangYi

Hello, all
I am new to Pytorch and I meet a strange GPU memory behavior while training a CNN model for semantic segmentation. Batchsize = 1, and there are totally 100 image-label pairs in trainset, thus 100 iterations per epoch. However the GPU memory consumption increases a lot at the first several iterations while training.

[Platform] GTX TITAN X (12G), CUDA-7.5, cuDNN-5.0

torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False

Then GPU memory consumption is 2934M -- 4413M -- 4433M -- 4537M -- 4537M -- 4537M at the first six iterations.

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

Then GPU memory consumption is 1686M -- 1791M -- 1791M -- 1791M -- 1791M -- 1791M at the first six iterations.

Why GPU memory consumption increases while training, especially, increases so largely while no cuDNN? (In my opinion, GPU memory consumption won't increase while the CNN has been build and starts training)

Does anyone meet the same problem? Or could anyone give some help?

This is the code snippet:

def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        target = target.long()
        input = input.cuda(async=True)
        target = target.cuda(async=True)
        input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target)

        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)

        # record loss
        losses.update(loss.data[0], input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                   epoch, i+1, len(train_loader),
                   batch_time=batch_time,
                   data_time=data_time,
                   loss=losses))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions