Skip to content

[Common] MXFP8 kernel for grouped tensors#2586

Merged
ptrendx merged 28 commits into
NVIDIA:mainfrom
Oleg-Goncharov:pr_mxfp8_grouped_kernel
Feb 6, 2026
Merged

[Common] MXFP8 kernel for grouped tensors#2586
ptrendx merged 28 commits into
NVIDIA:mainfrom
Oleg-Goncharov:pr_mxfp8_grouped_kernel

Conversation

@Oleg-Goncharov

@Oleg-Goncharov Oleg-Goncharov commented Jan 12, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR adds a new kernel that supports MXFP8 quantization of grouped tensors.

Below is a performance comparison of tensor-descriptor updates with O(log N) vs. O(N) complexity for varying numbers of descriptors (N = 2, 4, 8, …, 64). The input grouped tensors are N × [256, 8192]. Run on GB300.

Tensor update speedup

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added MXFP8 cast kernel for grouped tensors
  • Added the test suite

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_mxfp8_grouped_kernel branch 4 times, most recently from e6bf02a to fc2a53f Compare January 15, 2026 16:15
@Oleg-Goncharov Oleg-Goncharov added enhancement New feature or request MoE labels Jan 15, 2026
@ptrendx ptrendx linked an issue Jan 16, 2026 that may be closed by this pull request
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_mxfp8_grouped_kernel branch from 74a7917 to 88cf1b2 Compare January 21, 2026 17:00
pre-commit-ci Bot and others added 6 commits January 21, 2026 17:00
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_mxfp8_grouped_kernel branch from 7c4fda7 to 39bb24f Compare January 22, 2026 18:12
@Oleg-Goncharov Oleg-Goncharov marked this pull request as ready for review January 24, 2026 00:53
@greptile-apps

greptile-apps Bot commented Jan 24, 2026

Copy link
Copy Markdown
Contributor

Greptile Overview

Greptile Summary

This PR implements MXFP8 quantization for grouped tensors, adding a new GPU kernel that uses TMA (Tensor Memory Accelerator) descriptors for efficient data transfer and O(log N) binary search for tensor identification in grouped tensor scenarios.

Key Changes:

  • Added group_quantize_mxfp8.cuh with the core MXFP8 grouped quantization kernel supporting rowwise/columnwise scaling
  • Implemented TMA descriptor update mechanism with O(log N) complexity for varying tensor shapes
  • Extended C API in cast.h with 7 new grouped tensor quantization functions
  • Added grouped variants for activation functions (GeLU, ReLU, SiLU, QGeLU, SReLU) with dbias support
  • Comprehensive test suite with reference implementation covering multiple shape representations

Issues from Previous Comments:
Previous review threads identified several concerns that appear to remain unresolved:

  • Binary search underflow risk when current_offset < offsets_ptr[0] (line 85)
  • Uninitialized shape_rep variable if none of the four shape conditions match (line 755-764)
  • Commented-out code creating ambiguity in switch statement (line 104)
  • Typos: "gropued" instead of "grouped" in multiple locations in cast.h
  • Missing newline after conditional compilation block

Architecture:
The implementation uses a two-kernel approach: first updating TMA descriptors per tensor, then launching the main quantization kernel that uses binary search to identify which tensor each block processes. The kernel supports both single-tensor (constant last dimension) and multi-tensor cases with different optimization paths.

Confidence Score: 3/5

  • This PR requires addressing several logic issues before merging, particularly around variable initialization and edge case handling
  • Score reflects substantial new functionality with comprehensive tests, but multiple unresolved concerns from previous review including potential binary search underflow, uninitialized variables, and typos that need to be addressed
  • Primary attention needed on transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh for initialization and edge case issues, and transformer_engine/common/include/transformer_engine/cast.h for typo corrections

Important Files Changed

Filename Overview
transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh New MXFP8 grouped quantization kernel with TMA descriptors and binary search for tensor identification; previous comments identified potential issues with uninitialized variables and binary search underflow
tests/cpp/operator/test_cast_mxfp8_grouped.cu Comprehensive test suite with reference implementation and multiple shape representations for grouped tensor quantization
transformer_engine/common/cast/dispatch/quantize.cuh Added dispatcher helpers for grouped tensor quantization following existing pattern for regular tensors
transformer_engine/common/include/transformer_engine/cast.h Added C API declarations for grouped tensor quantization; typos reported in previous comments

Sequence Diagram

