Skip to content

Adds truncated normal initializer#32397

Closed
Enealor wants to merge 21 commits intopytorch:masterfrom
Enealor:trunc_norm
Closed

Adds truncated normal initializer#32397
Enealor wants to merge 21 commits intopytorch:masterfrom
Enealor:trunc_norm

Conversation

@Enealor
Copy link
Copy Markdown
Contributor

@Enealor Enealor commented Jan 18, 2020

This adds the trunc_normal_ function to torch.nn.init which allows for modifying tensors in-place to values drawn from a truncated normal distribution. I chose to use the inverse CDF method to implement this. I have included the appropriate code in test_nn.py for verifying that the values are from the correct distribution.

Reasons I chose this method:

  1. Easily implemented to operate on memory in place, as the other initializers are.
  2. No resampling delays
  3. This method's main weakness is unlikely to be an issue. While the inverse CDF method can fail to generate the correct distribution when b < mean or mean < a, I expect users will choose a and b so that a < mean < b. This method is extremely effective in this case.

@Enealor Enealor requested a review from apaszke as a code owner January 18, 2020 21:02
@kostmo
Copy link
Copy Markdown
Member

kostmo commented Jan 18, 2020

💊 CircleCI build failures summary and remediations

As of commit bd41aeb (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no CircleCI failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 19 times.

@Enealor
Copy link
Copy Markdown
Contributor Author

Enealor commented Jan 18, 2020

Addresses #2129 and part of #32293 by adding a truncated normal initializer. @cossio

@Enealor
Copy link
Copy Markdown
Contributor Author

Enealor commented Feb 1, 2020

The current version is passing the available checks. Any suggestions on changes or does it look good? @alicanb @cossio

@ezyang ezyang requested a review from alicanb February 3, 2020 15:51
@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Feb 3, 2020

Tentatively assigning review to @alicanb, let me know if you need someone else to look

@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 3, 2020
@alicanb
Copy link
Copy Markdown
Collaborator

alicanb commented Feb 3, 2020

I am happy to do it after the ICML deadline :D

@alicanb
Copy link
Copy Markdown
Collaborator

alicanb commented Feb 13, 2020

This mostly looks good. Maybe one thing I would add is a warning if bounds are too far away from the center (I actually don't know when it will break, it might be a good idea to test that and place the warnings accordingly). I see you're only testing truncnorm(0,1,-2, 2) in test_trunc_normal, it might be nice to add tests for various upper and lower bounds as well.

@Enealor
Copy link
Copy Markdown
Contributor Author

Enealor commented Feb 18, 2020

@alicanb Thanks for the feedback. I experimented with randomly selected a and b for the truncated standard normal. It seems that the most common failure point is to choose mean more than 2 standard deviations from the interval [a, b]. In this situation, the method fails to be statistically similar to Scipy's truncated normal. It can also fail if mean is closer to the interval but b-a is particularly small (around 1e-6).

With that, I will add a warning if either a is 2 standard deviations more than mean or b is 2 standard deviations less than mean. I'll update the documentation to clarify that either mean should be in the interval [a, b], or the distribution may be incorrect for particularly small b-a. I don't have enough info to turn this into a particular warning though.

Below is the code I used to test various results. I have also attached some example failures. The currently chosen range for values of a will typically pass. If a is chosen from the interval [3, 5], failures will start to occur. Similarly, if b is changed so that smaller numbers are more likely, then failures will start to become more frequent.

import torch, scipy, random
from scipy.stats import kstest
from torch.nn import init
def _is_trunc_normal(tensor, mean, std, a, b):
    p_value = scipy.stats.kstest(tensor.flatten().tolist(), 'truncnorm', args=(a, b))[1]
    return p_value


if __name__ == '__main__':
    input_tensor = torch.empty((10, 10, 20))
    for _ in range(1000):
        a = random.uniform(3, 3)
        b = random.uniform(a, a + 1)
        init.trunc_normal_(input_tensor.flatten(), mean=0., std=1., a=a, b=b)

        p_value = _is_trunc_normal(input_tensor, 0., 1., a, b)
        if p_value <= 0.0001:
            print("Failed for interval [{0:.3}, {1:.3}], length {2:.3}".format(a, b, b-a))

# Failed for interval [4.98, 5.18], length 0.202
# Failed for interval [3.87, 3.87], length 0.00138
# Failed for interval [2.89, 2.89], length 0.000154

@Enealor
Copy link
Copy Markdown
Contributor Author

Enealor commented Feb 22, 2020

I updated the warning message to be more descriptive and report at the appropriate stack level. The current message triggers if the mean is too far from the interval. For example,

t = torch.zeros((1, 10))
init.trunc_normal_(t, 3, 0.1, -2, 2) # mean is 3, std is .1, truncated to [-2, 2]
# __main__:1: UserWarning: mean is more than 2 std from [a, b] in nn.init.trunc_normal_.
# The distribution of values may be incorrect.

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Mar 9, 2020

@alicanb let me know when to land!

@alicanb
Copy link
Copy Markdown
Collaborator

alicanb commented Mar 19, 2020

LGTM, sorry I forgot about this.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@cossio
Copy link
Copy Markdown

cossio commented Mar 20, 2020

Great! So this will be out in the next PyTorch minor version? Any estimate for when that'll be? Thanks!

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@ezyang merged this pull request in 8bcedf7.

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
This adds the `trunc_normal_` function to `torch.nn.init` which allows for modifying tensors in-place to values drawn from a truncated normal distribution. I chose to use the inverse CDF method to implement this. I have included the appropriate code in `test_nn.py` for verifying that the values are from the correct distribution.

Reasons I chose this method:
1. Easily implemented to operate on memory in place, as the other initializers are.
1. No resampling delays
1. This method's main weakness is unlikely to be an issue. While the inverse CDF method can fail to generate the correct distribution when `b < mean` or `mean < a`,  I expect users will choose `a` and `b` so that `a < mean < b`. This method is extremely effective in this case.
Pull Request resolved: pytorch#32397

Differential Revision: D20550996

Pulled By: ezyang

fbshipit-source-id: 298a325043a3fd7d1e24d266e3b9b6cc14f81829
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants