[Common] MXFP8 kernel for grouped tensors#2586
Conversation
e6bf02a to
fc2a53f
Compare
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
74a7917 to
88cf1b2
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
7c4fda7 to
39bb24f
Compare
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR implements MXFP8 quantization for grouped tensors, adding a new GPU kernel that uses TMA (Tensor Memory Accelerator) descriptors for efficient data transfer and O(log N) binary search for tensor identification in grouped tensor scenarios. Key Changes:
Issues from Previous Comments:
Architecture: Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant API as C API Layer<br/>(cast.cu)
participant Dispatcher as Dispatch Layer<br/>(quantize.cuh)
participant Kernel as MXFP8 Kernel<br/>(group_quantize_mxfp8.cuh)
participant GPU as GPU Device
User->>API: nvte_group_quantize(input, output, stream)
API->>Dispatcher: group_quantize_fwd_helper()
Dispatcher->>Dispatcher: Convert NVTEGroupedTensor to GroupedTensor*
Dispatcher->>Dispatcher: Check scaling_mode == NVTE_MXFP8_1D_SCALING
alt Multi-tensor case (VARYING_LAST_DIM or VARYING_BOTH_DIMS)
Dispatcher->>Kernel: update_tma_descriptors<<<num_tensors, 32>>>()
Kernel->>GPU: Launch descriptor update kernel
loop For each tensor in group
GPU->>GPU: modify_base_tensor_map()<br/>Update tensor map for each tensor's data pointer
end
GPU-->>Kernel: TMA descriptors updated
end
Dispatcher->>Kernel: group_quantize_mxfp8_kernel<<<blocks, 128>>>()
Kernel->>GPU: Launch main quantization kernel
loop For each block
GPU->>GPU: get_current_tensor_id()<br/>Binary search to find tensor ID
GPU->>GPU: Acquire TMA fence for tensor map
GPU->>GPU: TMA load input data to shared memory
alt COLWISE_SCALING
GPU->>GPU: Compute column-wise AMAX
GPU->>GPU: Generate E8M0 scale factor
GPU->>GPU: Quantize to MXFP8 with column-wise scale
GPU->>GPU: TMA store to global memory
end
alt ROWWISE_SCALING
GPU->>GPU: Compute row-wise AMAX
GPU->>GPU: Generate E8M0 scale factor
GPU->>GPU: Quantize to MXFP8 with row-wise scale
GPU->>GPU: TMA store to global memory
end
end
GPU-->>Dispatcher: Quantization complete
Dispatcher-->>API: Return
API-->>User: Return
|
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
| const __grid_constant__ CUtensorMap tensor_map_act_input_static, | ||
| const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static, | ||
| const __grid_constant__ CUtensorMap tensor_map_output_colwise_static, | ||
| const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t first_logical_dim, |
There was a problem hiding this comment.
Is having it as a regular parameter not impacting the performance?
There was a problem hiding this comment.
I haven’t measured the performance impact, but it should be very small since it’s only used during initialization and isn’t on the critical path
| NVTE_CHECK(last_logical_dim % 128 == 0, | ||
| "Last dimension of a grouped tensor should be divisible by 128."); |
There was a problem hiding this comment.
Do we need that? I think we only need that if we want columnwise scaling, no?
There was a problem hiding this comment.
I initially assumed a full 128×128 tile input, but we can relax this restriction for a single-tensor view with a simple change. The input/output alignment is validated when the tensor descriptor is created. However, we need special care when the last dimension varies across inputs (i.e., when it can’t be viewed as a single tensor). In that case, we should validate alignment when updating the tensor descriptors in the helper kernel and raise an error if the data is not aligned.
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
d2621c4 to
e9ddde1
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
|
@Oleg-Goncharov, I have tested grouped_quantize from a Pytorch binding created for nvte_grouped_quantize and it works fine for all four cases of (first_dims, last_dims). And the changes in the PR look ok to me based on what I could understand. Could we merge this @Oleg-Goncharov @ptrendx ? |
|
/te-ci pytorch |
|
/te-ci pytorch |
|
Thank you for checking it, @vthumbe1503. I’m working on extending |
|
@Oleg-Goncharov Please open a new PR with the proper dbias support - let's try to minimize the review effort. |
Description
This PR adds a new kernel that supports MXFP8 quantization of grouped tensors.
Below is a performance comparison of tensor-descriptor updates with O(log N) vs. O(N) complexity for varying numbers of descriptors (N = 2, 4, 8, …, 64). The input grouped tensors are
N × [256, 8192]. Run on GB300.Type of change
Changes
Checklist: