Skip to content

ConvNd function leaks memory #3835

@apaszke

Description

@apaszke

Originally reported by @dmarnerides in #3743:


I can reproduce the memory leak with this:

import gc
import resource
import torch
from torch import nn, autograd
from torch.autograd import Variable

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 1, 1, 1),
            nn.Conv2d(1, 1, 1, 1),
        )
    def forward(self, v_x):
        return self.main(v_x).view(v_x.size(0), 1)

net = Network()

i = 0
while True:
    v_in = Variable(torch.Tensor(2,1,1,1), requires_grad=True)
    grad_out = Variable(torch.ones(2,1,1,1))

    gradient = autograd.grad(outputs=net(v_in), inputs=v_in,
                             grad_outputs=grad_out,
                             create_graph=True, retain_graph=True, 
                             only_inputs=True)[0]
    gradient.mean().backward()

    i += 1
    if i % 512 == 0:
        gc.collect()
        max_mem_used = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
        print("{:.2f} MB".format(max_mem_used / 1024))

However the leak disappears when only one convolution is used:

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 1, 1, 1),
        )
    def forward(self, v_x):
        return self.main(v_x).view(v_x.size(0), 1)

Also, if nn.Linear is used, there is no leak.

Torch version: 0.4.0a0+8ebf18b

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions