I haven't tried the new pytorch implementation of weight norm, but I was having problems applying a similar implementation to RNNs because the weight recomputation with each call to forward was causing me to run out of memory. The solution seemed to be to only recompute weights after each call to backward. The version I was having problems with and my fix are both here: https://gist.github.com/rtqichen/b22a9c6bfc4f36e605a7b3ac1ab4122f
(I should warn that I am not totally sure I didn't do something wrong in my implementation since the results I got from weight norm in my model were much worse than I expected.) Should the official pytorch implementation of weight norm also recompute weights only after a call to backward?
cc @VitalyFedyunin @ngimel @mruberry
I haven't tried the new pytorch implementation of weight norm, but I was having problems applying a similar implementation to RNNs because the weight recomputation with each call to
forwardwas causing me to run out of memory. The solution seemed to be to only recompute weights after each call tobackward. The version I was having problems with and my fix are both here: https://gist.github.com/rtqichen/b22a9c6bfc4f36e605a7b3ac1ab4122f(I should warn that I am not totally sure I didn't do something wrong in my implementation since the results I got from weight norm in my model were much worse than I expected.) Should the official pytorch implementation of weight norm also recompute weights only after a call to
backward?cc @VitalyFedyunin @ngimel @mruberry