Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sgl-kernel/csrc/common_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Tensor output_scale_offset_by_experts) -> ()");
m.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);

m.def(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts, Tensor mask) -> ()");
m.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA, &silu_and_mul_scaled_fp4_experts_quant);

m.def(
"cutlass_fp4_group_mm(Tensor! output, Tensor a, Tensor b,"
"Tensor a_blockscale, Tensor b_blockscale, Tensor alphas,"
Expand Down
172 changes: 160 additions & 12 deletions sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,33 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
#endif
}

__device__ __forceinline__ float silu(const float& val) {
return val / (1.0f + __expf(-val));
}

template <class Type>
inline __device__ void silu_and_mul(PackedVec<Type>& x_vec, const PackedVec<Type>& y_vec) {
float2 x[CVT_FP4_ELTS_PER_THREAD / 2];
float2 y[CVT_FP4_ELTS_PER_THREAD / 2];

#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) {
x[i] = __half22float2(x_vec.elts[i]);
y[i] = __half22float2(y_vec.elts[i]);
x[i].x = silu(x[i].x) * y[i].x;
x[i].y = silu(x[i].y) * y[i].y;
x_vec.elts[i] = __float22half2_rn(x[i]);
} else {
x[i] = __bfloat1622float2(x_vec.elts[i]);
y[i] = __bfloat1622float2(y_vec.elts[i]);
x[i].x = silu(x[i].x) * y[i].x;
x[i].y = silu(x[i].y) * y[i].y;
x_vec.elts[i] = __float22bfloat162_rn(x[i]);
}
}
}

// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
__global__ void
Expand All @@ -255,6 +282,7 @@ cvt_fp16_to_fp4(
uint32_t* SFout,
uint32_t* input_offset_by_experts,
uint32_t* output_scale_offset_by_experts,
int32_t* mask,
int n_experts,
bool low_latency) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
Expand All @@ -265,20 +293,18 @@ cvt_fp16_to_fp4(
// Input tensor row/col loops.
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
// TODO(kaixih@nvidia): For now, we assume mask is used together with
// silu_and_mal. Maybe we want a more general behavior of mask later. In the
// silu case, the input last dim doubles.
bool use_mask = mask != nullptr;
int actualColsPerRow = use_mask ? colsPerRow * 2 : colsPerRow;

// Each global thread processes one element
for (int globalIdx = tid; globalIdx < numRows * colsPerRow; globalIdx += gridDim.x * blockDim.x) {
// Calculate which row and column this global thread should process
int rowIdx = globalIdx / colsPerRow;
int colIdx = globalIdx % colsPerRow;

int64_t inOffset = rowIdx * colsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];

// Find index within the experts using different strategies based on expert
// count
int rowIdx_in_expert = 0;
Expand Down Expand Up @@ -321,6 +347,23 @@ cvt_fp16_to_fp4(
}
}

// Eerly exit when using masks.
if (use_mask && rowIdx_in_expert >= mask[expert_idx]) {
continue;
}

int64_t inOffset = rowIdx * actualColsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
if (use_mask) {
PackedVec in_vec_mul = reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
silu_and_mul(in_vec, in_vec_mul);
}

// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = rowIdx * colsPerRow + colIdx;
auto& out_pos = out[outOffset];

// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
Expand Down Expand Up @@ -356,6 +399,7 @@ cvt_fp16_to_fp4(
uint32_t* SFout,
uint32_t* input_offset_by_experts,
uint32_t* output_scale_offset_by_experts,
int32_t* mask,
int n_experts) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>;
Expand Down Expand Up @@ -383,18 +427,15 @@ cvt_fp16_to_fp4(

int tid = blockIdx.x * blockDim.x + threadIdx.x;
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
bool use_mask = mask != nullptr;
int actualColsPerRow = use_mask ? colsPerRow * 2 : colsPerRow;

// Each global thread processes one element
for (int globalIdx = tid; globalIdx < numRows * colsPerRow; globalIdx += gridDim.x * blockDim.x) {
// Calculate which row and column this global thread should process
int rowIdx = globalIdx / colsPerRow;
int colIdx = globalIdx % colsPerRow;

int64_t inOffset = rowIdx * colsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];

// Find expert using binary search for better performance with large m_topk
int rowIdx_in_expert = 0;
int expert_idx = 0;
Expand All @@ -419,6 +460,21 @@ cvt_fp16_to_fp4(
}
}

if (use_mask && rowIdx_in_expert >= mask[expert_idx]) {
continue;
}

int64_t inOffset = rowIdx * actualColsPerRow + colIdx;

PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
if (use_mask) {
PackedVec in_vec_mul = reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
silu_and_mul(in_vec, in_vec_mul);
}

int64_t outOffset = rowIdx * colsPerRow + colIdx;
auto& out_pos = out[outOffset];

float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];

int factor = CVT_FP4_SF_VEC_SIZE * 4;
Expand All @@ -442,6 +498,7 @@ void quant_impl(
void* input_global_scale,
void* input_offset_by_experts,
void* output_scale_offset_by_experts,
void* mask,
int m_topk,
int k,
int n_experts,
Expand Down Expand Up @@ -478,6 +535,7 @@ void quant_impl(
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
reinterpret_cast<int32_t*>(mask),
n_experts);
} else {
cvt_fp16_to_fp4<T, false, true><<<grid, block, shared_mem_size, stream>>>(
Expand All @@ -489,6 +547,7 @@ void quant_impl(
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
reinterpret_cast<int32_t*>(mask),
n_experts);
}
} else {
Expand All @@ -502,6 +561,7 @@ void quant_impl(
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
reinterpret_cast<int32_t*>(mask),
n_experts,
/* bool low_latency */ true);
} else {
Expand All @@ -514,6 +574,7 @@ void quant_impl(
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
reinterpret_cast<int32_t*>(mask),
n_experts,
/* bool low_latency */ true);
}
Expand Down Expand Up @@ -590,6 +651,92 @@ void scaled_fp4_experts_quant_sm100a(
input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(),
nullptr, // mask
m_topk,
k,
n_experts,
stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
quant_impl<__nv_bfloat16>(
output.data_ptr(),
output_scale.data_ptr(),
input.data_ptr(),
input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(),
nullptr, // mask
m_topk,
k,
n_experts,
stream);
} else {
TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
}
}

void silu_and_mul_scaled_fp4_experts_quant_sm100a(
torch::Tensor& output,
torch::Tensor& output_scale,
torch::Tensor const& input,
torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts,
torch::Tensor const& mask) {
CHECK_INPUT(output, "output must be a CUDA tensor");
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
CHECK_INPUT(input, "input must be a CUDA tensor");
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts must be a CUDA tensor");
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts must be a CUDA tensor");
CHECK_INPUT(mask, "mask must be a CUDA tensor");

TORCH_CHECK(output.dim() == 2);
TORCH_CHECK(output_scale.dim() == 2);
TORCH_CHECK(input.dim() == 2);
TORCH_CHECK(input_global_scale.dim() == 1);
TORCH_CHECK(input_offset_by_experts.dim() == 1);
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);

TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
TORCH_CHECK(mask.scalar_type() == INT);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
TORCH_CHECK(output.scalar_type() == UINT8);
TORCH_CHECK(output_scale.scalar_type() == INT);

const int BLOCK_SIZE = 16;
auto m_topk = input.size(0);
auto k_by_2 = input.size(1);
TORCH_CHECK(k_by_2 % 2 == 0, "k must be a multiple of 2");
auto k = k_by_2 / 2;
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
auto n_experts = input_global_scale.size(0);
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(mask.size(0) == n_experts);
TORCH_CHECK(output.size(0) == m_topk);
TORCH_CHECK(output.size(1) == k / 2);
int scales_k = k / BLOCK_SIZE;
// 4 means the swizzle requirement by nvidia nvfp4.
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
// 4 means 4 fp8 values are packed into one int32
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);

auto in_dtype = input.dtype();
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device());
if (in_dtype == at::ScalarType::Half) {
quant_impl<half>(
output.data_ptr(),
output_scale.data_ptr(),
input.data_ptr(),
input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(),
mask.data_ptr(),
m_topk,
k,
n_experts,
Expand All @@ -602,6 +749,7 @@ void scaled_fp4_experts_quant_sm100a(
input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(),
mask.data_ptr(),
m_topk,
k,
n_experts,
Expand Down
24 changes: 24 additions & 0 deletions sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ void scaled_fp4_experts_quant_sm100a(
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);

void silu_and_mul_scaled_fp4_experts_quant_sm100a(
torch::Tensor& output,
torch::Tensor& output_scale,
torch::Tensor const& input,
torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts,
torch::Tensor const& mask);

#endif

void scaled_fp4_quant(
Expand All @@ -50,3 +59,18 @@ void scaled_fp4_experts_quant(
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel");
}

void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output,
torch::Tensor& output_scale,
torch::Tensor const& input,
torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts,
torch::Tensor const& mask) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return silu_and_mul_scaled_fp4_experts_quant_sm100a(
output, output_scale, input, input_global_scale, input_offset_by_experts, output_scale_offset_by_experts, mask);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel");
}
8 changes: 8 additions & 0 deletions sgl-kernel/include/sgl_kernel_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,14 @@ void scaled_fp4_experts_quant(
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);

void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output,
torch::Tensor& output_scale,
torch::Tensor const& input,
torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts,
torch::Tensor const& mask);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
Expand Down
2 changes: 2 additions & 0 deletions sgl-kernel/python/sgl_kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@
qserve_w4a8_per_chn_gemm,
qserve_w4a8_per_group_gemm,
scaled_fp4_experts_quant,
scaled_fp4_grouped_quant,
scaled_fp4_quant,
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8,
sgl_per_token_group_quant_int8,
sgl_per_token_quant_fp8,
shuffle_rows,
silu_and_mul_scaled_fp4_grouped_quant,
)
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
from sgl_kernel.kvcacheio import (
Expand Down
Loading
Loading