Skip to content

Commit 2a04414

Browse files
v0drocsarofeen
authored andcommitted
Update the documentation of the scatter_ method with support for reduction methods. (pytorch#40962)
Summary: Follow up to pytorch#36447 . Update for pytorch#33389. Also removes unused `unordered_map` include from the CPP file. Pull Request resolved: pytorch#40962 Differential Revision: D22376253 Pulled By: ngimel fbshipit-source-id: 4e7432190e9a847321aec6d6f6634056fa69bdb8
1 parent 968a90a commit 2a04414

2 files changed

Lines changed: 29 additions & 2 deletions

File tree

aten/src/ATen/native/cpu/ScatterGatherKernel.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include <ATen/native/TensorIterator.h>
44
#include <ATen/native/TensorAdvancedIndexing.h>
55
#include <ATen/Parallel.h>
6-
#include <unordered_map>
76

87
namespace at { namespace native {
98

torch/_tensor_docs.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2575,7 +2575,7 @@ def callable(a, b) -> number
25752575

25762576
add_docstr_all('scatter_',
25772577
r"""
2578-
scatter_(dim, index, src) -> Tensor
2578+
scatter_(dim, index, src, reduce=None) -> Tensor
25792579
25802580
Writes all values from the tensor :attr:`src` into :attr:`self` at the indices
25812581
specified in the :attr:`index` tensor. For each value in :attr:`src`, its output
@@ -2599,6 +2599,27 @@ def callable(a, b) -> number
25992599
between ``0`` and ``self.size(dim) - 1`` inclusive, and all values in a row
26002600
along the specified dimension :attr:`dim` must be unique.
26012601
2602+
Additionally accepts an optional :attr:`reduce` argument that allows
2603+
specification of an optional reduction operation, which is applied to all
2604+
values in the tensor :attr:`src` into :attr:`self` at the indicies
2605+
specified in the :attr:`index`. For each value in :attr:`src`, the reduction
2606+
operation is applied to an index in :attr:`self` which is specified by
2607+
its index in :attr:`src` for ``dimension != dim`` and by the corresponding
2608+
value in :attr:`index` for ``dimension = dim``.
2609+
2610+
Given a 3-D tensor and reduction using the multiplication operation, :attr:`self`
2611+
is updated as::
2612+
2613+
self[index[i][j][k]][j][k] *= src[i][j][k] # if dim == 0
2614+
self[i][index[i][j][k]][k] *= src[i][j][k] # if dim == 1
2615+
self[i][j][index[i][j][k]] *= src[i][j][k] # if dim == 2
2616+
2617+
Reducing with the addition operation is the same as using
2618+
:meth:`~torch.Tensor.scatter_add_`.
2619+
2620+
Note:
2621+
Reduction is not yet implemented for the CUDA backend.
2622+
26022623
Args:
26032624
dim (int): the axis along which to index
26042625
index (LongTensor): the indices of elements to scatter,
@@ -2608,6 +2629,8 @@ def callable(a, b) -> number
26082629
incase `value` is not specified
26092630
value (float): the source element(s) to scatter,
26102631
incase `src` is not specified
2632+
reduce (string): reduction operation to apply,
2633+
can be either 'add', 'subtract', 'multiply' or 'divide'.
26112634
26122635
Example::
26132636
@@ -2624,6 +2647,11 @@ def callable(a, b) -> number
26242647
>>> z
26252648
tensor([[ 0.0000, 0.0000, 1.2300, 0.0000],
26262649
[ 0.0000, 0.0000, 0.0000, 1.2300]])
2650+
2651+
>>> z = torch.ones(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23, reduce='multiply')
2652+
>>> z
2653+
tensor([[1.0000, 1.0000, 1.2300, 1.0000],
2654+
[1.0000, 1.0000, 1.0000, 1.2300]])
26272655
""")
26282656

26292657
add_docstr_all('scatter_add_',

0 commit comments

Comments
 (0)