Adds truncated normal initializer#32397
Conversation
💊 CircleCI build failures summary and remediationsAs of commit bd41aeb (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no CircleCI failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker. This comment has been revised 19 times. |
|
Tentatively assigning review to @alicanb, let me know if you need someone else to look |
|
I am happy to do it after the ICML deadline :D |
|
This mostly looks good. Maybe one thing I would add is a warning if bounds are too far away from the center (I actually don't know when it will break, it might be a good idea to test that and place the warnings accordingly). I see you're only testing truncnorm(0,1,-2, 2) in |
|
@alicanb Thanks for the feedback. I experimented with randomly selected With that, I will add a warning if either Below is the code I used to test various results. I have also attached some example failures. The currently chosen range for values of |
|
I updated the warning message to be more descriptive and report at the appropriate stack level. The current message triggers if the mean is too far from the interval. For example, |
|
@alicanb let me know when to land! |
|
LGTM, sorry I forgot about this. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Great! So this will be out in the next PyTorch minor version? Any estimate for when that'll be? Thanks! |
Summary: This adds the `trunc_normal_` function to `torch.nn.init` which allows for modifying tensors in-place to values drawn from a truncated normal distribution. I chose to use the inverse CDF method to implement this. I have included the appropriate code in `test_nn.py` for verifying that the values are from the correct distribution. Reasons I chose this method: 1. Easily implemented to operate on memory in place, as the other initializers are. 1. No resampling delays 1. This method's main weakness is unlikely to be an issue. While the inverse CDF method can fail to generate the correct distribution when `b < mean` or `mean < a`, I expect users will choose `a` and `b` so that `a < mean < b`. This method is extremely effective in this case. Pull Request resolved: pytorch#32397 Differential Revision: D20550996 Pulled By: ezyang fbshipit-source-id: 298a325043a3fd7d1e24d266e3b9b6cc14f81829
This adds the
trunc_normal_function totorch.nn.initwhich allows for modifying tensors in-place to values drawn from a truncated normal distribution. I chose to use the inverse CDF method to implement this. I have included the appropriate code intest_nn.pyfor verifying that the values are from the correct distribution.Reasons I chose this method:
b < meanormean < a, I expect users will chooseaandbso thata < mean < b. This method is extremely effective in this case.