sequenceDiagram
    participant User
    participant API as C API Layer<br/>(cast.cu)
    participant Dispatcher as Dispatch Layer<br/>(quantize.cuh)
    participant Kernel as MXFP8 Kernel<br/>(group_quantize_mxfp8.cuh)
    participant GPU as GPU Device

    User->>API: nvte_group_quantize(input, output, stream)
    API->>Dispatcher: group_quantize_fwd_helper()
    Dispatcher->>Dispatcher: Convert NVTEGroupedTensor to GroupedTensor*
    Dispatcher->>Dispatcher: Check scaling_mode == NVTE_MXFP8_1D_SCALING
    
    alt Multi-tensor case (VARYING_LAST_DIM or VARYING_BOTH_DIMS)
        Dispatcher->>Kernel: update_tma_descriptors<<<num_tensors, 32>>>()
        Kernel->>GPU: Launch descriptor update kernel
        loop For each tensor in group
            GPU->>GPU: modify_base_tensor_map()<br/>Update tensor map for each tensor's data pointer
        end
        GPU-->>Kernel: TMA descriptors updated
    end
    
    Dispatcher->>Kernel: group_quantize_mxfp8_kernel<<<blocks, 128>>>()
    Kernel->>GPU: Launch main quantization kernel
    
    loop For each block
        GPU->>GPU: get_current_tensor_id()<br/>Binary search to find tensor ID
        GPU->>GPU: Acquire TMA fence for tensor map
        GPU->>GPU: TMA load input data to shared memory
        
        alt COLWISE_SCALING
            GPU->>GPU: Compute column-wise AMAX
            GPU->>GPU: Generate E8M0 scale factor
            GPU->>GPU: Quantize to MXFP8 with column-wise scale
            GPU->>GPU: TMA store to global memory
        end
        
        alt ROWWISE_SCALING
            GPU->>GPU: Compute row-wise AMAX
            GPU->>GPU: Generate E8M0 scale factor
            GPU->>GPU: Quantize to MXFP8 with row-wise scale
            GPU->>GPU: TMA store to global memory
        end
    end
    
    GPU-->>Dispatcher: Quantization complete
    Dispatcher-->>API: Return
    API-->>User: Return
Loading

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

Comment thread transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh Outdated
Comment thread transformer_engine/common/include/transformer_engine/cast.h Outdated
Comment thread transformer_engine/common/include/transformer_engine/cast.h Outdated
Comment thread transformer_engine/common/include/transformer_engine/cast.h Outdated
Comment thread transformer_engine/common/include/transformer_engine/cast.h Outdated
Comment thread transformer_engine/common/include/transformer_engine/cast.h Outdated
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment thread transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh Outdated
Comment thread transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh Outdated
Comment thread transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh
const __grid_constant__ CUtensorMap tensor_map_act_input_static,
const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static,
const __grid_constant__ CUtensorMap tensor_map_output_colwise_static,
const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t first_logical_dim,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is having it as a regular parameter not impacting the performance?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it = shape_rep

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven’t measured the performance impact, but it should be very small since it’s only used during initialization and isn’t on the critical path

Comment thread transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh
Comment on lines +790 to +791
NVTE_CHECK(last_logical_dim % 128 == 0,
"Last dimension of a grouped tensor should be divisible by 128.");

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need that? I think we only need that if we want columnwise scaling, no?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially assumed a full 128×128 tile input, but we can relax this restriction for a single-tensor view with a simple change. The input/output alignment is validated when the tensor descriptor is created. However, we need special care when the last dimension varies across inputs (i.e., when it can’t be viewed as a single tensor). In that case, we should validate alignment when updating the tensor descriptors in the helper kernel and raise an error if the data is not aligned.

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_mxfp8_grouped_kernel branch from d2621c4 to e9ddde1 Compare February 4, 2026 11:44

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@vthumbe1503

Copy link
Copy Markdown
Collaborator

/te-ci L1 pytorch

@vthumbe1503

Copy link
Copy Markdown
Collaborator

@Oleg-Goncharov, I have tested grouped_quantize from a Pytorch binding created for nvte_grouped_quantize and it works fine for all four cases of (first_dims, last_dims). And the changes in the PR look ok to me based on what I could understand. Could we merge this @Oleg-Goncharov @ptrendx ?

@vthumbe1503

Copy link
Copy Markdown
Collaborator

/te-ci pytorch

@vthumbe1503

Copy link
Copy Markdown
Collaborator

/te-ci pytorch

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@Oleg-Goncharov

Copy link
Copy Markdown
Collaborator Author

Thank you for checking it, @vthumbe1503. I’m working on extending dbias support to output a split grouped tensor, since the kernel currently accumulates dbias into a single tensor. Let’s merge this once that’s in.

@ptrendx ptrendx merged commit 7393947 into NVIDIA:main Feb 6, 2026
21 of 24 checks passed
@ptrendx

ptrendx commented Feb 6, 2026

Copy link
Copy Markdown
Member

@Oleg-Goncharov Please open a new PR with the proper dbias support - let's try to minimize the review effort.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request MoE

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Quantization support for GroupedTensor: MXFP8

3 participants