Commit 8cfe5e2
Refactor allreduce for supporting prefill case (#2453)
* fea(ar): refactor custom allreduce
Signed-off-by: root <root@hjbog-srdc-24.amd.com>
* fea: support prefill
Signed-off-by: root <root@hjbog-srdc-24.amd.com>
* add latency cmp with rccl
Signed-off-by: root <root@hjbog-srdc-24.amd.com>
* fix: remove ck in new kernel
Signed-off-by: root <root@hjbog-srdc-24.amd.com>
* fix: ruff check
Signed-off-by: root <root@hjbog-srdc-24.amd.com>
* fix: test script format
Signed-off-by: root <root@hjbog-srdc-24.amd.com>
* fix: ruff check
Signed-off-by: root <root@hjbog-srdc-24.amd.com>
* fix: pa_metadata macro err
Signed-off-by: root <root@hjbog-srdc-24.amd.com>
* fea(car): support aiter tensor
Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>
* fix: move pybind aiter tensor to dtypes.py
Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>
* add aiter_tensor_module
* update
* update
* update
* update
* update
* update
* fix: fused_ar_rms gpt n=2880 case
Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>
* [Kernel][Perf] Make allreduce fusion kernels support arbitrary hidden_dim
Previously the fused allreduce+rmsnorm+quant kernels only supported
N=512/1024/2048/4096 via compile-time template dispatch. This made
models with other hidden_dim (e.g. GLM-5 N=6144, GPT-OSS N=2880)
fall back to the slower non-fused path.
Changes:
- Convert HIDDEN_DIM/BLOCK_SIZE from template parameter to runtime
parameter in 1stage/2stage/split fusion kernels
- Use __launch_bounds__(1024,1) with runtime thread count
- Fix block_reduce for non-power-of-2 warp counts (round up
reduce_width for shfl_xor correctness)
- Pad 1stage launch threads to WARP_SIZE multiples with active guard
- Use dynamic shared memory for 2stage kernel
- Generalize step2 dispatch (local_device_load_rmsnorm) to support
any N where n_packs >= 64, removing n_bytes%1024 alignment requirement
- Replace silent printf errors with throw for unsupported shapes
- Add AITER_AR_1STAGE env override for benchmarking
- Improve test_fused_ar_rms.py: add error column, --test flag,
multi-shape support, markdown summary table
Now supports any N that satisfies: N % pack_size == 0 and
N / pack_size <= 1024 (i.e. N <= 8192 for bf16).
* fix: add param support_prefill in ar
Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>
* fix: test_fused_ar_rms.py
Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>
* fix: test_fused_ar_rms.py
Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>
---------
Signed-off-by: root <root@hjbog-srdc-24.amd.com>
Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com>
Co-authored-by: root <root@hjbog-srdc-24.amd.com>
Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com>
Co-authored-by: amd-ruitang3 <rui.tang2@amd.com>
Co-authored-by: amd-ruitang3 <145657428+amd-ruitang3@users.noreply.github.com>1 parent e47cc0e commit 8cfe5e2
21 files changed
Lines changed: 1622 additions & 1220 deletions
File tree
- aiter
- dist
- device_communicators
- jit
- utils
- ops
- utility
- csrc
- include
- kernels
- pybind
- op_tests/multigpu_tests
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
28 | 28 | | |
29 | 29 | | |
30 | 30 | | |
31 | | - | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
32 | 35 | | |
33 | 36 | | |
34 | | - | |
| 37 | + | |
35 | 38 | | |
36 | 39 | | |
37 | 40 | | |
38 | | - | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
39 | 46 | | |
40 | | - | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
41 | 50 | | |
42 | 51 | | |
43 | 52 | | |
44 | | - | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
45 | 58 | | |
46 | 59 | | |
47 | | - | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
48 | 65 | | |
49 | 66 | | |
50 | 67 | | |
| |||
53 | 70 | | |
54 | 71 | | |
55 | 72 | | |
56 | | - | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
57 | 76 | | |
58 | 77 | | |
59 | 78 | | |
60 | 79 | | |
61 | 80 | | |
62 | | - | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
63 | 84 | | |
64 | 85 | | |
65 | 86 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2 | 2 | | |
3 | 3 | | |
4 | 4 | | |
| 5 | + | |
5 | 6 | | |
6 | 7 | | |
7 | 8 | | |
| |||
14 | 15 | | |
15 | 16 | | |
16 | 17 | | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
17 | 23 | | |
18 | 24 | | |
19 | 25 | | |
| |||
148 | 154 | | |
149 | 155 | | |
150 | 156 | | |
151 | | - | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
152 | 162 | | |
153 | 163 | | |
154 | 164 | | |
| |||
169 | 179 | | |
170 | 180 | | |
171 | 181 | | |
172 | | - | |
173 | | - | |
174 | | - | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
175 | 189 | | |
176 | 190 | | |
177 | 191 | | |
| |||
191 | 205 | | |
192 | 206 | | |
193 | 207 | | |
194 | | - | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
195 | 214 | | |
196 | 215 | | |
197 | 216 | | |
| |||
205 | 224 | | |
206 | 225 | | |
207 | 226 | | |
208 | | - | |
209 | | - | |
210 | | - | |
211 | | - | |
212 | | - | |
213 | | - | |
214 | | - | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
215 | 241 | | |
216 | | - | |
| 242 | + | |
217 | 243 | | |
218 | 244 | | |
219 | 245 | | |
| |||
230 | 256 | | |
231 | 257 | | |
232 | 258 | | |
233 | | - | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
234 | 265 | | |
235 | 266 | | |
236 | 267 | | |
237 | 268 | | |
238 | 269 | | |
| 270 | + | |
239 | 271 | | |
240 | | - | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
241 | 277 | | |
242 | 278 | | |
243 | 279 | | |
244 | 280 | | |
245 | | - | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
246 | 284 | | |
247 | 285 | | |
248 | 286 | | |
| |||
0 commit comments