🐛 Describe the bug
torch._grouped_mm produces non-zero gradients for zero-size groups (experts with no tokens) on B200 (Blackwell, SM 100) GPUs, while it correctly produces zero gradients on H800 (Hopper, SM 90) GPUs.
This is a critical bug for Mixture-of-Experts (MoE) models where some experts may have zero tokens assigned in certain batches.
Versions
🔍 Expected Behavior
When a group has split_size=0 (no tokens assigned to that expert), the weight gradient for that expert should be zero because there is no input to compute gradients from.
❌ Actual Behavior
On B200 GPUs, zero-size groups produce non-zero gradients with significant magnitude, indicating that the backward kernel is incorrectly computing gradients using uninitialized or invalid data.
📋 Steps to Reproduce
- Run the provided test script on a B200 GPU:
python test_grouped_mm_zero_size_bug.py
- The script will show non-zero gradients for all zero-size groups.
Minimal Reproducible Example
import torch
# Create test case with zero-size groups
split_sizes = [162, 0, 2577, 446, 0, 83] # Groups at indices 1 and 4 have size 0
hidden_size = 2048
intermediate_size = 768
inputs = torch.randn(sum(split_sizes), hidden_size, dtype=torch.bfloat16, device="cuda")
weights = torch.randn(len(split_sizes), hidden_size, intermediate_size, dtype=torch.bfloat16, device="cuda")
inputs.requires_grad_(True)
weights.requires_grad_(True)
offsets = torch.tensor(split_sizes, dtype=torch.int32, device="cuda").cumsum(0, dtype=torch.int32)
# Forward pass
outputs = torch._grouped_mm(inputs, weights, offsets, out_dtype=torch.bfloat16)
# Backward pass
outputs.backward(torch.ones_like(outputs))
torch.cuda.synchronize()
# Check gradients for zero-size groups
zero_indices = [i for i, s in enumerate(split_sizes) if s == 0]
for idx in zero_indices:
grad_norm = weights.grad[idx].norm().item()
print(f"Expert {idx} (split_size=0): gradient norm = {grad_norm:.6e}")
# On B200: Non-zero (BUG)
# On H800: Zero (CORRECT)
🔧 Environment
B200 (Bug Present)
- PyTorch version: 2.9.1
- CUDA version: 12.9
- Compute Capability: 10.0 (SM 100)
- Architecture: Blackwell
H800 (Working Correctly)
- PyTorch version: 2.9.1
- CUDA version: 12.9
- Compute Capability: 9.0 (SM 90)
- Architecture: Hopper
📊 Test Results
test_grouped_mm_zero_size_bug.py
On B200 (Bug):
Expert 1 (split_size=0): gradient norm = 5.580800e+04 ✗ BUG
Expert 4 (split_size=0): gradient norm = 4.582400e+04 ✗ BUG
Expert 10 (split_size=0): gradient norm = 4.403200e+04 ✗ BUG
Expert 15 (split_size=0): gradient norm = 4.352000e+04 ✗ BUG
Expert 25 (split_size=0): gradient norm = 6.707200e+04 ✗ BUG
On H800 (Correct):
Expert 1 (split_size=0): gradient norm = 0.000000e+00 ✓ CORRECT
Expert 4 (split_size=0): gradient norm = 0.000000e+00 ✓ CORRECT
Expert 10 (split_size=0): gradient norm = 0.000000e+00 ✓ CORRECT
Expert 15 (split_size=0): gradient norm = 0.000000e+00 ✓ CORRECT
Expert 25 (split_size=0): gradient norm = 0.000000e+00 ✓ CORRECT
💡 logs
Blackwell.log
Hopper.log
cc @ptrblck @msaroufim @eqy @jerryzh168 @tinglvv @nWEIdia @csarofeen @xwang233 @jianyuh @nikitaved @mruberry @walterddr @lezcano
🐛 Describe the bug
torch._grouped_mmproduces non-zero gradients for zero-size groups (experts with no tokens) on B200 (Blackwell, SM 100) GPUs, while it correctly produces zero gradients on H800 (Hopper, SM 90) GPUs.This is a critical bug for Mixture-of-Experts (MoE) models where some experts may have zero tokens assigned in certain batches.
Versions
🔍 Expected Behavior
When a group has
split_size=0(no tokens assigned to that expert), the weight gradient for that expert should be zero because there is no input to compute gradients from.❌ Actual Behavior
On B200 GPUs, zero-size groups produce non-zero gradients with significant magnitude, indicating that the backward kernel is incorrectly computing gradients using uninitialized or invalid data.
📋 Steps to Reproduce
Minimal Reproducible Example
🔧 Environment
B200 (Bug Present)
H800 (Working Correctly)
📊 Test Results
test_grouped_mm_zero_size_bug.py
On B200 (Bug):
On H800 (Correct):
💡 logs
Blackwell.log
Hopper.log
cc @ptrblck @msaroufim @eqy @jerryzh168 @tinglvv @nWEIdia @csarofeen @xwang233 @jianyuh @nikitaved @mruberry @walterddr @lezcano