Skip to content

Commit 864f0cf

Browse files
rgommersfacebook-github-bot
authored andcommitted
Fix type annotations for torch.sparse, enable in CI (#43108)
Summary: Closes gh-42982 Pull Request resolved: #43108 Reviewed By: malfet Differential Revision: D23167560 Pulled By: ezyang fbshipit-source-id: 0d660ca686ada2347bf440c6349551d1539f99ef
1 parent 6db0b87 commit 864f0cf

2 files changed

Lines changed: 10 additions & 12 deletions

File tree

mypy.ini

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,6 @@ ignore_errors = True
9393
[mypy-torch.jit.*]
9494
ignore_errors = True
9595

96-
[mypy-torch.sparse]
97-
ignore_errors = True
98-
9996
[mypy-torch.tensor]
10097
ignore_errors = True
10198

torch/sparse/__init__.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
# The Tensor classes are added to this module by python_tensor.cpp
2-
from typing import Optional, Tuple
2+
from typing import Optional, Tuple, List, Union
33

44
import torch
55
from torch import Tensor
66

77
# A workaround to support both TorchScript and MyPy:
88
from typing import TYPE_CHECKING
99
if TYPE_CHECKING:
10-
from torch import dtype as DType
10+
from torch.types import _dtype as DType
11+
DimOrDims = Optional[Union[int, Tuple[int], List[int]]]
1112
else:
13+
# The JIT doesn't understand Union, nor torch.dtype here
1214
DType = int
13-
# TODO: replace the above with
14-
# from torch.types import _dtype as DType
15+
DimOrDims = Optional[Tuple[int]]
1516

1617

1718
__all__ = [
@@ -23,8 +24,8 @@
2324
]
2425

2526

26-
def addmm(mat, mat1, mat2, beta=1, alpha=1):
27-
# type: (Tensor, Tensor, Tensor, float, float) -> Tensor
27+
def addmm(mat: Tensor, mat1: Tensor, mat2: Tensor,
28+
beta: float = 1., alpha: float = 1.) -> Tensor:
2829
r"""
2930
This function does exact same thing as :func:`torch.addmm` in the forward,
3031
except that it supports backward for sparse matrix :attr:`mat1`. :attr:`mat1`
@@ -41,7 +42,7 @@ def addmm(mat, mat1, mat2, beta=1, alpha=1):
4142
return torch._sparse_addmm(mat, mat1, mat2, beta=beta, alpha=alpha)
4243

4344

44-
def mm(mat1, mat2):
45+
def mm(mat1: Tensor, mat2: Tensor) -> Tensor:
4546
r"""
4647
Performs a matrix multiplication of the sparse matrix :attr:`mat1`
4748
and dense matrix :attr:`mat2`. Similar to :func:`torch.mm`, If :attr:`mat1` is a
@@ -83,8 +84,8 @@ def mm(mat1, mat2):
8384
return torch._sparse_mm(mat1, mat2)
8485

8586

86-
def sum(input, dim=None, dtype=None):
87-
# type: (Tensor, Optional[Tuple[int]], Optional[int]) -> Tensor
87+
def sum(input: Tensor, dim: DimOrDims = None,
88+
dtype: Optional[DType] = None) -> Tensor:
8889
r"""
8990
Returns the sum of each row of SparseTensor :attr:`input` in the given
9091
dimensions :attr:`dim`. If :attr:`dim` is a list of dimensions,

0 commit comments

Comments
 (0)