Skip to content

Commit c2e0540

Browse files
committed
refine quant kernel and fix test, address comments
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
1 parent eeb1315 commit c2e0540

3 files changed

Lines changed: 50 additions & 199 deletions

File tree

csrc/fp8_blockscale_gemm_sm90_binding.cu

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,6 @@ class Fp8BlockScaleGemmRunner : public tvm::ffi::ModuleObj {
122122
if (input_is_fp8) {
123123
TVM_FFI_ICHECK(scales_a.has_value() && scales_a.value().data_ptr() != nullptr)
124124
<< "scales_a is required for FP8 input";
125-
// TensorRT-LLM expects scale shape: (K/128, M) after transpose
126-
// int64_t expected_scale_k = (shape_k + 127) / 128;
127-
// TVM_FFI_ICHECK(scales_a.value().size(0) == expected_scale_k &&
128-
// scales_a.value().size(1) == shape_m)
129-
// << "scales_a shape mismatch: expected (" << expected_scale_k << ", " << shape_m
130-
// << "), got (" << scales_a.value().size(0) << ", " << scales_a.value().size(1) << ")";
131125
}
132126

133127
if (weight_is_fp8) {

flashinfer/gemm/gemm_base.py

Lines changed: 3 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3413,6 +3413,9 @@ def fp8_blockscale_gemm_sm90(
34133413
f"K dimension must be divisible by block size ({BLOCK_SIZE}), got K={K}"
34143414
)
34153415

3416+
if N % 64 != 0:
3417+
raise ValueError(f"N dimension must be divisible by 64, got N={N}")
3418+
34163419
# Validate dtype combinations
34173420
input_is_fp8 = input.dtype == torch.float8_e4m3fn
34183421
weight_is_fp8 = weight.dtype == torch.float8_e4m3fn
@@ -3429,14 +3432,6 @@ def fp8_blockscale_gemm_sm90(
34293432
if input_is_fp8:
34303433
if input_scale is None:
34313434
raise ValueError("input_scale is required when input is FP8. ")
3432-
# Users provide input_scale in shape (M, K//128), matching per_token_cast_to_fp8 output.
3433-
# We transpose it internally to (K//128, M) to match TensorRT-LLM kernel expectations.
3434-
expected_scale_shape = (M, K // BLOCK_SIZE)
3435-
if input_scale.shape != expected_scale_shape:
3436-
raise ValueError(
3437-
f"input_scale shape mismatch. Expected {expected_scale_shape}, "
3438-
f"got {input_scale.shape}"
3439-
)
34403435
if input_scale.dtype != torch.float32:
34413436
raise ValueError(f"input_scale must be float32, got {input_scale.dtype}")
34423437
if input_scale.device != input.device:
@@ -3522,44 +3517,5 @@ def fp8_blockscale_gemm_sm90(
35223517
workspace = torch.empty(workspace_size, dtype=torch.uint8, device=input.device)
35233518
runner.configure_workspace(workspace)
35243519

3525-
if input_is_bf16 and weight_is_fp8 and input_scale is None:
3526-
# Quantize the bf16 input to FP8 with correct scale format to run gemm at fp8 x fp8
3527-
M_padded = ((M + 4 - 1) // 4) * 4 # Round M up to multiple of 4
3528-
K_blocks = (K + BLOCK_SIZE - 1) // BLOCK_SIZE
3529-
input_scale_size = ((K * M_padded * 4 + BLOCK_SIZE - 1) // BLOCK_SIZE) // 4
3530-
3531-
fp8_input = torch.empty((M_padded, K), dtype=torch.float8_e4m3fn, device=input.device)
3532-
input_scale = torch.empty((input_scale_size), dtype=torch.float32, device=input.device)
3533-
runner.fp8_quantize_1x128(input, fp8_input, input_scale, False)
3534-
input = fp8_input[:M, :]
3535-
else:
3536-
if input_scale is not None:
3537-
M_padded = ((M + 4 - 1) // 4) * 4 # Round M up to multiple of 4
3538-
K_blocks = K // BLOCK_SIZE
3539-
3540-
# Create padded tensor with the stride TRT-LLM expects
3541-
input_scale_padded = torch.zeros(
3542-
K_blocks, M_padded, dtype=torch.float32, device=input.device
3543-
)
3544-
3545-
# Copy scales into the non-padded region: (K//128, M)
3546-
# Transpose from (M, K//128) to (K//128, M) and copy
3547-
input_scale_padded[:, :M] = input_scale.T
3548-
3549-
# Extract view of the actual (K//128, M) region
3550-
# This view has stride (M_padded, 1) which matches TRT-LLM's expectations
3551-
input_scale = input_scale_padded[:, :M]
3552-
3553-
# Verify stride matches TRT-LLM's expectations
3554-
expected_stride_0 = M_padded
3555-
if input_scale.stride(0) != expected_stride_0:
3556-
raise ValueError(
3557-
f"input_scale stride mismatch: expected stride[0]={expected_stride_0} "
3558-
f"(M_padded={M_padded}), got {input_scale.stride(0)}"
3559-
)
3560-
else:
3561-
pass
3562-
35633520
runner.run_gemm(input, weight, out, input_scale, weight_scale)
3564-
35653521
return out

tests/gemm/test_fp8_blockscale_gemm.py

Lines changed: 47 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import flashinfer
2222
from flashinfer.gemm import fp8_blockscale_gemm_sm90
23-
from flashinfer.testing.utils import per_token_cast_to_fp8, per_block_cast_to_fp8
23+
from flashinfer.testing.utils import per_token_cast_to_fp8
2424
from flashinfer.utils import (
2525
get_compute_capability,
2626
has_flashinfer_jit_cache,
@@ -29,22 +29,6 @@
2929
from flashinfer.jit.gemm import gen_fp8_blockscale_gemm_sm90_module
3030

3131

32-
def calc_diff(output: torch.Tensor, expected: torch.Tensor) -> float:
33-
"""Calculate similarity difference using TensorRT-LLM's metric.
34-
35-
Returns diff = 1 - sim, where sim = 2*<x,y> / (||x||² + ||y||²)
36-
This is similar to cosine similarity but uses squared norms in denominator.
37-
38-
diff < 0.001 corresponds to >99.9% similarity.
39-
"""
40-
output_f64 = output.to(torch.float64)
41-
expected_f64 = expected.to(torch.float64)
42-
denominator = (output_f64 * output_f64 + expected_f64 * expected_f64).sum()
43-
sim = 2 * (output_f64 * expected_f64).sum() / denominator
44-
diff = 1 - sim
45-
return diff.item()
46-
47-
4832
@pytest.fixture(
4933
autouse=not has_flashinfer_jit_cache(),
5034
scope="module",
@@ -141,10 +125,16 @@ def test_fp8_blockscale_gemm_dtypes(m, n, k, input_dtype, weight_dtype):
141125

142126
device = "cuda"
143127
torch.manual_seed(42)
128+
fp8_info = torch.finfo(torch.float8_e4m3fn)
129+
fp8_max = fp8_info.max
144130

145131
# Create BF16 data for reference
146-
input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16)
147-
weight_bf16 = torch.randn(n, k, device=device, dtype=torch.bfloat16)
132+
input_bf16 = (
133+
(torch.rand(m, k, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
134+
)
135+
weight_bf16 = (
136+
(torch.rand(n, k, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
137+
)
148138

149139
# Quantize input
150140
if input_dtype == torch.float8_e4m3fn:
@@ -175,12 +165,7 @@ def test_fp8_blockscale_gemm_dtypes(m, n, k, input_dtype, weight_dtype):
175165
reference.flatten().float(), output.flatten().float(), dim=0
176166
)
177167

178-
if input_dtype == torch.bfloat16 and weight_dtype == torch.bfloat16:
179-
threshold = 0.99
180-
else:
181-
# BF16+FP8: BF16 input quantized internally, FP8 weight pre-quantized
182-
# TODO: check threshold
183-
threshold = 0.967
168+
threshold = 0.99
184169

185170
assert cos_sim > threshold, (
186171
f"Cosine similarity {cos_sim:.4f} too low for "
@@ -191,7 +176,8 @@ def test_fp8_blockscale_gemm_dtypes(m, n, k, input_dtype, weight_dtype):
191176
@pytest.mark.parametrize("m", [7, 32, 128])
192177
@pytest.mark.parametrize("n", [1024, 4096])
193178
@pytest.mark.parametrize("k", [512, 4096])
194-
def test_fp8_blockscale_gemm_w8a8(m, n, k):
179+
@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float8_e4m3fn])
180+
def test_fp8_blockscale_gemm_w8a8(m, n, k, input_dtype):
195181
"""Test W8A8 (FP8+FP8) GEMM with per-token scales for both input and weight.
196182
197183
This test demonstrates full FP8 quantization for both activations and weights.
@@ -206,11 +192,17 @@ def test_fp8_blockscale_gemm_w8a8(m, n, k):
206192
device = "cuda"
207193
# m, n, k = 64, 2048, 4096
208194
torch.manual_seed(42)
195+
fp8_info = torch.finfo(torch.float8_e4m3fn)
196+
fp8_max = fp8_info.max
209197

210198
# Create BF16 inputs for reference (no normalization)
211199
# Raw randn values work well with FP8 quantization without causing numerical issues
212-
input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16)
213-
weight_bf16 = torch.randn(n, k, device=device, dtype=torch.bfloat16)
200+
input_bf16 = (
201+
(torch.rand(m, k, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
202+
)
203+
weight_bf16 = (
204+
(torch.rand(n, k, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
205+
)
214206

215207
# Quantize both input and weight to FP8 with per-token (1x128) scales
216208
input_fp8, input_scale = per_token_cast_to_fp8(input_bf16)
@@ -226,18 +218,32 @@ def test_fp8_blockscale_gemm_w8a8(m, n, k):
226218
assert input_scale.min() > 0, "Input scale should be positive"
227219
assert weight_scale.min() > 0, "Weight scale should be positive"
228220

229-
# Run W8A8 GEMM: FP8 input + FP8 weight
230-
output = fp8_blockscale_gemm_sm90(input_fp8, weight_fp8, input_scale, weight_scale)
221+
M_padded = ((m + 4 - 1) // 4) * 4 # Round M up to multiple of 4
222+
K_blocks = k // 128
231223

232-
# Dequantize FP8 tensors to create reference (tests kernel correctness, not quantization)
233-
# Dequant: bf16 = fp8.to(bf16) * scale (applied per 128-element block)
234-
input_dequant = torch.zeros_like(input_bf16)
235-
for i in range(m):
236-
for k_tile in range(k // 128):
237-
start, end = k_tile * 128, (k_tile + 1) * 128
238-
input_dequant[i, start:end] = (
239-
input_fp8[i, start:end].to(torch.bfloat16) * input_scale[i, k_tile]
240-
)
224+
if input_dtype == torch.float8_e4m3fn:
225+
# Create padded tensor with the stride TRT-LLM expects
226+
input_scale_padded = torch.zeros(
227+
K_blocks, M_padded, dtype=torch.float32, device=device
228+
)
229+
input_scale_padded[:, :m] = input_scale.T
230+
input_scale_padded = input_scale_padded[:, :m]
231+
232+
output = fp8_blockscale_gemm_sm90(
233+
input_fp8, weight_fp8, input_scale_padded, weight_scale
234+
)
235+
# Dequantize FP8 tensors to create reference (tests kernel correctness, not quantization)
236+
# Dequant: bf16 = fp8.to(bf16) * scale (applied per 128-element block)
237+
input_dequant = torch.zeros_like(input_bf16)
238+
for i in range(m):
239+
for k_tile in range(k // 128):
240+
start, end = k_tile * 128, (k_tile + 1) * 128
241+
input_dequant[i, start:end] = (
242+
input_fp8[i, start:end].to(torch.bfloat16) * input_scale[i, k_tile]
243+
)
244+
else:
245+
output = fp8_blockscale_gemm_sm90(input_bf16, weight_fp8, None, weight_scale)
246+
input_dequant = input_bf16
241247

242248
weight_dequant = torch.zeros_like(weight_bf16)
243249
for j in range(n):
@@ -253,115 +259,10 @@ def test_fp8_blockscale_gemm_w8a8(m, n, k):
253259
cos_sim = F.cosine_similarity(
254260
reference.flatten().float(), output.flatten().float(), dim=0
255261
)
256-
# W8A8 achieves ~97% cosine similarity against dequantized FP8 reference
257-
assert cos_sim > 0.967, (
258-
f"W8A8 cosine similarity {cos_sim:.4f} too low (expected > 0.967)"
259-
)
260-
261-
print(f"✓ W8A8 (FP8+FP8): cosine similarity = {cos_sim:.4f}")
262-
263-
@pytest.mark.parametrize("m", [7, 32, 128])
264-
@pytest.mark.parametrize("n", [1024, 4096])
265-
@pytest.mark.parametrize("k", [512, 4096])
266-
def test_fp8_blockscale_gemm_w8a8_with_quant_kernel(m, n, k):
267-
"""Test W8A8 (FP8+FP8) GEMM with per-token scales for both input and weight.
268-
269-
This test demonstrates full FP8 quantization for both activations and weights.
270-
"""
271-
compute_capability = get_compute_capability(torch.device("cuda"))
272-
if compute_capability[0] < 9:
273-
pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later")
274-
275-
if not is_sm90a_supported(torch.device("cuda")):
276-
pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support")
277-
278-
device = "cuda"
279-
# m, n, k = 64, 2048, 4096
280-
torch.manual_seed(42)
281-
282-
# Create BF16 inputs for reference (no normalization)
283-
# Raw randn values work well with FP8 quantization without causing numerical issues
284-
input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16)
285-
weight_bf16 = torch.randn(n, k, device=device, dtype=torch.bfloat16)
286-
287-
# Quantize both input and weight to FP8 with per-token (1x128) scales
288-
weight_fp8, weight_scale = per_token_cast_to_fp8(weight_bf16)
289262

290-
# Verify scale shapes
291-
assert weight_scale.shape == (n, k // 128), (
292-
f"Expected weight scale shape ({n}, {k // 128}), got {weight_scale.shape}"
293-
)
294-
assert weight_scale.min() > 0, "Weight scale should be positive"
295-
296-
# Run W8A8 GEMM: FP8 input + FP8 weight
297-
output = fp8_blockscale_gemm_sm90(input_bf16, weight_fp8, None, weight_scale)
298-
299-
# Dequantize FP8 tensors to create reference (tests kernel correctness, not quantization)
300-
# Dequant: bf16 = fp8.to(bf16) * scale (applied per 128-element block)
301-
weight_dequant = torch.zeros_like(weight_bf16)
302-
for j in range(n):
303-
for k_tile in range(k // 128):
304-
start, end = k_tile * 128, (k_tile + 1) * 128
305-
weight_dequant[j, start:end] = (
306-
weight_fp8[j, start:end].to(torch.bfloat16) * weight_scale[j, k_tile]
307-
)
308-
309-
reference = torch.matmul(input_bf16, weight_dequant.T)
310-
311-
# Use cosine similarity (same metric as BF16+FP8 tests)
312-
cos_sim = F.cosine_similarity(
313-
reference.flatten().float(), output.flatten().float(), dim=0
263+
assert cos_sim > 0.99, (
264+
f"W8A8 cosine similarity {cos_sim:.4f} too low (expected > 0.99)"
314265
)
315-
# W8A8 achieves ~97% cosine similarity against dequantized FP8 reference
316-
assert cos_sim > 0.967, (
317-
f"W8A8 cosine similarity {cos_sim:.4f} too low (expected > 0.967)"
318-
)
319-
320-
print(f"✓ W8A8 (FP8+FP8): cosine similarity = {cos_sim:.4f}")
321-
322-
323-
def test_fp8_blockscale_gemm_per_block_weight_scales():
324-
"""Test BF16+FP8 GEMM with per-block (128x128) weight scales.
325-
326-
This test demonstrates using 128x128 block quantization for weights with BF16 input,
327-
"""
328-
compute_capability = get_compute_capability(torch.device("cuda"))
329-
if compute_capability[0] < 9:
330-
pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later")
331-
332-
if not is_sm90a_supported(torch.device("cuda")):
333-
pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support")
334-
335-
device = "cuda"
336-
m, n, k = 16, 512, 512
337-
torch.manual_seed(42)
338-
339-
# Create inputs
340-
input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16)
341-
weight_bf16 = torch.randn(n, k, device=device, dtype=torch.bfloat16)
342-
343-
# Quantize weight with per-block (128x128) blocks
344-
weight_fp8, weight_scale = per_block_cast_to_fp8(weight_bf16)
345-
346-
# Verify scale shape
347-
assert weight_scale.shape == (n // 128, k // 128), (
348-
f"Expected weight scale shape ({n // 128}, {k // 128}), got {weight_scale.shape}"
349-
)
350-
assert weight_scale.min() > 0, "Weight scale should be positive (reciprocal format)"
351-
352-
# Run GEMM: BF16 input (internal quant) + FP8 weight (per-block scales)
353-
output = fp8_blockscale_gemm_sm90(input_bf16, weight_fp8, None, weight_scale)
354-
355-
# Compare to BF16 reference
356-
reference = torch.matmul(input_bf16, weight_bf16.T)
357-
358-
cos_sim = F.cosine_similarity(
359-
reference.flatten().float(), output.flatten().float(), dim=0
360-
)
361-
# TODO: check threshold
362-
assert cos_sim > 0.967, f"Per-block weight scale accuracy too low: {cos_sim:.4f}"
363-
364-
print(f"✓ Per-block weight scales: cosine similarity = {cos_sim:.4f}")
365266

366267

367268
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)