|
9 | 9 | (TestCase, run_tests, make_tensor) |
10 | 10 | from torch.testing._internal.common_device_type import \ |
11 | 11 | (instantiate_device_type_tests, dtypes, onlyOnCPUAndCUDA, |
12 | | - skipCUDAIfRocm, onlyCUDA, dtypesIfCUDA) |
| 12 | + skipCUDAIfRocm, onlyCUDA, dtypesIfCUDA, onlyCPU) |
13 | 13 |
|
14 | 14 | # TODO: remove this |
15 | 15 | SIZE = 100 |
@@ -113,6 +113,84 @@ def test_sort(self, device): |
113 | 113 | self.assertIsOrdered('descending', x, res2val, res2ind, |
114 | 114 | 'random with NaNs') |
115 | 115 |
|
| 116 | + @onlyCUDA |
| 117 | + @dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128}) |
| 118 | + def test_stable_sort_fails_on_CUDA(self, device, dtype): |
| 119 | + x = torch.tensor([1, 0, 1, 0], dtype=dtype, device=device) |
| 120 | + with self.assertRaisesRegex(RuntimeError, "stable=True is not implemented on CUDA yet."): |
| 121 | + x.sort(stable=True) |
| 122 | + |
| 123 | + @onlyCPU |
| 124 | + @dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128}) |
| 125 | + def test_stable_sort(self, device, dtype): |
| 126 | + for ncopies in (100, 1000, 10000): |
| 127 | + x = torch.tensor([0, 1] * ncopies, dtype=dtype, device=device) |
| 128 | + _, idx = x.sort(stable=True) |
| 129 | + self.assertEqual( |
| 130 | + idx[:ncopies], |
| 131 | + torch.arange(start=0, end=2 * ncopies, step=2, device=device) |
| 132 | + ) |
| 133 | + self.assertEqual( |
| 134 | + idx[ncopies:], |
| 135 | + torch.arange(start=1, end=2 * ncopies, step=2, device=device) |
| 136 | + ) |
| 137 | + |
| 138 | + @onlyCPU |
| 139 | + @dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128}) |
| 140 | + def test_stable_sort_against_numpy(self, device, dtype): |
| 141 | + if dtype in torch.testing.floating_types_and(torch.float16): |
| 142 | + inf = float('inf') |
| 143 | + neg_inf = -float('inf') |
| 144 | + nan = float('nan') |
| 145 | + else: |
| 146 | + if dtype != torch.bool: |
| 147 | + # no torch.iinfo support for torch.bool |
| 148 | + inf = torch.iinfo(dtype).max |
| 149 | + neg_inf = torch.iinfo(dtype).min |
| 150 | + else: |
| 151 | + inf = True |
| 152 | + neg_inf = ~inf |
| 153 | + # no nan for integral types, we use inf instead for simplicity |
| 154 | + nan = inf |
| 155 | + |
| 156 | + def generate_samples(): |
| 157 | + from itertools import chain, combinations |
| 158 | + |
| 159 | + def repeated_index_fill(t, dim, idxs, vals): |
| 160 | + res = t |
| 161 | + for idx, val in zip(idxs, vals): |
| 162 | + res = res.index_fill(dim, idx, val) |
| 163 | + return res |
| 164 | + |
| 165 | + for sizes in [(1, 10), (10, 1), (10, 10), (10, 10, 10)]: |
| 166 | + size = min(*sizes) |
| 167 | + x = (torch.randn(*sizes, device=device) * size).to(dtype) |
| 168 | + yield (x, 0) |
| 169 | + |
| 170 | + # Generate tensors which are being filled at random locations |
| 171 | + # with values from the non-empty subsets of the set (inf, neg_inf, nan) |
| 172 | + # for each dimension. |
| 173 | + n_fill_vals = 3 # cardinality of (inf, neg_inf, nan) |
| 174 | + for dim in range(len(sizes)): |
| 175 | + idxs = (torch.randint(high=size, size=(size // 10,)) for i in range(n_fill_vals)) |
| 176 | + vals = (inf, neg_inf, nan) |
| 177 | + subsets = chain.from_iterable(combinations(list(zip(idxs, vals)), r) |
| 178 | + for r in range(1, n_fill_vals + 1)) |
| 179 | + for subset in subsets: |
| 180 | + idxs_subset, vals_subset = zip(*subset) |
| 181 | + yield (repeated_index_fill(x, dim, idxs_subset, vals_subset), dim) |
| 182 | + |
| 183 | + for sizes in [(100,), (1000,), (10000,)]: |
| 184 | + size = sizes[0] |
| 185 | + # binary strings |
| 186 | + yield (torch.tensor([0, 1] * size, dtype=dtype, device=device), 0) |
| 187 | + |
| 188 | + for sample, dim in generate_samples(): |
| 189 | + _, idx_torch = sample.sort(dim=dim, stable=True) |
| 190 | + sample_numpy = sample.numpy() |
| 191 | + idx_numpy = np.argsort(sample_numpy, axis=dim, kind='stable') |
| 192 | + self.assertEqual(idx_torch, idx_numpy) |
| 193 | + |
116 | 194 | @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False))) |
117 | 195 | def test_msort(self, device, dtype): |
118 | 196 | def test(shape): |
|
0 commit comments