[typing] Add missing type annotations to torch.nn.init module#154504
[typing] Add missing type annotations to torch.nn.init module#154504janumiko wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154504
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 810fdcb with merge base f58143b ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "topic: not user facing" |
00bbf1a to
fd1d2ee
Compare
There was a problem hiding this comment.
| def _make_deprecate(meth: Callable[..., Tensor]) -> Callable[..., Tensor]: | |
| new_name = meth.__name__ | |
| old_name = new_name[:-1] | |
| def deprecated_init(*args, **kwargs): | |
| def deprecated_init(*args: tuple[Any], **kwargs: dict[str, Any]) -> Tensor: | |
| def _make_deprecate(meth: Callable[_P, _R]) -> Callable[_P, _R]: | |
| new_name = meth.__name__ | |
| old_name = new_name[:-1] | |
| def deprecated_init(*args : _P.args, **kwargs: _P.kwargs) -> _R: |
Where _R is a typing.TypeVar and _P is a typing_extensions.ParamSpec
There was a problem hiding this comment.
Besides, this should all be replaced with typing_extensions.deprecated anyway
There was a problem hiding this comment.
Thanks for the feedback!
Updated the typing to use *args: _P.args, **kwargs: _P.kwargs) -> _R as requested.
Good point about replacing this with @deprecated, I'll work on that in the next PR and keep this one contained to typing.
cfe9803 to
5916600
Compare
There was a problem hiding this comment.
| def _calculate_correct_fan(tensor: Tensor, mode: str) -> int: | |
| def _calculate_correct_fan(tensor: Tensor, mode: Literal["fan_in", "fan_out"]) -> int: |
There was a problem hiding this comment.
Fixed, created FanMode type alias for Literal instead because it appears in the code multiple times.
Did the same for the nonlinearity parameter typing, due to over 10 possible values.
eeaf9b1 to
4c9f011
Compare
3f780af to
eb9fb8d
Compare
|
Added |
|
@pytorchbot label "suppress-bc-linter" Add supress-bc-linter due to false-positive about changing str to Literal in kaiming init. https://github.com/pytorch/pytorch/actions/runs/15306112609/job/43064631148 |
|
@pytorchbot merge |
|
Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra. |
eb9fb8d to
810fdcb
Compare
|
Fixed linting issues that were causing CI failures. The only remaining CI error seems to be from the BC linter, which I believe is a false positive. |
|
@pytorchbot merge |
|
Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…h#154504) ## Summary Adds missing type annotations to `torch.nn.init` and removes `# mypy: allow-untyped-defs` since all functions are now properly typed. ## Changes - Added missing type annotations to initialization functions in the module. - Added missing typing imports: `Any`, `Callable`, `Union` - Removed `# mypy: allow-untyped-defs` comment - Create Literal types for kaiming initialization mode and nonlinearity. - Created `__all__` ## Why Better IDE support, catches type errors earlier, and brings the module up to PyTorch's typing standards. No runtime changes - purely additive typing improvements. Tested with existing test suite and lintrunner. Pull Request resolved: pytorch#154504 Approved by: https://github.com/Skylion007
…h#154504) ## Summary Adds missing type annotations to `torch.nn.init` and removes `# mypy: allow-untyped-defs` since all functions are now properly typed. ## Changes - Added missing type annotations to initialization functions in the module. - Added missing typing imports: `Any`, `Callable`, `Union` - Removed `# mypy: allow-untyped-defs` comment - Create Literal types for kaiming initialization mode and nonlinearity. - Created `__all__` ## Why Better IDE support, catches type errors earlier, and brings the module up to PyTorch's typing standards. No runtime changes - purely additive typing improvements. Tested with existing test suite and lintrunner. Pull Request resolved: pytorch#154504 Approved by: https://github.com/Skylion007
Summary
Adds missing type annotations to
torch.nn.initand removes# mypy: allow-untyped-defssince all functions are now properly typed.Changes
Any,Callable,Union# mypy: allow-untyped-defscomment__all__Why
Better IDE support, catches type errors earlier, and brings the module up to PyTorch's typing standards. No runtime changes - purely additive typing improvements.
Tested with existing test suite and lintrunner.