Skip to content

Bug: grouped_mm produces non-zero gradients for zero-size groups on B200 (Blackwell) GPUs #172439

@BIGBALLON

Description

@BIGBALLON

🐛 Describe the bug

torch._grouped_mm produces non-zero gradients for zero-size groups (experts with no tokens) on B200 (Blackwell, SM 100) GPUs, while it correctly produces zero gradients on H800 (Hopper, SM 90) GPUs.

This is a critical bug for Mixture-of-Experts (MoE) models where some experts may have zero tokens assigned in certain batches.

Versions

🔍 Expected Behavior

When a group has split_size=0 (no tokens assigned to that expert), the weight gradient for that expert should be zero because there is no input to compute gradients from.

❌ Actual Behavior

On B200 GPUs, zero-size groups produce non-zero gradients with significant magnitude, indicating that the backward kernel is incorrectly computing gradients using uninitialized or invalid data.

📋 Steps to Reproduce

  1. Run the provided test script on a B200 GPU:
python test_grouped_mm_zero_size_bug.py
  1. The script will show non-zero gradients for all zero-size groups.

Minimal Reproducible Example

import torch

# Create test case with zero-size groups
split_sizes = [162, 0, 2577, 446, 0, 83]  # Groups at indices 1 and 4 have size 0
hidden_size = 2048
intermediate_size = 768

inputs = torch.randn(sum(split_sizes), hidden_size, dtype=torch.bfloat16, device="cuda")
weights = torch.randn(len(split_sizes), hidden_size, intermediate_size, dtype=torch.bfloat16, device="cuda")
inputs.requires_grad_(True)
weights.requires_grad_(True)

offsets = torch.tensor(split_sizes, dtype=torch.int32, device="cuda").cumsum(0, dtype=torch.int32)

# Forward pass
outputs = torch._grouped_mm(inputs, weights, offsets, out_dtype=torch.bfloat16)

# Backward pass
outputs.backward(torch.ones_like(outputs))
torch.cuda.synchronize()

# Check gradients for zero-size groups
zero_indices = [i for i, s in enumerate(split_sizes) if s == 0]
for idx in zero_indices:
    grad_norm = weights.grad[idx].norm().item()
    print(f"Expert {idx} (split_size=0): gradient norm = {grad_norm:.6e}")
    # On B200: Non-zero (BUG)
    # On H800: Zero (CORRECT)

🔧 Environment

B200 (Bug Present)

  • PyTorch version: 2.9.1
  • CUDA version: 12.9
  • Compute Capability: 10.0 (SM 100)
  • Architecture: Blackwell

H800 (Working Correctly)

  • PyTorch version: 2.9.1
  • CUDA version: 12.9
  • Compute Capability: 9.0 (SM 90)
  • Architecture: Hopper

📊 Test Results

test_grouped_mm_zero_size_bug.py

On B200 (Bug):

Expert 1 (split_size=0): gradient norm = 5.580800e+04  ✗ BUG
Expert 4 (split_size=0): gradient norm = 4.582400e+04  ✗ BUG
Expert 10 (split_size=0): gradient norm = 4.403200e+04 ✗ BUG
Expert 15 (split_size=0): gradient norm = 4.352000e+04 ✗ BUG
Expert 25 (split_size=0): gradient norm = 6.707200e+04 ✗ BUG

On H800 (Correct):

Expert 1 (split_size=0): gradient norm = 0.000000e+00  ✓ CORRECT
Expert 4 (split_size=0): gradient norm = 0.000000e+00  ✓ CORRECT
Expert 10 (split_size=0): gradient norm = 0.000000e+00 ✓ CORRECT
Expert 15 (split_size=0): gradient norm = 0.000000e+00 ✓ CORRECT
Expert 25 (split_size=0): gradient norm = 0.000000e+00 ✓ CORRECT

💡 logs

Blackwell.log
Hopper.log

cc @ptrblck @msaroufim @eqy @jerryzh168 @tinglvv @nWEIdia @csarofeen @xwang233 @jianyuh @nikitaved @mruberry @walterddr @lezcano

Metadata

Metadata

Assignees

Labels

module: cublasProblem related to cublas supportmodule: cudaRelated to torch.cuda, and CUDA support in generalmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions