Better Grouped GEMM + EP#45621
Conversation
|
will ad bf16 deepgemm to testing as well |
ArthurZucker
left a comment
There was a problem hiding this comment.
nice~
Having TP code in moe is fine IMO let's not separate into TP since anything could be using sentinels actually!
LGTM otherwise, deepgemm isolation could be another PR 😉
yeah at first i created it because it was gonna be used everywhere but now only a clamp is needed in batched paths. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
reverted the bf16 deepgemm and isolation |
|
Great that we have a couple models with TensorParallelTesterMixin they pass now ! |
| "finegrained-fp8": {"repo_id": "kernels-community/finegrained-fp8", "version": 1}, | ||
| "deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1}, | ||
| "sonic-moe": {"repo_id": "kernels-community/sonic-moe", "version": 1}, | ||
| "sonic-moe": {"repo_id": "IlyasMoutawwakil/sonic-moe", "revision": "main"}, |
There was a problem hiding this comment.
Quick sanity check, does this redirect mean that the currently-published kernels-community/sonic-moe does not yet have the metadata/sentinel handling?
Asking because I hit cudaErrorIllegalAddress reliably running sonicmoe with EP=8 (Qwen3-30B-A3B, 2 nodes × 8 H100, FSDP2 dp=2 ep=8) with the current hub kernel:
File "...sonic-moe/build/torch-cuda/quack/autotuner.py", line 84, in _gpu_warmup
a = torch.randn(4096, 4096, device="cuda", dtype=torch.bfloat16)
torch.AcceleratorError: CUDA error: an illegal memory access was encounteredThats some sticky CUDA errors that ran asynchronously and faulted before the error propagated back.
with EP=8, RouterParallel produces sentinel expert_ids ≥ num_local_experts (16, since 128/8). The v1 sonic-moe kernel internally does gate_up_proj[expert_ids[i]] which is OOB ??
when I add expert_ids.clamp(0, num_experts-1) and masked_fill_(invalid_mask, 0.0) in the wrapper before the kernel call everything works.
Removing the clamp brings the crash back 🥲 . So the v1 kernel really does need its inputs in-bounds, and the new build in your fork is what actually fixes it ?? Just want to confirm the plan is to republish the fixed build to kernels-community/sonic-moe and revert this redirect once that's done ?
| ws_down = self.down_proj_scale_inv | ||
| proj_fp8, proj_scales = deepgemm_per_token_cast_to_fp8(proj_out, use_ue8m0=False) | ||
| proj_out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) | ||
| proj_out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) |
There was a problem hiding this comment.
nice if we write to all, which we do
There was a problem hiding this comment.
loads of rep from moe vs fp8 cool if we re-use stuff, fine otherwise haha
|
I benchmarked this PR's
Repro: import torch
torch.manual_seed(0)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16
T, TOP_K, H = 8, 2, 4
S = T * TOP_K # 16 routing slots
def run_mode(label, *, sentinel_makes_nan, prezero_before_mul):
sentinel_mask = torch.zeros(S, dtype=torch.bool, device=DEVICE)
sentinel_mask[S // 2:] = True
sample_weights = torch.randn(S, device=DEVICE, dtype=torch.float32)
sample_weights[sentinel_mask] = 0.0
sample_weights = sample_weights.to(DTYPE).clone().requires_grad_(True)
proj_out = torch.randn(S, H, device=DEVICE, dtype=DTYPE)
if sentinel_makes_nan:
proj_out[sentinel_mask] = float("nan") # simulate uninitialized rows
proj_out = proj_out.clone().requires_grad_(True)
proj_out_used = (
proj_out.masked_fill(sentinel_mask.unsqueeze(-1), 0.0)
if prezero_before_mul else proj_out
)
weighted = proj_out_used * sample_weights.unsqueeze(-1)
weighted_zero = weighted.masked_fill(sentinel_mask.unsqueeze(-1), 0.0)
out = weighted_zero.view(T, TOP_K, H).sum(dim=1)
out.backward(torch.randn_like(out))
sw_nan = torch.isnan(sample_weights.grad).sum().item()
print(f"[{label}]")
print(f" forward out finite : {torch.isfinite(out).all().item()}")
print(f" sample_weights.grad nan: {sw_nan} of {S}")
print(f" sw.grad sentinel slice : {sample_weights.grad[sentinel_mask].float().tolist()}")
run_mode("A: proj_out fully init (baseline)",
sentinel_makes_nan=False, prezero_before_mul=False)
run_mode("B: NaN at sentinels, NO pre-zero (this PR's pattern)",
sentinel_makes_nan=True, prezero_before_mul=False)
run_mode("C: NaN at sentinels, pre-zero before mul (proposed fix)",
sentinel_makes_nan=True, prezero_before_mul=True)Output: Suggested fix: # Apply routing weights
+ # Zero sentinel rows of proj_out *before* the multiply so backward's
+ # `d_sample_weights_g = (d_weighted * proj_out).sum(-1)` doesn't read
+ # uninitialized memory at sentinel positions.
+ proj_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0)
weighted_out = proj_out * sample_weights_g.unsqueeze(-1)
- weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) |
|
Update on my earlier suggestion. I reran with After more digging, there are actually three separate sentinel grad leaks selected_hidden_states_g = hidden_states[perm // num_top_k] # (S, H)
proj_out = _grouped_linear(selected_hidden_states_g, gate_up_proj, offsets) # ←── leak 2
proj_out = proj_out.masked_fill(sentinel_mask, 0.0)
proj_out = self._apply_gate(proj_out)
proj_out = _grouped_linear(proj_out, down_proj, offsets) # ←── leak 3
proj_out.masked_fill_(sentinel_mask, 0.0)
weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # ←── leak 1Leak 1 is what my first comment, grad Nan in Leaks 2 & 3 are in the bwd path of Walking through what happens for the up-projection: # 1. Build the per-(token, slot) input tensor for the kernel by gathering rows from hidden_states:
selected_hidden_states_g = hidden_states[perm // num_top_k] # shape (S, H)
# 2. Run the kernel. It only does work for rows [0, offsets[-1]) sentinels are excluded.
proj_out = _grouped_mm(selected_hidden_states_g, gate_up_proj, offsets)In backward, autograd needs the gradient w.r.t. each input. The kernel produces Sentinel rows have NaN, so NaN gets added to Under the ~94 % all-sentinel-token pattern EP=8 produces, almost every token has at least one sentinel slot, so almost every row of Suggested fix is to maskout at each stage to zerout NaN values for backward |
* init * style * full support * support EP better using offsets ! * comments * get rid of neutralize_ep_sentinels * remove deepgemm stuff * fix * prefix * move * fix * remove comment * fix unintilized outputs leaking * revert unnecessary changes * more unnecessary changes * revert downcast * keep it simple * guard deepgemm cuda version * fix style * moe sentinel support * fix * compilable sonicmoe * fix * dtensor support * more dtensor * simpler * remove comment * make loaders return a dataclass/namespace that's correctly typed * style * fix
What does this PR do?
The idea is that we shouldn't be clamping experts ids at all, clamping them makes their tokens get prjected as if they were routed to the last expert, instead, we should let the sentinels be:
I micro-benchmarked the kernel with sentinel tokens and it is faster (it is skipping their compute as expected) :
and still data independent / compatible with torch.compile / cuda graphs
Code Agent Policy
The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by
code agents. We are currently bottlenecked by our ability to review and respond to them. As a result,
we ask that new users do not submit pure code agent PRs at this time.
You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents
not to open any PRs or issues for the moment.
PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this
repeatedly or maliciously.
This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result,
this policy is likely to be updated regularly in the near future. For more information, please read
CONTRIBUTING.md.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.