2020
2121import flashinfer
2222from flashinfer .gemm import fp8_blockscale_gemm_sm90
23- from flashinfer .testing .utils import per_token_cast_to_fp8 , per_block_cast_to_fp8
23+ from flashinfer .testing .utils import per_token_cast_to_fp8
2424from flashinfer .utils import (
2525 get_compute_capability ,
2626 has_flashinfer_jit_cache ,
2929from flashinfer .jit .gemm import gen_fp8_blockscale_gemm_sm90_module
3030
3131
32- def calc_diff (output : torch .Tensor , expected : torch .Tensor ) -> float :
33- """Calculate similarity difference using TensorRT-LLM's metric.
34-
35- Returns diff = 1 - sim, where sim = 2*<x,y> / (||x||² + ||y||²)
36- This is similar to cosine similarity but uses squared norms in denominator.
37-
38- diff < 0.001 corresponds to >99.9% similarity.
39- """
40- output_f64 = output .to (torch .float64 )
41- expected_f64 = expected .to (torch .float64 )
42- denominator = (output_f64 * output_f64 + expected_f64 * expected_f64 ).sum ()
43- sim = 2 * (output_f64 * expected_f64 ).sum () / denominator
44- diff = 1 - sim
45- return diff .item ()
46-
47-
4832@pytest .fixture (
4933 autouse = not has_flashinfer_jit_cache (),
5034 scope = "module" ,
@@ -141,10 +125,16 @@ def test_fp8_blockscale_gemm_dtypes(m, n, k, input_dtype, weight_dtype):
141125
142126 device = "cuda"
143127 torch .manual_seed (42 )
128+ fp8_info = torch .finfo (torch .float8_e4m3fn )
129+ fp8_max = fp8_info .max
144130
145131 # Create BF16 data for reference
146- input_bf16 = torch .randn (m , k , device = device , dtype = torch .bfloat16 )
147- weight_bf16 = torch .randn (n , k , device = device , dtype = torch .bfloat16 )
132+ input_bf16 = (
133+ (torch .rand (m , k , dtype = torch .bfloat16 , device = device ) - 0.5 ) * 2 * fp8_max
134+ )
135+ weight_bf16 = (
136+ (torch .rand (n , k , dtype = torch .bfloat16 , device = device ) - 0.5 ) * 2 * fp8_max
137+ )
148138
149139 # Quantize input
150140 if input_dtype == torch .float8_e4m3fn :
@@ -175,12 +165,7 @@ def test_fp8_blockscale_gemm_dtypes(m, n, k, input_dtype, weight_dtype):
175165 reference .flatten ().float (), output .flatten ().float (), dim = 0
176166 )
177167
178- if input_dtype == torch .bfloat16 and weight_dtype == torch .bfloat16 :
179- threshold = 0.99
180- else :
181- # BF16+FP8: BF16 input quantized internally, FP8 weight pre-quantized
182- # TODO: check threshold
183- threshold = 0.967
168+ threshold = 0.99
184169
185170 assert cos_sim > threshold , (
186171 f"Cosine similarity { cos_sim :.4f} too low for "
@@ -191,7 +176,8 @@ def test_fp8_blockscale_gemm_dtypes(m, n, k, input_dtype, weight_dtype):
191176@pytest .mark .parametrize ("m" , [7 , 32 , 128 ])
192177@pytest .mark .parametrize ("n" , [1024 , 4096 ])
193178@pytest .mark .parametrize ("k" , [512 , 4096 ])
194- def test_fp8_blockscale_gemm_w8a8 (m , n , k ):
179+ @pytest .mark .parametrize ("input_dtype" , [torch .bfloat16 , torch .float8_e4m3fn ])
180+ def test_fp8_blockscale_gemm_w8a8 (m , n , k , input_dtype ):
195181 """Test W8A8 (FP8+FP8) GEMM with per-token scales for both input and weight.
196182
197183 This test demonstrates full FP8 quantization for both activations and weights.
@@ -206,11 +192,17 @@ def test_fp8_blockscale_gemm_w8a8(m, n, k):
206192 device = "cuda"
207193 # m, n, k = 64, 2048, 4096
208194 torch .manual_seed (42 )
195+ fp8_info = torch .finfo (torch .float8_e4m3fn )
196+ fp8_max = fp8_info .max
209197
210198 # Create BF16 inputs for reference (no normalization)
211199 # Raw randn values work well with FP8 quantization without causing numerical issues
212- input_bf16 = torch .randn (m , k , device = device , dtype = torch .bfloat16 )
213- weight_bf16 = torch .randn (n , k , device = device , dtype = torch .bfloat16 )
200+ input_bf16 = (
201+ (torch .rand (m , k , dtype = torch .bfloat16 , device = device ) - 0.5 ) * 2 * fp8_max
202+ )
203+ weight_bf16 = (
204+ (torch .rand (n , k , dtype = torch .bfloat16 , device = device ) - 0.5 ) * 2 * fp8_max
205+ )
214206
215207 # Quantize both input and weight to FP8 with per-token (1x128) scales
216208 input_fp8 , input_scale = per_token_cast_to_fp8 (input_bf16 )
@@ -226,18 +218,32 @@ def test_fp8_blockscale_gemm_w8a8(m, n, k):
226218 assert input_scale .min () > 0 , "Input scale should be positive"
227219 assert weight_scale .min () > 0 , "Weight scale should be positive"
228220
229- # Run W8A8 GEMM: FP8 input + FP8 weight
230- output = fp8_blockscale_gemm_sm90 ( input_fp8 , weight_fp8 , input_scale , weight_scale )
221+ M_padded = (( m + 4 - 1 ) // 4 ) * 4 # Round M up to multiple of 4
222+ K_blocks = k // 128
231223
232- # Dequantize FP8 tensors to create reference (tests kernel correctness, not quantization)
233- # Dequant: bf16 = fp8.to(bf16) * scale (applied per 128-element block)
234- input_dequant = torch .zeros_like (input_bf16 )
235- for i in range (m ):
236- for k_tile in range (k // 128 ):
237- start , end = k_tile * 128 , (k_tile + 1 ) * 128
238- input_dequant [i , start :end ] = (
239- input_fp8 [i , start :end ].to (torch .bfloat16 ) * input_scale [i , k_tile ]
240- )
224+ if input_dtype == torch .float8_e4m3fn :
225+ # Create padded tensor with the stride TRT-LLM expects
226+ input_scale_padded = torch .zeros (
227+ K_blocks , M_padded , dtype = torch .float32 , device = device
228+ )
229+ input_scale_padded [:, :m ] = input_scale .T
230+ input_scale_padded = input_scale_padded [:, :m ]
231+
232+ output = fp8_blockscale_gemm_sm90 (
233+ input_fp8 , weight_fp8 , input_scale_padded , weight_scale
234+ )
235+ # Dequantize FP8 tensors to create reference (tests kernel correctness, not quantization)
236+ # Dequant: bf16 = fp8.to(bf16) * scale (applied per 128-element block)
237+ input_dequant = torch .zeros_like (input_bf16 )
238+ for i in range (m ):
239+ for k_tile in range (k // 128 ):
240+ start , end = k_tile * 128 , (k_tile + 1 ) * 128
241+ input_dequant [i , start :end ] = (
242+ input_fp8 [i , start :end ].to (torch .bfloat16 ) * input_scale [i , k_tile ]
243+ )
244+ else :
245+ output = fp8_blockscale_gemm_sm90 (input_bf16 , weight_fp8 , None , weight_scale )
246+ input_dequant = input_bf16
241247
242248 weight_dequant = torch .zeros_like (weight_bf16 )
243249 for j in range (n ):
@@ -253,115 +259,10 @@ def test_fp8_blockscale_gemm_w8a8(m, n, k):
253259 cos_sim = F .cosine_similarity (
254260 reference .flatten ().float (), output .flatten ().float (), dim = 0
255261 )
256- # W8A8 achieves ~97% cosine similarity against dequantized FP8 reference
257- assert cos_sim > 0.967 , (
258- f"W8A8 cosine similarity { cos_sim :.4f} too low (expected > 0.967)"
259- )
260-
261- print (f"✓ W8A8 (FP8+FP8): cosine similarity = { cos_sim :.4f} " )
262-
263- @pytest .mark .parametrize ("m" , [7 , 32 , 128 ])
264- @pytest .mark .parametrize ("n" , [1024 , 4096 ])
265- @pytest .mark .parametrize ("k" , [512 , 4096 ])
266- def test_fp8_blockscale_gemm_w8a8_with_quant_kernel (m , n , k ):
267- """Test W8A8 (FP8+FP8) GEMM with per-token scales for both input and weight.
268-
269- This test demonstrates full FP8 quantization for both activations and weights.
270- """
271- compute_capability = get_compute_capability (torch .device ("cuda" ))
272- if compute_capability [0 ] < 9 :
273- pytest .skip ("FP8 block-scale GEMM requires SM90 (Hopper) or later" )
274-
275- if not is_sm90a_supported (torch .device ("cuda" )):
276- pytest .skip ("FP8 block-scale GEMM requires SM90a (Hopper) support" )
277-
278- device = "cuda"
279- # m, n, k = 64, 2048, 4096
280- torch .manual_seed (42 )
281-
282- # Create BF16 inputs for reference (no normalization)
283- # Raw randn values work well with FP8 quantization without causing numerical issues
284- input_bf16 = torch .randn (m , k , device = device , dtype = torch .bfloat16 )
285- weight_bf16 = torch .randn (n , k , device = device , dtype = torch .bfloat16 )
286-
287- # Quantize both input and weight to FP8 with per-token (1x128) scales
288- weight_fp8 , weight_scale = per_token_cast_to_fp8 (weight_bf16 )
289262
290- # Verify scale shapes
291- assert weight_scale .shape == (n , k // 128 ), (
292- f"Expected weight scale shape ({ n } , { k // 128 } ), got { weight_scale .shape } "
293- )
294- assert weight_scale .min () > 0 , "Weight scale should be positive"
295-
296- # Run W8A8 GEMM: FP8 input + FP8 weight
297- output = fp8_blockscale_gemm_sm90 (input_bf16 , weight_fp8 , None , weight_scale )
298-
299- # Dequantize FP8 tensors to create reference (tests kernel correctness, not quantization)
300- # Dequant: bf16 = fp8.to(bf16) * scale (applied per 128-element block)
301- weight_dequant = torch .zeros_like (weight_bf16 )
302- for j in range (n ):
303- for k_tile in range (k // 128 ):
304- start , end = k_tile * 128 , (k_tile + 1 ) * 128
305- weight_dequant [j , start :end ] = (
306- weight_fp8 [j , start :end ].to (torch .bfloat16 ) * weight_scale [j , k_tile ]
307- )
308-
309- reference = torch .matmul (input_bf16 , weight_dequant .T )
310-
311- # Use cosine similarity (same metric as BF16+FP8 tests)
312- cos_sim = F .cosine_similarity (
313- reference .flatten ().float (), output .flatten ().float (), dim = 0
263+ assert cos_sim > 0.99 , (
264+ f"W8A8 cosine similarity { cos_sim :.4f} too low (expected > 0.99)"
314265 )
315- # W8A8 achieves ~97% cosine similarity against dequantized FP8 reference
316- assert cos_sim > 0.967 , (
317- f"W8A8 cosine similarity { cos_sim :.4f} too low (expected > 0.967)"
318- )
319-
320- print (f"✓ W8A8 (FP8+FP8): cosine similarity = { cos_sim :.4f} " )
321-
322-
323- def test_fp8_blockscale_gemm_per_block_weight_scales ():
324- """Test BF16+FP8 GEMM with per-block (128x128) weight scales.
325-
326- This test demonstrates using 128x128 block quantization for weights with BF16 input,
327- """
328- compute_capability = get_compute_capability (torch .device ("cuda" ))
329- if compute_capability [0 ] < 9 :
330- pytest .skip ("FP8 block-scale GEMM requires SM90 (Hopper) or later" )
331-
332- if not is_sm90a_supported (torch .device ("cuda" )):
333- pytest .skip ("FP8 block-scale GEMM requires SM90a (Hopper) support" )
334-
335- device = "cuda"
336- m , n , k = 16 , 512 , 512
337- torch .manual_seed (42 )
338-
339- # Create inputs
340- input_bf16 = torch .randn (m , k , device = device , dtype = torch .bfloat16 )
341- weight_bf16 = torch .randn (n , k , device = device , dtype = torch .bfloat16 )
342-
343- # Quantize weight with per-block (128x128) blocks
344- weight_fp8 , weight_scale = per_block_cast_to_fp8 (weight_bf16 )
345-
346- # Verify scale shape
347- assert weight_scale .shape == (n // 128 , k // 128 ), (
348- f"Expected weight scale shape ({ n // 128 } , { k // 128 } ), got { weight_scale .shape } "
349- )
350- assert weight_scale .min () > 0 , "Weight scale should be positive (reciprocal format)"
351-
352- # Run GEMM: BF16 input (internal quant) + FP8 weight (per-block scales)
353- output = fp8_blockscale_gemm_sm90 (input_bf16 , weight_fp8 , None , weight_scale )
354-
355- # Compare to BF16 reference
356- reference = torch .matmul (input_bf16 , weight_bf16 .T )
357-
358- cos_sim = F .cosine_similarity (
359- reference .flatten ().float (), output .flatten ().float (), dim = 0
360- )
361- # TODO: check threshold
362- assert cos_sim > 0.967 , f"Per-block weight scale accuracy too low: { cos_sim :.4f} "
363-
364- print (f"✓ Per-block weight scales: cosine similarity = { cos_sim :.4f} " )
365266
366267
367268@pytest .mark .parametrize (
0 commit comments