@@ -2575,7 +2575,7 @@ def callable(a, b) -> number
25752575
25762576add_docstr_all ('scatter_' ,
25772577 r"""
2578- scatter_(dim, index, src) -> Tensor
2578+ scatter_(dim, index, src, reduce=None ) -> Tensor
25792579
25802580Writes all values from the tensor :attr:`src` into :attr:`self` at the indices
25812581specified in the :attr:`index` tensor. For each value in :attr:`src`, its output
@@ -2599,6 +2599,27 @@ def callable(a, b) -> number
25992599between ``0`` and ``self.size(dim) - 1`` inclusive, and all values in a row
26002600along the specified dimension :attr:`dim` must be unique.
26012601
2602+ Additionally accepts an optional :attr:`reduce` argument that allows
2603+ specification of an optional reduction operation, which is applied to all
2604+ values in the tensor :attr:`src` into :attr:`self` at the indicies
2605+ specified in the :attr:`index`. For each value in :attr:`src`, the reduction
2606+ operation is applied to an index in :attr:`self` which is specified by
2607+ its index in :attr:`src` for ``dimension != dim`` and by the corresponding
2608+ value in :attr:`index` for ``dimension = dim``.
2609+
2610+ Given a 3-D tensor and reduction using the multiplication operation, :attr:`self`
2611+ is updated as::
2612+
2613+ self[index[i][j][k]][j][k] *= src[i][j][k] # if dim == 0
2614+ self[i][index[i][j][k]][k] *= src[i][j][k] # if dim == 1
2615+ self[i][j][index[i][j][k]] *= src[i][j][k] # if dim == 2
2616+
2617+ Reducing with the addition operation is the same as using
2618+ :meth:`~torch.Tensor.scatter_add_`.
2619+
2620+ Note:
2621+ Reduction is not yet implemented for the CUDA backend.
2622+
26022623Args:
26032624 dim (int): the axis along which to index
26042625 index (LongTensor): the indices of elements to scatter,
@@ -2608,6 +2629,8 @@ def callable(a, b) -> number
26082629 incase `value` is not specified
26092630 value (float): the source element(s) to scatter,
26102631 incase `src` is not specified
2632+ reduce (string): reduction operation to apply,
2633+ can be either 'add', 'subtract', 'multiply' or 'divide'.
26112634
26122635Example::
26132636
@@ -2624,6 +2647,11 @@ def callable(a, b) -> number
26242647 >>> z
26252648 tensor([[ 0.0000, 0.0000, 1.2300, 0.0000],
26262649 [ 0.0000, 0.0000, 0.0000, 1.2300]])
2650+
2651+ >>> z = torch.ones(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23, reduce='multiply')
2652+ >>> z
2653+ tensor([[1.0000, 1.0000, 1.2300, 1.0000],
2654+ [1.0000, 1.0000, 1.0000, 1.2300]])
26272655""" )
26282656
26292657add_docstr_all ('scatter_add_' ,
0 commit comments