Skip to content

Resolve #25605 cyclic reference in _LRScheduler#25776

Closed
huzecong wants to merge 7 commits intopytorch:masterfrom
huzecong:fix-scheduler
Closed

Resolve #25605 cyclic reference in _LRScheduler#25776
huzecong wants to merge 7 commits intopytorch:masterfrom
huzecong:fix-scheduler

Conversation

@huzecong
Copy link
Copy Markdown
Contributor

@huzecong huzecong commented Sep 6, 2019

Cyclic reference was introduced in a previous version due to runtime overwriting of the bound method optimizer.step. This is now avoided by keeping a weak reference to the optimizer instance.

Credit: https://stackoverflow.com/questions/26157952/why-set-a-bound-method-to-python-object-create-a-circular-reference

Cyclic reference was introduced in a previous version due to runtime
overwriting of the bound method `optimizer.step`. This is now avoided
by keeping a weak reference to the optimizer instance.
@pytorchbot pytorchbot added the module: optimizer Related to torch.optim label Sep 6, 2019
@soumith soumith requested a review from vincentqb September 6, 2019 15:34
@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Sep 6, 2019

Thank you! Do you think there is a way you can conveniently test this (e.g., by setting __del__ on the optimizer and testing that it is promptly disposed?)

@ezyang ezyang self-requested a review September 6, 2019 16:05
Comment thread torch/optim/lr_scheduler.py Outdated
@vincentqb
Copy link
Copy Markdown
Contributor

Thank you! Do you think there is a way you can conveniently test this (e.g., by setting __del__ on the optimizer and testing that it is promptly disposed?)

like this?

@huzecong
Copy link
Copy Markdown
Contributor Author

huzecong commented Sep 6, 2019

@ezyang I can add a unit test similar to the snippet I included in the issue. Would that do?

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Sep 6, 2019

Yes that would be great, ty

When multiple schedulers are constructed for the same optimizer,
`optim.step` is only wrapped once.
def wrapper(*args, **kwargs):
opt._step_count += 1
return func(*args, **kwargs)
instance = instance_ref()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it provable that the optimizer instance is always live at this point? If not maybe we should check the return result of instance...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be, since this function will only be accessible through optimizer.step().

@huzecong
Copy link
Copy Markdown
Contributor Author

huzecong commented Sep 6, 2019

#17630 could be due to having cyclic reference, which prevents optimizer and parameters from garbage collected when del is called.

For #9942 store a partial function to a bound method would probably also create a self-reference. TBH I don't really see the need for a partial function here; it's perfectly fine to check self.mode and such in _cmp(). Also, _cmp() can just be a staticmethod.

Should I also fix this in the same PR?

@vincentqb
Copy link
Copy Markdown
Contributor

For #9942 store a partial function to a bound method would probably also create a self-reference. TBH I don't really see the need for a partial function here; it's perfectly fine to check self.mode and such in _cmp(). Also, _cmp() can just be a staticmethod.

Good point: might as well remove partial and let _cmp() access self.*. Even if it doesn't have a negative performance impact, removing it makes the code more readable :)

Should I also fix this in the same PR?

I'm leaning for a separate one, but I'm ok either way that is convenient for you, thanks!

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Sep 9, 2019

The test fails, btw

Sep 06 20:43:36 ======================================================================
Sep 06 20:43:36 FAIL: test_no_cyclic_references (__main__.TestLRScheduler)
Sep 06 20:43:36 ----------------------------------------------------------------------
Sep 06 20:43:36 Traceback (most recent call last):
Sep 06 20:43:36   File "test_optim.py", line 554, in test_no_cyclic_references
Sep 06 20:43:36     "Optimizer should contain no cyclic references")
Sep 06 20:43:36 AssertionError: False is not true : Optimizer should contain no cyclic references

Copy link
Copy Markdown
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fails its test

Python `gc` module behavior changed in version 3.7. Prior to 3.7,
when calling `gc.get_referrers` on a local variable in a function, the
current frame would be included in the returned list.
ReduceLROnPlateau would previously store a partial function over a
instance method, which creates a cyclic reference. Refactored the code
to directly use attributes instead of creating a partial function.
@huzecong
Copy link
Copy Markdown
Contributor Author

Sorry for the delay, been busy during the week.

The test error was due to a change in Python version 3.7. Prior to 3.7, gc.get_referrers would include the current frame if called on a local variable in a function. I've only been testing on Python 3.7 so I didn't notice this locally. My bad.

I've also addressed #9942 (fix ReduceLROnPlateau) in a separate commit.

Didn't realize that `ReduceLROnPlateau` inherited from `object` rather
than `_LRScheduler`... Might leave that change to the future.
@huzecong huzecong requested a review from ezyang September 15, 2019 14:22
@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Sep 16, 2019

@pytorchbot rebase this please

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Sep 16, 2019

Just waiting on CI

Copy link
Copy Markdown
Contributor

@vincentqb vincentqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again!

@huzecong
Copy link
Copy Markdown
Contributor Author

Cool! Glad to help!

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@huzecong huzecong deleted the fix-scheduler branch September 18, 2019 13:56
@facebook-github-bot
Copy link
Copy Markdown
Contributor

@ezyang merged this pull request in b8ae4d0.

@mcarilli
Copy link
Copy Markdown
Collaborator

mcarilli commented Oct 14, 2019

@huzecong This stuff is really cool and highly instructive to a chronic Cuda programmer/beginner Pythonista. Is there a reason you chose to grab a manual weakref to the optimizer instance instead of using types.MethodType? In your original issue you suggested types.MethodType resolves the circular reference issue (and it certainly seems simpler than manual reference juggling) but types.MethodType is never mentioned again in either the issue or PR discussion afaict.

Does types.MethodType actually resolve the circular reference issue? I'm very interested because I'm a frequent (ab)user of it. I realize that generally, functions are descriptors that become MethodTypes instances whenever they're retrieved and called (as in optimizer.func()) but the MethodType instance is a temporary that's only used for the call itself. Even if the MethodType instance does store a direct non-weak reference to optimizer, there's no danger of a long-lived reference to instance surviving as a result of the optimizer.func() invocation, because the MethodType instance dies immediately. My point is, since for typical use MethodType instances are short-lived, I could imagine that the implementation of MethodType simply stores a non-weak reference to optimizer, in which case

def wrapper(self, args....):
    ...
optimizer.step = types.MethodType(wrapper, optimizer)

would in fact create a non-weak reference cycle for optimizer. Conversely, it's also possible that MethodType is smart enough to store a weakref to optimizer, in which case it would be safe to use, in which case I'm curious why you didn't use it. I can dig into this some more myself but I'm very interested to hear your take...

@huzecong
Copy link
Copy Markdown
Contributor Author

huzecong commented Oct 14, 2019

@mcarilli Sorry if the previous issue confused you, but yes you're right, types.MethodType does not solve the circular reference issue because it keeps a non-weak reference to the object. This is also mentioned in the link I posted in the PR message, which shows that:

obj = Foo()
obj.func = obj.func

saves the MethodType instance to the object and that resulted in a circular reference. You can also see this from gc.get_referrers(obj) which will list <bound method func of <Foo object at ...>> as a referrer of obj.

But anyway, this isn't a big problem because Python has mechanisms (albeit more expensive) to deal with circular references. For simple personal scripts, it should be fine as long as the object in question is not holding too much memory.

@mcarilli
Copy link
Copy Markdown
Collaborator

mcarilli commented Oct 14, 2019

Makes sense, thanks very much! My only remaining point of confusion is then: why did you suggest

The self-reference could be easily prevented by using the self argument and wrapping the created function with types.MethodType(wrapper, opt).

in the original issue? Did you originally assume (as I did) that MethodType was smart enough to store a weakref to self rather than an ordinary non-weak ref?

@huzecong
Copy link
Copy Markdown
Contributor Author

Yeah, that's pretty much what I thought at the beginning :D

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Cyclic reference was introduced in a previous version due to runtime overwriting of the bound method `optimizer.step`. This is now avoided by keeping a weak reference to the optimizer instance.

Credit: https://stackoverflow.com/questions/26157952/why-set-a-bound-method-to-python-object-create-a-circular-reference
Pull Request resolved: pytorch#25776

Differential Revision: D17420770

Pulled By: ezyang

fbshipit-source-id: 546ec94cf725ebfddb310b24e6a2e146ddecd1f6
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants