Skip to content

Memory leak after OOM (maybe RRelu specific) #38966

@Roffild

Description

@Roffild

🐛 Bug

Pytorch 1.5
Python 3.7
Windows 10

NVIDIA RTX 2080 8GB (6GB)

You need to catch the exception "RuntimeError: CUDA out of memory." at torch.rrelu().

import torch

model = torch.nn.Sequential(
    torch.nn.Linear(in_features=1746, out_features=500, bias=True),
    torch.nn.RReLU(lower=0.125, upper=0.3333333333333333),
    torch.nn.Linear(in_features=500, out_features=100, bias=True),
    torch.nn.RReLU(lower=0.125, upper=0.3333333333333333),
    torch.nn.Linear(in_features=100, out_features=2, bias=True)
)
dev = torch.device("cuda")
mem = 480593
intrain = torch.rand((int(mem), 1746), device=dev, dtype=torch.float32)
outtrain = torch.rand((int(mem), 2), device=dev, dtype=torch.float32)
model.to(device=dev)
# model.train()
before = torch.cuda.memory_allocated(dev)
try:
    model(intrain)
except:
    # raise
    pass
after = torch.cuda.memory_allocated(dev)
print("Memory:")
print(before, "- before")
print(after, "- after")
Memory:
3364003328 - before
6440500224 - after

cc @ezyang @gchanan @zou3519 @ngimel

Metadata

Metadata

Assignees

Labels

high prioritymodule: cudaRelated to torch.cuda, and CUDA support in generalmodule: memory usagePyTorch is using more memory than it should, or it is leaking memorytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions