Skip to content

[typing] Add missing type annotations to torch.nn.init module#154504

Closed
janumiko wants to merge 1 commit intopytorch:mainfrom
janumiko:fix/type-hints-nn-init
Closed

[typing] Add missing type annotations to torch.nn.init module#154504
janumiko wants to merge 1 commit intopytorch:mainfrom
janumiko:fix/type-hints-nn-init

Conversation

@janumiko
Copy link
Copy Markdown
Contributor

@janumiko janumiko commented May 28, 2025

Summary

Adds missing type annotations to torch.nn.init and removes # mypy: allow-untyped-defs since all functions are now properly typed.

Changes

  • Added missing type annotations to initialization functions in the module.
  • Added missing typing imports: Any, Callable, Union
  • Removed # mypy: allow-untyped-defs comment
  • Create Literal types for kaiming initialization mode and nonlinearity.
  • Created __all__

Why

Better IDE support, catches type errors earlier, and brings the module up to PyTorch's typing standards. No runtime changes - purely additive typing improvements.

Tested with existing test suite and lintrunner.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154504

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 810fdcb with merge base f58143b (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@janumiko
Copy link
Copy Markdown
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot Bot added the topic: not user facing topic category label May 28, 2025
@janumiko janumiko force-pushed the fix/type-hints-nn-init branch 2 times, most recently from 00bbf1a to fd1d2ee Compare May 28, 2025 13:16
Comment thread torch/nn/init.py Outdated
Comment on lines 680 to 684
Copy link
Copy Markdown
Collaborator

@Skylion007 Skylion007 May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _make_deprecate(meth: Callable[..., Tensor]) -> Callable[..., Tensor]:
new_name = meth.__name__
old_name = new_name[:-1]
def deprecated_init(*args, **kwargs):
def deprecated_init(*args: tuple[Any], **kwargs: dict[str, Any]) -> Tensor:
def _make_deprecate(meth: Callable[_P, _R]) -> Callable[_P, _R]:
new_name = meth.__name__
old_name = new_name[:-1]
def deprecated_init(*args : _P.args, **kwargs: _P.kwargs) -> _R:

Where _R is a typing.TypeVar and _P is a typing_extensions.ParamSpec

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides, this should all be replaced with typing_extensions.deprecated anyway

Copy link
Copy Markdown
Contributor Author

@janumiko janumiko May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback!
Updated the typing to use *args: _P.args, **kwargs: _P.kwargs) -> _R as requested.

Good point about replacing this with @deprecated, I'll work on that in the next PR and keep this one contained to typing.

@janumiko janumiko force-pushed the fix/type-hints-nn-init branch 2 times, most recently from cfe9803 to 5916600 Compare May 28, 2025 14:51
Comment thread torch/nn/init.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _calculate_correct_fan(tensor: Tensor, mode: str) -> int:
def _calculate_correct_fan(tensor: Tensor, mode: Literal["fan_in", "fan_out"]) -> int:

Copy link
Copy Markdown
Contributor Author

@janumiko janumiko May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, created FanMode type alias for Literal instead because it appears in the code multiple times.
Did the same for the nonlinearity parameter typing, due to over 10 possible values.

@janumiko janumiko force-pushed the fix/type-hints-nn-init branch 2 times, most recently from eeaf9b1 to 4c9f011 Compare May 28, 2025 17:01
@janumiko janumiko force-pushed the fix/type-hints-nn-init branch 2 times, most recently from 3f780af to eb9fb8d Compare May 28, 2025 21:34
@janumiko
Copy link
Copy Markdown
Contributor Author

janumiko commented May 28, 2025

Added __all__ due to failing CI tests.

@janumiko
Copy link
Copy Markdown
Contributor Author

janumiko commented May 28, 2025

@pytorchbot label "suppress-bc-linter"

Add supress-bc-linter due to false-positive about changing str to Literal in kaiming init.

Warning: Function kaiming_uniform_: mode changed from str to _FanMode
fetch github.event.pull_request.base.sha
  Warning: Function kaiming_uniform_: nonlinearity changed from str to _NonlinearityType
  Warning: Function kaiming_normal_: mode changed from str to _FanMode
  Warning: Function kaiming_normal_: nonlinearity changed from str to _NonlinearityType
Error: Process completed with exit code 1.

https://github.com/pytorch/pytorch/actions/runs/15306112609/job/43064631148

@pytorch-bot pytorch-bot Bot added the suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) label May 28, 2025
@janumiko
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 29, 2025

Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra.

@janumiko janumiko force-pushed the fix/type-hints-nn-init branch from eb9fb8d to 810fdcb Compare May 30, 2025 23:00
@janumiko
Copy link
Copy Markdown
Contributor Author

Fixed linting issues that were causing CI failures. The only remaining CI error seems to be from the BC linter, which I believe is a false positive.

@janumiko
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 31, 2025

Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra.

@Skylion007
Copy link
Copy Markdown
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 3, 2025
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

iupaikov-amd pushed a commit to ROCm/pytorch that referenced this pull request Jun 4, 2025
…h#154504)

## Summary

Adds missing type annotations to `torch.nn.init` and removes `# mypy: allow-untyped-defs` since all functions are now properly typed.

## Changes

- Added missing type annotations to initialization functions in the module.
- Added missing typing imports: `Any`, `Callable`, `Union`
- Removed `# mypy: allow-untyped-defs` comment
- Create Literal types for kaiming initialization mode and nonlinearity.
- Created `__all__`

## Why

Better IDE support, catches type errors earlier, and brings the module up to PyTorch's typing standards. No runtime changes - purely additive typing improvements.

Tested with existing test suite and lintrunner.
Pull Request resolved: pytorch#154504
Approved by: https://github.com/Skylion007
angelayi pushed a commit to angelayi/pytorch that referenced this pull request Jun 5, 2025
…h#154504)

## Summary

Adds missing type annotations to `torch.nn.init` and removes `# mypy: allow-untyped-defs` since all functions are now properly typed.

## Changes

- Added missing type annotations to initialization functions in the module.
- Added missing typing imports: `Any`, `Callable`, `Union`
- Removed `# mypy: allow-untyped-defs` comment
- Create Literal types for kaiming initialization mode and nonlinearity.
- Created `__all__`

## Why

Better IDE support, catches type errors earlier, and brings the module up to PyTorch's typing standards. No runtime changes - purely additive typing improvements.

Tested with existing test suite and lintrunner.
Pull Request resolved: pytorch#154504
Approved by: https://github.com/Skylion007
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants