Skip to content

feat: FP8 groupwise scaling along M#1

Closed
soundOfDestiny wants to merge 1 commit intomanishucsd:f8_blockwise_scaling_pr_branchfrom
soundOfDestiny:f8_blockwise_scaling_pr_branch
Closed

feat: FP8 groupwise scaling along M#1
soundOfDestiny wants to merge 1 commit intomanishucsd:f8_blockwise_scaling_pr_branchfrom
soundOfDestiny:f8_blockwise_scaling_pr_branch

Conversation

@soundOfDestiny
Copy link

@soundOfDestiny soundOfDestiny commented Dec 12, 2024

Summary

As NVIDIA#1932 adds blockwise scaling strategy, this PR is a patch based on NVIDIA#1932 and adds groupwise scaling strategy along M in A tensor. Scaling granularity along M is made independent of CTA Block configuration, however, scaling granularities along N and K are still blockwise (i.e. one scaling value per CTA Block).

This PR restricts scaling granularity along M to a factor of TILE_SHAPE_M in CTA Block configuration, while one can set the GEMM scaling granularity along M to exactly TILE_SHAPE_M (i.e. fallback to blockwise scaling strategy) and call repeat_interleave method on input tensor ScaleA to simulate the situation that scaling granularity is multiplies of TILE_SHAPE_M.

Groupwise Scaling

In this implementation, we load scaling tensors with more elements than NVIDIA#1932 to shared memory since there might be various scaling along M per CTA Block. However, each thread only needs to load at most 2 scale values for A tensor and exactly one scale value for B tensor from shared memory to registers per iteration because WGMMA accumulators of each thread involve only 2 rows in result tensor.

Performance

I haven't observed a performance degradation compared with NVIDIA#1932
blockwise scaling

./64_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling 
  Disposition: Passed
  Problem Size: 1024x512x1024x1
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0112583 ms
  GFLOPS: 95373.3

groupwise scaling (this PR, setting scaling granularity along M to 64)

./64_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling 
  Disposition: Passed
  Problem Size: 1024x512x1024x1
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0112435 ms
  GFLOPS: 95499.3

Background (copied from NVIDIA#1932)

  1. Tensorwise Scaling: Uses a single scaling factor per tensor, applied in the epilogue.
  2. Rowwise Scaling: Uses a row vector for scaling, with dimensions Mx1 for operand A and 1xN for operand B, avoiding the scaling along the reduction dimension. This can also be handled in the epilogue with EpilogueVisitorTree.
  3. Blockwise Scaling (Blockwise Scaling for FP8 NVIDIA/cutlass#1932): Introduces a 2D scaling tensor, assigning one scaling value per CTA Block. Since this scaling involves the reduction dimension (M, N, K), it must be applied during the mainloop, impacting performance. This PR implements blockwise scaling for CUTLASS F8 GEMM, staging scaling tensors via shared memory, and preparing for future support of groupwise scaling.
  4. Groupwise Scaling (this diff, along M in A tensor): Uses a 2D scaling tensor with multiple scaling values per CTA Block. Scaling granularity is independent of CTA Block configuration, allowing greater flexibility for future implementations.

@soundOfDestiny soundOfDestiny changed the title FP8 groupwise scaling along M feat: FP8 groupwise scaling along M Dec 12, 2024
@manishucsd manishucsd force-pushed the f8_blockwise_scaling_pr_branch branch 3 times, most recently from 6834abc to 5ddebb9 Compare December 27, 2024 17:20
@soundOfDestiny soundOfDestiny closed this by deleting the head repository Jan 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant