Skip to content

should weight norm only recompute weights at the beginning and after each call to backward? #2176

@greaber

Description

@greaber

I haven't tried the new pytorch implementation of weight norm, but I was having problems applying a similar implementation to RNNs because the weight recomputation with each call to forward was causing me to run out of memory. The solution seemed to be to only recompute weights after each call to backward. The version I was having problems with and my fix are both here: https://gist.github.com/rtqichen/b22a9c6bfc4f36e605a7b3ac1ab4122f
(I should warn that I am not totally sure I didn't do something wrong in my implementation since the results I got from weight norm in my model were much worse than I expected.) Should the official pytorch implementation of weight norm also recompute weights only after a call to backward?

cc @VitalyFedyunin @ngimel @mruberry

Metadata

Metadata

Assignees

Labels

module: nnRelated to torch.nnmodule: performanceIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions