Skip to content

Add moe topk softmax templated from vllm#4302

Merged
zhyncs merged 13 commits intosgl-project:mainfrom
qingquansong:qsong/moe-topk-softmax
Mar 14, 2025
Merged

Add moe topk softmax templated from vllm#4302
zhyncs merged 13 commits intosgl-project:mainfrom
qingquansong:qsong/moe-topk-softmax

Conversation

@qingquansong
Copy link
Copy Markdown
Collaborator

@qingquansong qingquansong commented Mar 11, 2025

Motivation

#2965

Modifications

  1. Cherry picked current vllm MoE topk softmax kernel template (with a fix on naming typo for token_expert_indices)
  2. Polish util func warpReduceMax / blockReduceMax for handle AMD use case as well.

Tests

Unit tests + benchmarking aligned with vllm counterpart

Checklist

@qingquansong qingquansong marked this pull request as draft March 11, 2025 08:26
@hebiao064 hebiao064 mentioned this pull request Mar 11, 2025
18 tasks
@qingquansong qingquansong force-pushed the qsong/moe-topk-softmax branch 11 times, most recently from b70667a to 6200f4a Compare March 12, 2025 07:11
@qingquansong qingquansong force-pushed the qsong/moe-topk-softmax branch 2 times, most recently from 2786ad3 to acd4fb7 Compare March 13, 2025 04:34
@qingquansong qingquansong force-pushed the qsong/moe-topk-softmax branch from acd4fb7 to 97f4eb0 Compare March 13, 2025 04:35
@qingquansong qingquansong changed the title add moe topk softmax templated from vllm to improve Add moe topk softmax templated from vllm Mar 13, 2025
@qingquansong qingquansong marked this pull request as ready for review March 13, 2025 21:59
Comment thread sgl-kernel/include/utils.h Outdated
Comment thread sgl-kernel/include/utils.h Outdated
@qingquansong qingquansong force-pushed the qsong/moe-topk-softmax branch 2 times, most recently from d1b7bb2 to 6b28b88 Compare March 14, 2025 03:28
@qingquansong qingquansong force-pushed the qsong/moe-topk-softmax branch from 6b28b88 to 2fa0db7 Compare March 14, 2025 03:29
Comment thread sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu
@qingquansong qingquansong force-pushed the qsong/moe-topk-softmax branch from c77f195 to c15a211 Compare March 14, 2025 18:37
@qingquansong qingquansong requested review from BBuf and hebiao064 March 14, 2025 18:41
@zhyncs zhyncs merged commit 61e4433 into sgl-project:main Mar 14, 2025
@yiakwy-xpu-ml-framework-team
Copy link
Copy Markdown
Contributor

Hi @qingquansong

once this PR #4432 is merged

#ifndef USE_ROCM
  #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
    __shfl_xor_sync(uint32_t(-1), var, lane_mask)
  #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
    __shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
#else
  #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
  #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
    __shfl_xor(var, lane_mask, width)
#endif

you can use __shfl_xor_sync directly. hence no need to have these lines.

Does it sound good to you ?

max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 16));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The modification to these function will be revert in #4432

cc @zhyncs

#else
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor((var), (lane_mask))
#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width))
#endif
Copy link
Copy Markdown
Contributor

@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team Mar 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will keep these lines since you have the ROCM specific macro (since CUDA operation is no longer safe if we employ this approach) in many places. But #4432 is merged. The macro is no longer needed.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I'll change those and remove the definition.


const int thread_row_offset = blockIdx.x * num_cols;

cub::Sum sum;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hipCUB is experimental one, we can try it. But it introduces new dependencies.

Could just use some simple reduction kernel ?

Copy link
Copy Markdown
Collaborator Author

@qingquansong qingquansong Mar 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely. We can change to customized reductions for both max and sum. I'll do it together with the macro change in a follow-up pr. How about the following one? I can test the correctness on CUDA and may need your help for AMD machine testing.

__device__ __forceinline__ float warpReduceSum(float sum_value) {
  sum_value += __shfl_xor_sync(0xffffffff, sum_value, 16);
  sum_value += __shfl_xor_sync(0xffffffff, sum_value, 8);
  sum_value += __shfl_xor_sync(0xffffffff, sum_value, 4);
  sum_value += __shfl_xor_sync(0xffffffff, sum_value, 2);
  sum_value += __shfl_xor_sync(0xffffffff, sum_value, 1);
  return sum_value;
}

__device__ __forceinline__ float blockReduceSum(float sum_value) {
  static __shared__ float warpLevelSums[WARP_SIZE];
  const int laneId = threadIdx.x % WARP_SIZE;
  const int warpId = threadIdx.x / WARP_SIZE;

  sum_value = warpReduceSum(sum_value);

  if (laneId == 0) warpLevelSums[warpId] = sum_value;
  __syncthreads();

  sum_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelSums[laneId] : 0;
  if (warpId == 0) sum_value = warpReduceSum(sum_value);

  return sum_value;
}


Copy link
Copy Markdown
Contributor

@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team Mar 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resovled in #4448

Copy link
Copy Markdown
Contributor

@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team Mar 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But also I recommend to use shlf_xor based implementation. The old solution from fasterTransformer(later incorporated into TRT-LLM) uses heavily shared memory for reduction:

const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());

WIth shlf_xor based implementation, then you can get better result.

@qingquansong

@qingquansong
Copy link
Copy Markdown
Collaborator Author

qingquansong commented Mar 15, 2025

Hi @qingquansong

once this PR #4432 is merged

#ifndef USE_ROCM
  #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
    __shfl_xor_sync(uint32_t(-1), var, lane_mask)
  #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
    __shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
#else
  #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
  #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
    __shfl_xor(var, lane_mask, width)
#endif

you can use __shfl_xor_sync directly. hence no need to have these lines.

Does it sound good to you ?

Sounds great! I'll remove the marco definition and change back to use __shfl_xor_sync directly

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants