Skip to content

Reference cycle in _LRScheduler #25605

@huzecong

Description

@huzecong

🐛 Bug

To Reproduce

In the following code introduced in version 1.2.0 by #20124, when constructing the _LRScheduler, the step method of optimizer is replaced with a new closure that captures the optimizer itself.

# Following https://github.com/pytorch/pytorch/issues/20124
# We would like to ensure that `lr_scheduler.step()` is called after
# `optimizer.step()`
def with_counter(func, opt):
@wraps(func)
def wrapper(*args, **kwargs):
opt._step_count += 1
return func(*args, **kwargs)
wrapper._with_counter = True
return wrapper
self.optimizer.step = with_counter(self.optimizer.step, self.optimizer)

This makes the optimizer reference itself, thus creating a reference cycle that prevents it from being garbage collected.

To verify this, run the following snippet:

import torch
from torch import nn

import gc

def main():
    param = nn.Parameter(torch.randn(10))

    optim = torch.optim.Adam([param])
    scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lambda epoch: 1.0)
    del scheduler

    print(gc.get_referrers(optim))
    
    gc.collect()
    del optim
    print(gc.collect())

if __name__ == '__main__':
    main()

The output is:

[<bound method Adam.step of Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    initial_lr: 0.001
    lr: 0.001
    weight_decay: 0
)>, <cell at 0x109354948: Adam object at 0x10940d8d0>]
12

which shows that optim.step is a referrer of optim, and it's not collected automatically.

If the scheduler is not created, the output becomes:

[]
0

which shows that no referrers of optim exist and it's GC-ed properly.

Expected behavior

The self-reference could be easily prevented by using the self argument and wrapping the created function with types.MethodType(wrapper, opt).

Environment

  • PyTorch Version: 1.2.0
  • Other information probably irrelevant.

cc @ezyang @gchanan @zou3519 @vincentqb

Metadata

Metadata

Assignees

Labels

high prioritymodule: optimizerRelated to torch.optimtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions