Skip to content

Commit 9699c70

Browse files
nikitavedfacebook-github-bot
authored andcommitted
Stable sort for the CPU take 2. (#51790)
Summary: Fixes #38681. A duplicate of #50052 created to become importable to the fb internal tests. Pull Request resolved: #51790 Reviewed By: agolynski Differential Revision: D26279045 Pulled By: glaringlee fbshipit-source-id: 348e171dee9c370a76002b65d0c82c329f57a421
1 parent 5fda3b0 commit 9699c70

15 files changed

Lines changed: 249 additions & 29 deletions

aten/src/ATen/LegacyTHFunctionsCUDA.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ std::tuple<Tensor &,Tensor &> _th_mode_out(Tensor & values, Tensor & indices, co
2828
std::tuple<Tensor,Tensor> _th_mode(const Tensor & self, int64_t dim, bool keepdim);
2929
std::tuple<Tensor &,Tensor &> _th_sort_out(Tensor & values, Tensor & indices, const Tensor & self, int64_t dim, bool descending);
3030
std::tuple<Tensor,Tensor> _th_sort(const Tensor & self, int64_t dim, bool descending);
31+
std::tuple<Tensor &,Tensor &> _th_sort_out_stable(Tensor & values, Tensor & indices, const Tensor & self, c10::optional<bool> stable, int64_t dim, bool descending);
32+
std::tuple<Tensor,Tensor> _th_sort_stable(const Tensor & self, c10::optional<bool> stable, int64_t dim, bool descending);
3133
std::tuple<Tensor &,Tensor &> _th_topk_out(Tensor & values, Tensor & indices, const Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted);
3234
std::tuple<Tensor,Tensor> _th_topk(const Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted);
3335
Tensor & _th_renorm_out(Tensor & result, const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm);

aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,13 @@ std::tuple<Tensor,Tensor> _th_mode(const Tensor & self, int64_t dim, bool keepdi
304304
}
305305
return std::tuple<Tensor, Tensor>(values, indices);
306306
}
307-
std::tuple<Tensor &,Tensor &> _th_sort_out(Tensor & values, Tensor & indices, const Tensor & self, int64_t dim, bool descending) {
307+
std::tuple<Tensor &,Tensor &> _th_sort_out_stable(Tensor & values, Tensor & indices, const Tensor & self, c10::optional<bool> stable, int64_t dim, bool descending) {
308308
// DeviceGuard omitted
309309
auto dispatch_scalar_type = infer_scalar_type(self);
310310

311+
TORCH_INTERNAL_ASSERT(stable.has_value(), "sort_out(): c10::optional<bool> for stable has to have value.");
312+
TORCH_CHECK(!stable.value(), "stable=True is not implemented on CUDA yet.");
313+
311314
switch (dispatch_scalar_type) {
312315
case ScalarType::Byte: {
313316
auto values_ = checked_dense_tensor_unwrap(values, "values", 0, "_th_sort_out", false, DeviceType::CUDA, dispatch_scalar_type);
@@ -370,8 +373,15 @@ std::tuple<Tensor &,Tensor &> _th_sort_out(Tensor & values, Tensor & indices, co
370373
}
371374
return std::tuple<Tensor &, Tensor &>(values, indices);
372375
}
373-
std::tuple<Tensor,Tensor> _th_sort(const Tensor & self, int64_t dim, bool descending) {
376+
std::tuple<Tensor &,Tensor &> _th_sort_out(Tensor & values, Tensor & indices, const Tensor & self, int64_t dim, bool descending) {
377+
return _th_sort_out_stable(values, indices, self, /*stable=*/false, dim, descending);
378+
}
379+
std::tuple<Tensor,Tensor> _th_sort_stable(const Tensor & self, c10::optional<bool> stable, int64_t dim, bool descending) {
374380
// DeviceGuard omitted
381+
382+
TORCH_INTERNAL_ASSERT(stable.has_value(), "sort_out(): c10::optional<bool> for stable has to have value.");
383+
TORCH_CHECK(!stable.value(), "stable=True is not implemented on CUDA yet.");
384+
375385
auto dispatch_scalar_type = infer_scalar_type(self);
376386
auto values_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release();
377387
auto values = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(values_));
@@ -423,6 +433,9 @@ std::tuple<Tensor,Tensor> _th_sort(const Tensor & self, int64_t dim, bool descen
423433
}
424434
return std::tuple<Tensor, Tensor>(values, indices);
425435
}
436+
std::tuple<Tensor,Tensor> _th_sort(const Tensor & self, int64_t dim, bool descending) {
437+
return _th_sort_stable(self, /*stable=*/false, dim, descending);
438+
}
426439
std::tuple<Tensor &,Tensor &> _th_topk_out(Tensor & values, Tensor & indices, const Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted) {
427440
// DeviceGuard omitted
428441
auto dispatch_scalar_type = infer_scalar_type(self);

aten/src/ATen/native/CompositeRandomAccessorCommon.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ class CompositeRandomAccessor {
122122
using difference_type = typename std::iterator_traits<KeyAccessor>::difference_type;
123123
using iterator_category = std::random_access_iterator_tag;
124124

125+
C10_HOST_DEVICE
126+
CompositeRandomAccessor() = default;
127+
125128
C10_HOST_DEVICE
126129
CompositeRandomAccessor(KeyAccessor keys, ValueAccessor values)
127130
: keys(keys), values(values)

aten/src/ATen/native/NamedTensor.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,9 +359,15 @@ Tensor scatter_add(const Tensor& self, Dimname dim, const Tensor& index, const T
359359
Tensor& scatter_add_(Tensor& self, Dimname dim, const Tensor& index, const Tensor& source) {
360360
reportNYIDimnameOverload("scatter_add");
361361
}
362+
std::tuple<Tensor&, Tensor&> sort_out(Tensor& values, Tensor& indices, const Tensor& self, c10::optional<bool> stable, Dimname dim, bool keepdim) {
363+
reportNYIDimnameOverload("sort");
364+
}
362365
std::tuple<Tensor&, Tensor&> sort_out(Tensor& values, Tensor& indices, const Tensor& self, Dimname dim, bool keepdim) {
363366
reportNYIDimnameOverload("sort");
364367
}
368+
std::tuple<Tensor, Tensor> sort(const Tensor& self, c10::optional<bool> stable, Dimname dim, bool keepdim) {
369+
reportNYIDimnameOverload("sort");
370+
}
365371
std::tuple<Tensor, Tensor> sort(const Tensor& self, Dimname dim, bool keepdim) {
366372
reportNYIDimnameOverload("sort");
367373
}

aten/src/ATen/native/Sorting.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -682,10 +682,11 @@ Tensor nanmedian_cpu(const Tensor& self) {
682682
return median_impl(self, /*ignore_nan=*/true);
683683
}
684684

685-
std::tuple<Tensor&, Tensor&> sort_out_cpu(
685+
std::tuple<Tensor&, Tensor&> sort_out_cpu_stable(
686686
Tensor& values,
687687
Tensor& indices,
688688
const Tensor& self,
689+
c10::optional<bool> stable,
689690
int64_t dim,
690691
bool descending) {
691692
values.resize_(self.sizes()).copy_(self);
@@ -697,18 +698,37 @@ std::tuple<Tensor&, Tensor&> sort_out_cpu(
697698
return std::forward_as_tuple(values, indices);
698699
}
699700

700-
sort_stub(kCPU, values, indices, dim, descending);
701+
TORCH_INTERNAL_ASSERT(stable.has_value(), "sort_out(): c10::optional<bool> for stable has to have value.");
702+
sort_stub(kCPU, values, indices, dim, descending, stable.value());
701703

702704
return std::forward_as_tuple(values, indices);
703705
}
704706

705-
std::tuple<Tensor, Tensor> sort_cpu(
707+
std::tuple<Tensor&, Tensor&> sort_out_cpu(
708+
Tensor& values,
709+
Tensor& indices,
710+
const Tensor& self,
711+
int64_t dim,
712+
bool descending) {
713+
return sort_out_cpu_stable(values, indices, self, /*stable=*/false, dim, descending);
714+
}
715+
716+
std::tuple<Tensor, Tensor> sort_cpu_stable(
706717
const Tensor& self,
718+
c10::optional<bool> stable,
707719
int64_t dim,
708720
bool descending) {
721+
TORCH_CHECK(!self.is_complex(), "sort(): input tensor must be of non-complex type");
709722
Tensor values = at::empty({0}, self.options());
710723
Tensor indices = at::empty({0}, self.options().dtype(kLong));
711-
return sort_out_cpu(values, indices, self, dim, descending);
724+
return sort_out_cpu_stable(values, indices, self, stable, dim, descending);
725+
}
726+
727+
std::tuple<Tensor, Tensor> sort_cpu(
728+
const Tensor& self,
729+
int64_t dim,
730+
bool descending) {
731+
return sort_cpu_stable(self, /*stable=*/false, dim, descending);
712732
}
713733

714734
Tensor& msort_out(Tensor& values, const Tensor& self) {

aten/src/ATen/native/Sorting.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
namespace at { namespace native {
77

8-
using sort_fn = void(*)(Tensor& values, Tensor& indices, int64_t dim, bool descending);
8+
using sort_fn = void(*)(Tensor& values, Tensor& indices, int64_t dim, bool descending, bool stable);
99
using topk_fn = void(*)(Tensor&, Tensor&, const Tensor&, int64_t, int64_t, bool, bool);
1010

1111
DECLARE_DISPATCH(sort_fn, sort_stub);

aten/src/ATen/native/cpu/SortingKernel.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ static void sort_kernel(
9696
Tensor& values,
9797
Tensor& indices,
9898
int64_t dim,
99-
bool descending) {
99+
bool descending,
100+
bool stable) {
100101
dim = maybe_wrap_dim(dim, values.dim());
101102
_fill_indices(indices, dim);
102103
_dim_apply(
@@ -116,12 +117,24 @@ static void sort_kernel(
116117
>(values_accessor, indices_accessor);
117118

118119
if (descending) {
119-
std::sort(composite_accessor, composite_accessor + dim_size,
120-
KeyValueCompDesc<scalar_t>());
120+
if (stable) {
121+
std::stable_sort(composite_accessor, composite_accessor + dim_size,
122+
KeyValueCompDesc<scalar_t>());
123+
}
124+
else {
125+
std::sort(composite_accessor, composite_accessor + dim_size,
126+
KeyValueCompDesc<scalar_t>());
127+
}
121128
}
122129
else {
123-
std::sort(composite_accessor, composite_accessor + dim_size,
124-
KeyValueCompAsc<scalar_t>());
130+
if (stable) {
131+
std::stable_sort(composite_accessor, composite_accessor + dim_size,
132+
KeyValueCompAsc<scalar_t>());
133+
}
134+
else {
135+
std::sort(composite_accessor, composite_accessor + dim_size,
136+
KeyValueCompAsc<scalar_t>());
137+
}
125138
}
126139
}
127140
);

aten/src/ATen/native/native_functions.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6391,19 +6391,38 @@
63916391
CPU: sort_out_cpu
63926392
CUDA: legacy::cuda::_th_sort_out
63936393

6394+
- func: sort.values_stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
6395+
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
6396+
dispatch:
6397+
CPU: sort_out_cpu_stable
6398+
CUDA: legacy::cuda::_th_sort_out_stable
6399+
63946400
- func: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
63956401
variants: method, function
63966402
dispatch:
63976403
CPU: sort_cpu
63986404
CUDA: legacy::cuda::_th_sort
63996405
QuantizedCPU: sort_quantized_cpu
64006406

6407+
- func: sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
6408+
variants: method, function
6409+
dispatch:
6410+
CPU: sort_cpu_stable
6411+
CUDA: legacy::cuda::_th_sort_stable
6412+
QuantizedCPU: sort_quantized_cpu_stable
6413+
64016414
- func: sort.dimname_values(Tensor self, Dimname dim, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
64026415
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
64036416

6417+
- func: sort.dimname_values_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
6418+
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
6419+
64046420
- func: sort.dimname(Tensor self, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices)
64056421
variants: method, function
64066422

6423+
- func: sort.dimname_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices)
6424+
variants: method, function
6425+
64076426
- func: msort.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
64086427
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
64096428
dispatch:

aten/src/ATen/native/quantized/TensorCompare.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,27 @@ Tensor min_quantized_cpu(const Tensor& self) {
2020

2121
// TODO: move to TensorMath.cpp
2222

23-
std::tuple<Tensor, Tensor> sort_quantized_cpu(
23+
std::tuple<Tensor, Tensor> sort_quantized_cpu_stable(
2424
const Tensor& self,
25+
c10::optional<bool> stable,
2526
int64_t dim,
2627
bool descending) {
2728
Tensor sort_int;
2829
Tensor sort_indicies;
2930
std::tie(sort_int, sort_indicies) =
30-
at::sort(self.int_repr(), dim, descending);
31+
at::sort(self.int_repr(), stable, dim, descending);
3132
return std::forward_as_tuple(
3233
at::_make_per_tensor_quantized_tensor(
3334
sort_int, self.q_scale(), self.q_zero_point()),
3435
sort_indicies);
3536
}
3637

38+
std::tuple<Tensor, Tensor> sort_quantized_cpu(
39+
const Tensor& self,
40+
int64_t dim,
41+
bool descending) {
42+
return sort_quantized_cpu_stable(self, /*stable=*/false, dim, descending);
43+
}
44+
3745
} // namespace native
3846
} // namespace at

test/test_sort_and_select.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
(TestCase, run_tests, make_tensor)
1010
from torch.testing._internal.common_device_type import \
1111
(instantiate_device_type_tests, dtypes, onlyOnCPUAndCUDA,
12-
skipCUDAIfRocm, onlyCUDA, dtypesIfCUDA)
12+
skipCUDAIfRocm, onlyCUDA, dtypesIfCUDA, onlyCPU)
1313

1414
# TODO: remove this
1515
SIZE = 100
@@ -113,6 +113,84 @@ def test_sort(self, device):
113113
self.assertIsOrdered('descending', x, res2val, res2ind,
114114
'random with NaNs')
115115

116+
@onlyCUDA
117+
@dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128})
118+
def test_stable_sort_fails_on_CUDA(self, device, dtype):
119+
x = torch.tensor([1, 0, 1, 0], dtype=dtype, device=device)
120+
with self.assertRaisesRegex(RuntimeError, "stable=True is not implemented on CUDA yet."):
121+
x.sort(stable=True)
122+
123+
@onlyCPU
124+
@dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128})
125+
def test_stable_sort(self, device, dtype):
126+
for ncopies in (100, 1000, 10000):
127+
x = torch.tensor([0, 1] * ncopies, dtype=dtype, device=device)
128+
_, idx = x.sort(stable=True)
129+
self.assertEqual(
130+
idx[:ncopies],
131+
torch.arange(start=0, end=2 * ncopies, step=2, device=device)
132+
)
133+
self.assertEqual(
134+
idx[ncopies:],
135+
torch.arange(start=1, end=2 * ncopies, step=2, device=device)
136+
)
137+
138+
@onlyCPU
139+
@dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128})
140+
def test_stable_sort_against_numpy(self, device, dtype):
141+
if dtype in torch.testing.floating_types_and(torch.float16):
142+
inf = float('inf')
143+
neg_inf = -float('inf')
144+
nan = float('nan')
145+
else:
146+
if dtype != torch.bool:
147+
# no torch.iinfo support for torch.bool
148+
inf = torch.iinfo(dtype).max
149+
neg_inf = torch.iinfo(dtype).min
150+
else:
151+
inf = True
152+
neg_inf = ~inf
153+
# no nan for integral types, we use inf instead for simplicity
154+
nan = inf
155+
156+
def generate_samples():
157+
from itertools import chain, combinations
158+
159+
def repeated_index_fill(t, dim, idxs, vals):
160+
res = t
161+
for idx, val in zip(idxs, vals):
162+
res = res.index_fill(dim, idx, val)
163+
return res
164+
165+
for sizes in [(1, 10), (10, 1), (10, 10), (10, 10, 10)]:
166+
size = min(*sizes)
167+
x = (torch.randn(*sizes, device=device) * size).to(dtype)
168+
yield (x, 0)
169+
170+
# Generate tensors which are being filled at random locations
171+
# with values from the non-empty subsets of the set (inf, neg_inf, nan)
172+
# for each dimension.
173+
n_fill_vals = 3 # cardinality of (inf, neg_inf, nan)
174+
for dim in range(len(sizes)):
175+
idxs = (torch.randint(high=size, size=(size // 10,)) for i in range(n_fill_vals))
176+
vals = (inf, neg_inf, nan)
177+
subsets = chain.from_iterable(combinations(list(zip(idxs, vals)), r)
178+
for r in range(1, n_fill_vals + 1))
179+
for subset in subsets:
180+
idxs_subset, vals_subset = zip(*subset)
181+
yield (repeated_index_fill(x, dim, idxs_subset, vals_subset), dim)
182+
183+
for sizes in [(100,), (1000,), (10000,)]:
184+
size = sizes[0]
185+
# binary strings
186+
yield (torch.tensor([0, 1] * size, dtype=dtype, device=device), 0)
187+
188+
for sample, dim in generate_samples():
189+
_, idx_torch = sample.sort(dim=dim, stable=True)
190+
sample_numpy = sample.numpy()
191+
idx_numpy = np.argsort(sample_numpy, axis=dim, kind='stable')
192+
self.assertEqual(idx_torch, idx_numpy)
193+
116194
@dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False)))
117195
def test_msort(self, device, dtype):
118196
def test(shape):

0 commit comments

Comments
 (0)