Skip to content

Free Memory after CUDA out of memory error #27600

@thequilo

Description

@thequilo

🐛 Bug

Sometimes, PyTorch does not free memory after a CUDA out of memory exception.

To Reproduce

Consider the following function:

import torch

def oom():
    try:
        x = torch.randn(100, 10000, device=1)
        for i in range(100):
            l = torch.nn.Linear(10000, 10000)
            l.to(1)
            x = l(x)
    except RuntimeError as e:
        print(e)
        print('at iteration', i)

Executing it one time gives the expected out of memory error after some iterations:

>>> oom()
CUDA out of memory. Tried to allocate 381.50 MiB (GPU 1; 7.92 GiB total capacity; 7.16 GiB already allocated; 231.00 MiB free; 452.50 KiB cached)
at iteration 19

Executing a second time gives a OOM error immediately after the first iteration, which means that the memory consumed by the scoped variables in x is still occupied (which is a little weird):

>>> oom()
CUDA out of memory. Tried to allocate 381.50 MiB (GPU 1; 7.92 GiB total capacity; 7.16 GiB already allocated; 231.00 MiB free; 452.50 KiB cached)
at iteration 0

Calling gc.collect() now sometimes (!!) leads to freeing the memory and sometimes it doesn't.

Expected behavior

I expected a consistent behavior that frees the memory after the OOM exception occurred or at least after gc.collect() gets called.

Environment

Collecting environment information...
PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: CentOS Linux release 7.3.1611 (Core)
GCC version: (GCC) 4.8.5 20150623 (Red Hat 4.8.5-11)
CMake version: version 2.8.12.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce GTX 1080
GPU 1: GeForce GTX 1080

Nvidia driver version: 387.26
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.16.4
[pip] pytorch-wpe==0.0.0
[pip] torch==1.0.0
[pip] torch-complex==0.0.1
[conda] blas 1.0 mkl
[conda] cuda90 1.0 h6433d27_0 pytorch
[conda] cuda91 1.0 h4c16780_0 pytorch
[conda] mkl 2019.4 243
[conda] mkl-service 2.3.0 py37he904b0f_0
[conda] mkl_fft 1.0.14 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch 1.0.0 py3.7_cuda9.0.176_cudnn7.4.1_1 pytorch
[conda] pytorch-wpe 0.0.0 pypi_0 pypi
[conda] torch-complex 0.0.1 pypi_0

Additional context

I stumbled upon this because I tried to fallback to CPU for computation of a single batch after a OOM error. I noticed that after the computation on the CPU, still the GPU memory was blocked sometimes and caused all following batches to be computed on the CPU as well.

cc @ezyang @gchanan @zou3519

Metadata

Metadata

Assignees

Labels

module: cudaRelated to torch.cuda, and CUDA support in generalmodule: memory usagePyTorch is using more memory than it should, or it is leaking memorytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions