Add type annotations to distribution.py#87577
Add type annotations to distribution.py#87577EPronovost wants to merge 2 commits intopytorch:masterfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87577
Note: Links to docs will display an error until the docs builds have been completed. ✅ No Failures, 1 PendingAs of commit c3dbf55: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
|
@EPronovost The linter says the types are not correct. |
|
Thanks @kit1980 . Fixed. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Hey @EPronovost. |
|
Hi @kit1980 , thanks for the review. The github action bot says I should add release notes and topic labels to the PR, but the github UI says I don't have permission to add labels to PRs. Is there something I should do here? |
|
@EPronovost It's fine ignore it for now. You can add labels by commenting to pytorchbot, see https://github.com/pytorch/pytorch/wiki/Bot-commands#labeling |
| return self.rsample(sample_shape) | ||
|
|
||
| def rsample(self, sample_shape=torch.Size()): | ||
| def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: |
There was a problem hiding this comment.
torch.Size is too strict here since something like
dist = torch.distributions.Uniform(0.0, 1.0)
dist.sample((2, 3))is no longer allowed since (2, 3) is not a torch.Size. This forces users to do
dist.sample(torch.Size((2, 3)))which is likely not intended, is it? We should use torch.types._size
Line 22 in 166b5d3
as annotation here and in all other occurrences of this annotation.
As title. Pull Request resolved: pytorch#87577 Approved by: https://github.com/kit1980
As title. Pull Request resolved: pytorch#87577 Approved by: https://github.com/kit1980
As title.