11# The Tensor classes are added to this module by python_tensor.cpp
2- from typing import Optional , Tuple
2+ from typing import Optional , Tuple , List , Union
33
44import torch
55from torch import Tensor
66
77# A workaround to support both TorchScript and MyPy:
88from typing import TYPE_CHECKING
99if TYPE_CHECKING :
10- from torch import dtype as DType
10+ from torch .types import _dtype as DType
11+ DimOrDims = Optional [Union [int , Tuple [int ], List [int ]]]
1112else :
13+ # The JIT doesn't understand Union, nor torch.dtype here
1214 DType = int
13- # TODO: replace the above with
14- # from torch.types import _dtype as DType
15+ DimOrDims = Optional [Tuple [int ]]
1516
1617
1718__all__ = [
2324]
2425
2526
26- def addmm (mat , mat1 , mat2 , beta = 1 , alpha = 1 ):
27- # type: (Tensor, Tensor, Tensor, float, float) -> Tensor
27+ def addmm (mat : Tensor , mat1 : Tensor , mat2 : Tensor ,
28+ beta : float = 1. , alpha : float = 1. ) -> Tensor :
2829 r"""
2930 This function does exact same thing as :func:`torch.addmm` in the forward,
3031 except that it supports backward for sparse matrix :attr:`mat1`. :attr:`mat1`
@@ -41,7 +42,7 @@ def addmm(mat, mat1, mat2, beta=1, alpha=1):
4142 return torch ._sparse_addmm (mat , mat1 , mat2 , beta = beta , alpha = alpha )
4243
4344
44- def mm (mat1 , mat2 ) :
45+ def mm (mat1 : Tensor , mat2 : Tensor ) -> Tensor :
4546 r"""
4647 Performs a matrix multiplication of the sparse matrix :attr:`mat1`
4748 and dense matrix :attr:`mat2`. Similar to :func:`torch.mm`, If :attr:`mat1` is a
@@ -83,8 +84,8 @@ def mm(mat1, mat2):
8384 return torch ._sparse_mm (mat1 , mat2 )
8485
8586
86- def sum (input , dim = None , dtype = None ):
87- # type: (Tensor, Optional[Tuple[int]], Optional[int] ) -> Tensor
87+ def sum (input : Tensor , dim : DimOrDims = None ,
88+ dtype : Optional [DType ] = None ) -> Tensor :
8889 r"""
8990 Returns the sum of each row of SparseTensor :attr:`input` in the given
9091 dimensions :attr:`dim`. If :attr:`dim` is a list of dimensions,
0 commit comments