Skip to content

Commit 8cfe5e2

Browse files
TennyWang1223rootvalarLipamd-ruitang3
authored
Refactor allreduce for supporting prefill case (#2453)
* fea(ar): refactor custom allreduce Signed-off-by: root <root@hjbog-srdc-24.amd.com> * fea: support prefill Signed-off-by: root <root@hjbog-srdc-24.amd.com> * add latency cmp with rccl Signed-off-by: root <root@hjbog-srdc-24.amd.com> * fix: remove ck in new kernel Signed-off-by: root <root@hjbog-srdc-24.amd.com> * fix: ruff check Signed-off-by: root <root@hjbog-srdc-24.amd.com> * fix: test script format Signed-off-by: root <root@hjbog-srdc-24.amd.com> * fix: ruff check Signed-off-by: root <root@hjbog-srdc-24.amd.com> * fix: pa_metadata macro err Signed-off-by: root <root@hjbog-srdc-24.amd.com> * fea(car): support aiter tensor Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com> * fix: move pybind aiter tensor to dtypes.py Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com> * add aiter_tensor_module * update * update * update * update * update * update * fix: fused_ar_rms gpt n=2880 case Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com> * [Kernel][Perf] Make allreduce fusion kernels support arbitrary hidden_dim Previously the fused allreduce+rmsnorm+quant kernels only supported N=512/1024/2048/4096 via compile-time template dispatch. This made models with other hidden_dim (e.g. GLM-5 N=6144, GPT-OSS N=2880) fall back to the slower non-fused path. Changes: - Convert HIDDEN_DIM/BLOCK_SIZE from template parameter to runtime parameter in 1stage/2stage/split fusion kernels - Use __launch_bounds__(1024,1) with runtime thread count - Fix block_reduce for non-power-of-2 warp counts (round up reduce_width for shfl_xor correctness) - Pad 1stage launch threads to WARP_SIZE multiples with active guard - Use dynamic shared memory for 2stage kernel - Generalize step2 dispatch (local_device_load_rmsnorm) to support any N where n_packs >= 64, removing n_bytes%1024 alignment requirement - Replace silent printf errors with throw for unsupported shapes - Add AITER_AR_1STAGE env override for benchmarking - Improve test_fused_ar_rms.py: add error column, --test flag, multi-shape support, markdown summary table Now supports any N that satisfies: N % pack_size == 0 and N / pack_size <= 1024 (i.e. N <= 8192 for bf16). * fix: add param support_prefill in ar Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com> * fix: test_fused_ar_rms.py Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com> * fix: test_fused_ar_rms.py Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com> --------- Signed-off-by: root <root@hjbog-srdc-24.amd.com> Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com> Co-authored-by: root <root@hjbog-srdc-24.amd.com> Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> Co-authored-by: amd-ruitang3 <rui.tang2@amd.com> Co-authored-by: amd-ruitang3 <145657428+amd-ruitang3@users.noreply.github.com>
1 parent e47cc0e commit 8cfe5e2

21 files changed

Lines changed: 1622 additions & 1220 deletions

aiter/dist/communication_op.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,40 @@
2828

2929

3030
def tensor_model_parallel_all_reduce(
31-
input_: torch.Tensor, use_new: bool = True, open_fp8_quant: bool = False
31+
input_: torch.Tensor,
32+
use_new: bool = True,
33+
open_fp8_quant: bool = False,
34+
prefill_support: bool = False,
3235
) -> torch.Tensor:
3336
"""All-reduce the input tensor across model parallel group."""
34-
return get_tp_group().all_reduce(input_, use_new, open_fp8_quant)
37+
return get_tp_group().all_reduce(input_, use_new, open_fp8_quant, prefill_support)
3538

3639

3740
def tensor_model_parallel_fused_allreduce_rmsnorm(
38-
input_: torch.Tensor, residual_inp_: torch.Tensor, weight_: torch.Tensor, eps: float
41+
input_: torch.Tensor,
42+
residual_inp_: torch.Tensor,
43+
weight_: torch.Tensor,
44+
eps: float,
45+
prefill_support: bool = False,
3946
) -> tuple[torch.Tensor, torch.Tensor]:
40-
return get_tp_group().fused_allreduce_rmsnorm(input_, residual_inp_, weight_, eps)
47+
return get_tp_group().fused_allreduce_rmsnorm(
48+
input_, residual_inp_, weight_, eps, prefill_support
49+
)
4150

4251

4352
def tensor_model_parallel_fused_allreduce_rmsnorm_quant(
44-
input_: torch.Tensor, residual_inp_: torch.Tensor, weight_: torch.Tensor, eps: float
53+
input_: torch.Tensor,
54+
residual_inp_: torch.Tensor,
55+
weight_: torch.Tensor,
56+
eps: float,
57+
prefill_support: bool = False,
4558
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
4659
return get_tp_group().fused_allreduce_rmsnorm_quant(
47-
input_, residual_inp_, weight_, eps
60+
input_,
61+
residual_inp_,
62+
weight_,
63+
eps,
64+
prefill_support,
4865
)
4966

5067

@@ -53,13 +70,17 @@ def tensor_model_parallel_custom_all_gather(input_: torch.Tensor) -> torch.Tenso
5370

5471

5572
def tensor_model_parallel_reduce_scatter(
56-
input_: torch.Tensor, use_custom: bool = True, dim: int = 0
73+
input_: torch.Tensor,
74+
use_custom: bool = True,
75+
dim: int = 0,
5776
) -> torch.Tensor:
5877
return get_tp_group().reduce_scatter_tensor(input_, use_custom, dim)
5978

6079

6180
def tensor_model_parallel_all_gather(
62-
input_: torch.Tensor, use_custom: bool = False, dim: int = -1
81+
input_: torch.Tensor,
82+
use_custom: bool = False,
83+
dim: int = -1,
6384
) -> torch.Tensor:
6485
"""All-gather the input tensor across model parallel group."""
6586
return get_tp_group().all_gather(input_, use_custom, dim)

aiter/dist/device_communicators/communicator_cuda.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44

5+
import os
56
import torch
67
from torch.distributed import ProcessGroup
78

@@ -14,6 +15,11 @@
1415

1516

1617
class CudaCommunicator(DeviceCommunicatorBase):
18+
# AITER_AR_1STAGE=1 forces 1stage, =0 forces non-1stage, unset uses auto
19+
_ar_1stage_override = {"1": True, "0": False}.get(
20+
os.environ.get("AITER_AR_1STAGE", "")
21+
)
22+
1723
def __init__(
1824
self,
1925
cpu_group: ProcessGroup,
@@ -148,7 +154,11 @@ def all2all_manager(self, value):
148154
self._all2all_manager_created = True
149155

150156
def all_reduce(
151-
self, input_, use_new: bool = True, ca_fp8_quant: bool = False
157+
self,
158+
input_,
159+
use_new: bool = True,
160+
ca_fp8_quant: bool = False,
161+
prefill_support: bool = False,
152162
) -> torch.Tensor:
153163
# always try quick reduce first, then custom allreduce,
154164
# and then pynccl. (quick reduce just for ROCM MI3*)
@@ -169,9 +179,13 @@ def all_reduce(
169179
and not ca_comm.disabled
170180
and ca_comm.should_custom_ar(input_)
171181
):
172-
out = ca_comm.custom_all_reduce(input_, use_new, ca_fp8_quant)
173-
assert out is not None
174-
return out
182+
inp_size = input_.numel() * input_.element_size()
183+
if not prefill_support and inp_size > 64 * 1024 * 1024:
184+
pass # fall through to rccl for large prefill tensors
185+
else:
186+
out = ca_comm.custom_all_reduce(input_, use_new, ca_fp8_quant)
187+
assert out is not None
188+
return out
175189
symm_mem_comm = self.symm_mem_comm
176190
if symm_mem_comm is not None and symm_mem_comm.should_use_symm_mem(input_):
177191
out = symm_mem_comm.all_reduce(input_)
@@ -191,7 +205,12 @@ def all_reduce(
191205
return out
192206

193207
def fused_allreduce_rmsnorm(
194-
self, input_, res_inp_, weight_, eps
208+
self,
209+
input_,
210+
res_inp_,
211+
weight_,
212+
eps,
213+
prefill_support: bool = False,
195214
) -> tuple[torch.Tensor, torch.Tensor]:
196215
n = input_.shape[-1]
197216
total_bytes = input_.numel() * input_.element_size()
@@ -205,15 +224,22 @@ def fused_allreduce_rmsnorm(
205224
and ca_comm.should_custom_ar(input_)
206225
and can_use_fuse_ar_rms
207226
):
208-
use_1stage = True if total_bytes <= 128 * 1024 else False
209-
out, res_out = ca_comm.custom_fused_ar_rms(
210-
input_, res_inp_, weight_, eps, use_1stage
211-
)
212-
assert out is not None
213-
assert res_out is not None
214-
return out, res_out
227+
if not prefill_support and total_bytes > 64 * 1024 * 1024:
228+
pass # fall through to rccl for large prefill tensors
229+
else:
230+
use_1stage = (
231+
self._ar_1stage_override
232+
if self._ar_1stage_override is not None
233+
else (total_bytes <= 128 * 1024)
234+
)
235+
out, res_out = ca_comm.custom_fused_ar_rms(
236+
input_, res_inp_, weight_, eps, use_1stage
237+
)
238+
assert out is not None
239+
assert res_out is not None
240+
return out, res_out
215241
# call split kernel
216-
ar_out = self.all_reduce(input_)
242+
ar_out = self.all_reduce(input_, prefill_support=prefill_support)
217243
out = torch.empty_like(ar_out)
218244
residual_out = torch.empty_like(ar_out)
219245
from aiter import rmsnorm2d_fwd_with_add
@@ -230,19 +256,31 @@ def fused_allreduce_rmsnorm(
230256
return out, residual_out
231257

232258
def fused_allreduce_rmsnorm_quant(
233-
self, input_, res_inp_, weight_, eps
259+
self,
260+
input_,
261+
res_inp_,
262+
weight_,
263+
eps,
264+
prefill_support: bool = False,
234265
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
235266
total_bytes = input_.numel() * input_.element_size()
236267
if (
237268
int(input_.shape[-1]) in [512, 1024, 2048, 4096]
238269
and total_bytes <= 4096 * 1024
270+
and (prefill_support or total_bytes <= 64 * 1024 * 1024)
239271
):
240-
use_1stage = True if total_bytes <= 128 * 1024 else False
272+
use_1stage = (
273+
self._ar_1stage_override
274+
if self._ar_1stage_override is not None
275+
else (total_bytes <= 128 * 1024)
276+
)
241277
out, res_out, scale_out = self.ca_comm.custom_fused_ar_rms_quant(
242278
input_, res_inp_, weight_, eps, use_1stage
243279
)
244280
else:
245-
out_, res_out = self.fused_allreduce_rmsnorm(input_, res_inp_, weight_, eps)
281+
out_, res_out = self.fused_allreduce_rmsnorm(
282+
input_, res_inp_, weight_, eps, prefill_support
283+
)
246284
hip_quant = get_hip_quant(QuantType.per_Token)
247285
out, scale_out = hip_quant(out_, quant_dtype=fp8)
248286
assert out is not None

0 commit comments

Comments
 (0)