Skip to content

Add Permute/Unpermute Fusion Code Path#588

Merged
jershi425 merged 39 commits into
deepseek-ai:hybrid-epfrom
Autumn1998:tongliu_permute_fusion
Mar 31, 2026
Merged

Add Permute/Unpermute Fusion Code Path#588
jershi425 merged 39 commits into
deepseek-ai:hybrid-epfrom
Autumn1998:tongliu_permute_fusion

Conversation

@Autumn1998

@Autumn1998 Autumn1998 commented Mar 3, 2026

Copy link
Copy Markdown
Collaborator
  • Refactor
  • Add (un)permute fusion path
  • Update documentation
  • Refactor tests for cleaner output

@Autumn1998 Autumn1998 closed this Mar 5, 2026
@Autumn1998 Autumn1998 reopened this Mar 6, 2026

@Eliasthunderdog Eliasthunderdog left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some concerns about remote rank visibility. Please check the lines below in the comments.

uint32_t* last_chunk_flag_addr = intra_node_expert_output_chunk_flags[remote_rank_id] + last_chunk_global_chunk_id;
// Need a strong system-scope red to make sure the target ranks can observe the update of the flag,
// Notify last chunk.
asm volatile("red.relaxed.sys.global.add.u32 [%0], %1;"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to use "red.release" here to guarantee the chunk is visible when observing the flag update? Although cp_async_bulk_wait_group within the cache-coherent NVL can implicitly guarantee the ordering, I think we need a semantically correct way to do this.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thanks for the comment.
Yes, we need a release semantic here to make sure the flag is observed after the data chunk.
In fact, we already have it, you can see we explicitly call fence.release.sys; at Ln 1468 before we start notifying any flag, this instruction is an explicit release semantic memory fence, so it guarantee all the red.relaxed happens after the cp_async_bulk_wait_group(i.e. all notification happens after the data chunk is written to the global location) in system-scoped memory order. So it has the same effect as calling red.release here.
BTW, calling a red.release instruction or any memory intructions with .release semantics(such as st.release etc.) will be compiled into a sequence of fence.release.scope + mop.relaxed.scope(scope can be sys ,gpu etc. and mop can be red, st etc.), so there is NO difference by using fence.release.sys+red.relaxed.sys vs red.release.sys. As you can see the red.relaxed.sys here is within a loop, so if you are using red.release.sys here, then you will call fence.release.sys+red.relaxed.sys multiple time along with the loop which will kill all performance since fence.release.sys is a very very expensive operation and you want to call it as less time as possible. So instead of calling fence.release.sys+red.relaxed.sys within a loop, we will call fence.release.sys once then call red.relaxed.sys within a loop.
Even if we only call fence.release.sys once, it is still a very expensive operation, and according to our test, it will kill 50-100 GB/s NVLink bandwidth on Blackwell platform, so we want to avoid it whenever possible. According to our test, if the communcaition and sychronization happens only within a NVLink domain(i.e. no RDMA communication), w/o this memory fence can still get the corrent result. But if PCIe device such as RDMA NIC come into the communication, the result will corrupt w/o this memory fence. So, that's why we call this fence only when RDMA is enabled(i.e. HYBRID_EP_BUILD_MULTINODE_ENABLE is defined). Yes, for a well-defined memory behavior, we need the fence.release.sys here unconditionally, but as we are pursuing performance on NVLink system, we currently implement a correct(at least under our test), fast but undefined behavior here for NVLink scenario.

uint32_t* current_chunk_flag_addr = intra_node_expert_output_chunk_flags[remote_rank_id] + global_chunk_id;
// Need a strong system-scope red to make sure the target ranks can observe the update of the flag,
// Notify last chunk.
asm volatile("red.relaxed.sys.global.add.u32 [%0], %1;"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

red.release

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as Ln1482.

: "l"(__cvta_generic_to_global(last_chunk_flag_addr)), "n"(1)
: "memory");
// Notify current chunk.
asm volatile("red.relaxed.sys.global.add.u32 [%0], %1;"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

red.release

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as Ln1482.

uint32_t* last_chunk_flag_addr = intra_node_expert_output_chunk_flags[remote_rank_id] + last_chunk_global_chunk_id;
// Need a strong system-scope red to make sure the target ranks can observe the update of the flag,
// Notify last chunk.
asm volatile("red.relaxed.sys.global.add.u32 [%0], %1;"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

red.release

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as Ln1482.

do{
intra_node_chunk_flag = 0;
// Need a strong system-scope load to observe peer ranks' Atomic result.
asm volatile("ld.relaxed.sys.global.u32 %0, [%1];"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ld.acquire

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need a acquire semantic here to make sure the flag is observed before consuming the data chunk.
However, on current GPU arch, the acquire semantic will be compiled to mop.relaxed.scope + L1 dirty cache line flush(so no acquire fence actually. Here, the ld.acquire.sys will be compiled to ld.relaxed.sys + L1 dirty cache line flush), this will make sure the later memory operation after the flag polling will not accidentally load the stale data from L1 cache. This alone will make sure that if a memory operation can be observed when a flag polling(i.e. the ld.relaxed.sys) success, it can be observed by later memory operations.
Since the data chunk is totally consumed by TMA instructions after polling the flag, and TMA instrcutions bypass L1 cache entirely, so we actually don't need the L1 dirty cache line flush after the ld.relaxed.sys, so that's why we use ld.relaxed.sys instead of ld.acquire.sys here.
Again, the L1 dirty cache line flush is very very expensive, it may take thousands of cycles to finish. We want to avoid it at all cost. So, we currently use a arch-specific implementation here, not a well-defined implementation.

uint32_t* current_chunk_flag_addr = intra_node_expert_input_chunk_flags[current_rank_id] + current_flag_id;
// Need a strong system-scope red to make sure all ranks can observe the update of the flag,
// Notify last chunk.
asm volatile("red.relaxed.sys.global.add.u32 [%0], %1;"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

red.release

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as Ln1482.

: "l"(__cvta_generic_to_global(last_chunk_flag_addr)), "n"(1)
: "memory");
// Notify current chunk.
asm volatile("red.relaxed.sys.global.add.u32 [%0], %1;"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

red.release

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as Ln1482.

uint32_t* last_chunk_flag_addr = intra_node_expert_input_chunk_flags[last_chunk_rank_id] + last_chunk_global_chunk_id;
// Need a strong system-scope red to make sure all ranks can observe the update of the flag,
// Notify last chunk.
asm volatile("red.relaxed.sys.global.add.u32 [%0], %1;"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

red.release

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as Ln1482.

do{
intra_node_chunk_flag = 0;
// Need a strong system-scope load to observe peer ranks' Atomic result.
asm volatile("ld.relaxed.sys.global.u32 %0, [%1];"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ld.acquire

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as Ln1757.

do{
intra_node_chunk_flag = 0;
// Need a strong system-scope load to observe peer ranks' Atomic result.
asm volatile("ld.relaxed.sys.global.u32 %0, [%1];"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ld.acquire

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as Ln1757.

@CarpenterLee

Copy link
Copy Markdown

Hi @Autumn1998, could you share any performance comparison results with the target branch?

@Autumn1998

Autumn1998 commented Mar 25, 2026

Copy link
Copy Markdown
Collaborator Author

Hi @Autumn1998, could you share any performance comparison results with the target branch?

Hybrid_EP_NVL72.md
Here is some data on nvl72

I just submitted a bug fix that will cost us a slight performance regression. We haven't had time to update the detailed benchmarks yet, but to avoid blocking subsequent nixl PRs, I suggest this PR can be prioritized for merging.

@alpha-baby

alpha-baby commented Apr 14, 2026

Copy link
Copy Markdown
Contributor

why combine kernel performance reduce after merge this PR:

before:

[rank 0] HybridEP dispatch kernel(NVL) (BF16): 866.87 GB/s, avg_t=287.36 us | HybridEP combine kernel(NVL): 789.26 GB/s, avg_t=315.61 us

after

[rank 0] HybridEP dispatch kernel(NVL) (BF16): 867.18 GB/s, avg_t=287.26 us | HybridEP combine kernel(NVL): 476.63 GB/s, avg_t=522.63 us

before current PR:

[rank 0] Correctness check passed (BF16)
[rank 1] Correctness check passed (BF16)
[rank 2] Correctness check passed (BF16)
[rank 3] Correctness check passed (BF16)
[rank 0] HybridEP dispatch torch API (BF16): 548.75 GB/s (NVL), t: 453.94 us, nvl_recv_bytes: 249.10 MB
[rank 1] HybridEP dispatch torch API (BF16): 548.83 GB/s (NVL), t: 453.94 us, nvl_recv_bytes: 249.14 MB
[rank 2] HybridEP dispatch torch API (BF16): 550.67 GB/s (NVL), t: 454.44 us, nvl_recv_bytes: 250.25 MB
[rank 3] HybridEP dispatch torch API (BF16): 548.02 GB/s (NVL), t: 454.49 us, nvl_recv_bytes: 249.07 MB
[rank 0] HybridEP combine torch API: 600.77 GB/s (NVL), t: 414.64 us, combine_send_bytes: 249.10 MB
[rank 1] HybridEP combine torch API: 601.05 GB/s (NVL), t: 414.50 us, combine_send_bytes: 249.14 MB
[rank 2] HybridEP combine torch API: 603.04 GB/s (NVL), t: 414.98 us, combine_send_bytes: 250.25 MB
[rank 3] HybridEP combine torch API: 600.13 GB/s (NVL), t: 415.02 us, combine_send_bytes: 249.07 MB
[rank 0] HybridEP dispatch+permute torch API (BF16): 512.00 GB/s (NVL), t: 486.53 us, nvl_recv_bytes: 249.10 MB
[rank 1] HybridEP dispatch+permute torch API (BF16): 511.81 GB/s (NVL), t: 486.78 us, nvl_recv_bytes: 249.14 MB
[rank 2] HybridEP dispatch+permute torch API (BF16): 513.71 GB/s (NVL), t: 487.14 us, nvl_recv_bytes: 250.25 MB
[rank 3] HybridEP dispatch+permute torch API (BF16): 511.21 GB/s (NVL), t: 487.22 us, nvl_recv_bytes: 249.07 MB
[rank 0] HybridEP combine+unpermute torch API: 461.85 GB/s (NVL), t: 539.36 us, combine_send_bytes: 249.10 MB
[rank 1] HybridEP combine+unpermute torch API: 461.87 GB/s (NVL), t: 539.40 us, combine_send_bytes: 249.14 MB
[rank 2] HybridEP combine+unpermute torch API: 463.48 GB/s (NVL), t: 539.94 us, combine_send_bytes: 250.25 MB
[rank 3] HybridEP combine+unpermute torch API: 461.24 GB/s (NVL), t: 540.00 us, combine_send_bytes: 249.07 MB
[rank 0] HybridEP dispatch kernel(NVL) (BF16): 866.87 GB/s, avg_t=287.36 us | HybridEP combine kernel(NVL): 789.26 GB/s, avg_t=315.61 us
[rank 1] HybridEP dispatch kernel(NVL) (BF16): 869.53 GB/s, avg_t=286.52 us | HybridEP combine kernel(NVL): 800.22 GB/s, avg_t=311.33 us
[rank 2] HybridEP dispatch kernel(NVL) (BF16): 878.08 GB/s, avg_t=285.00 us | HybridEP combine kernel(NVL): 804.73 GB/s, avg_t=310.97 us
[rank 3] HybridEP dispatch kernel(NVL) (BF16): 840.67 GB/s, avg_t=296.27 us | HybridEP combine kernel(NVL): 799.29 GB/s, avg_t=311.61 us
[rank 0] Correctness check passed (FP8)
[rank 1] Correctness check passed (FP8)
[rank 2] Correctness check passed (FP8)
[rank 3] Correctness check passed (FP8)
[rank 0] HybridEP dispatch torch API (FP8): 424.78 GB/s (NVL), t: 303.81 us, nvl_recv_bytes: 129.05 MB
[rank 1] HybridEP dispatch torch API (FP8): 421.48 GB/s (NVL), t: 303.79 us, nvl_recv_bytes: 128.04 MB
[rank 2] HybridEP dispatch torch API (FP8): 422.80 GB/s (NVL), t: 304.27 us, nvl_recv_bytes: 128.65 MB
[rank 3] HybridEP dispatch torch API (FP8): 423.07 GB/s (NVL), t: 304.24 us, nvl_recv_bytes: 128.71 MB
[rank 0] HybridEP combine torch API: 605.81 GB/s (NVL), t: 413.14 us, combine_send_bytes: 250.28 MB
[rank 1] HybridEP combine torch API: 600.93 GB/s (NVL), t: 413.22 us, combine_send_bytes: 248.32 MB
[rank 2] HybridEP combine torch API: 603.02 GB/s (NVL), t: 413.74 us, combine_send_bytes: 249.50 MB
[rank 3] HybridEP combine torch API: 603.30 GB/s (NVL), t: 413.77 us, combine_send_bytes: 249.63 MB
[rank 0] HybridEP dispatch+permute torch API (FP8): 392.13 GB/s (NVL), t: 329.11 us, nvl_recv_bytes: 129.05 MB
[rank 1] HybridEP dispatch+permute torch API (FP8): 388.96 GB/s (NVL), t: 329.18 us, nvl_recv_bytes: 128.04 MB
[rank 2] HybridEP dispatch+permute torch API (FP8): 390.43 GB/s (NVL), t: 329.50 us, nvl_recv_bytes: 128.65 MB
[rank 3] HybridEP dispatch+permute torch API (FP8): 390.55 GB/s (NVL), t: 329.57 us, nvl_recv_bytes: 128.71 MB
[rank 0] HybridEP combine+unpermute torch API: 465.14 GB/s (NVL), t: 538.08 us, combine_send_bytes: 250.28 MB
[rank 1] HybridEP combine+unpermute torch API: 461.39 GB/s (NVL), t: 538.19 us, combine_send_bytes: 248.32 MB
[rank 2] HybridEP combine+unpermute torch API: 463.25 GB/s (NVL), t: 538.57 us, combine_send_bytes: 249.50 MB
[rank 3] HybridEP combine+unpermute torch API: 463.49 GB/s (NVL), t: 538.58 us, combine_send_bytes: 249.63 MB
[rank 0] HybridEP dispatch kernel(NVL) (FP8): 810.44 GB/s, avg_t=159.24 us | HybridEP combine kernel(NVL): 796.52 GB/s, avg_t=314.22 us
[rank 1] HybridEP dispatch kernel(NVL) (FP8): 807.32 GB/s, avg_t=158.60 us | HybridEP combine kernel(NVL): 799.91 GB/s, avg_t=310.43 us
[rank 2] HybridEP dispatch kernel(NVL) (FP8): 800.83 GB/s, avg_t=160.64 us | HybridEP combine kernel(NVL): 796.92 GB/s, avg_t=313.07 us
[rank 3] HybridEP dispatch kernel(NVL) (FP8): 761.48 GB/s, avg_t=169.03 us | HybridEP combine kernel(NVL): 802.90 GB/s, avg_t=310.90 us

after current PR:

[rank 0] Correctness check passed (BF16)
[rank 1] Correctness check passed (BF16)
[rank 2] Correctness check passed (BF16)
[rank 3] Correctness check passed (BF16)
[rank 0] HybridEP dispatch torch API (BF16): 544.91 GB/s (NVL), t: 457.15 us, nvl_recv_bytes: 249.10 MB
[rank 1] HybridEP dispatch torch API (BF16): 545.03 GB/s (NVL), t: 457.10 us, nvl_recv_bytes: 249.14 MB
[rank 2] HybridEP dispatch torch API (BF16): 546.80 GB/s (NVL), t: 457.66 us, nvl_recv_bytes: 250.25 MB
[rank 3] HybridEP dispatch torch API (BF16): 544.24 GB/s (NVL), t: 457.64 us, nvl_recv_bytes: 249.07 MB
[rank 0] HybridEP combine torch API: 387.94 GB/s (NVL), t: 642.12 us, combine_send_bytes: 249.10 MB
[rank 1] HybridEP combine torch API: 387.98 GB/s (NVL), t: 642.14 us, combine_send_bytes: 249.14 MB
[rank 2] HybridEP combine torch API: 389.45 GB/s (NVL), t: 642.58 us, combine_send_bytes: 250.25 MB
[rank 3] HybridEP combine torch API: 387.60 GB/s (NVL), t: 642.59 us, combine_send_bytes: 249.07 MB
[rank 0] HybridEP dispatch+permute torch API (BF16): 498.28 GB/s (NVL), t: 499.92 us, nvl_recv_bytes: 249.10 MB
[rank 1] HybridEP dispatch+permute torch API (BF16): 498.28 GB/s (NVL), t: 499.99 us, nvl_recv_bytes: 249.14 MB
[rank 2] HybridEP dispatch+permute torch API (BF16): 500.15 GB/s (NVL), t: 500.35 us, nvl_recv_bytes: 250.25 MB
[rank 3] HybridEP dispatch+permute torch API (BF16): 497.77 GB/s (NVL), t: 500.37 us, nvl_recv_bytes: 249.07 MB
[rank 0] HybridEP combine+unpermute torch API: 325.00 GB/s (NVL), t: 766.46 us, combine_send_bytes: 249.10 MB
[rank 1] HybridEP combine+unpermute torch API: 325.04 GB/s (NVL), t: 766.47 us, combine_send_bytes: 249.14 MB
[rank 2] HybridEP combine+unpermute torch API: 326.28 GB/s (NVL), t: 766.98 us, combine_send_bytes: 250.25 MB
[rank 3] HybridEP combine+unpermute torch API: 324.77 GB/s (NVL), t: 766.91 us, combine_send_bytes: 249.07 MB
[rank 0] HybridEP dispatch kernel(NVL) (BF16): 867.18 GB/s, avg_t=287.26 us | HybridEP combine kernel(NVL): 476.63 GB/s, avg_t=522.63 us
[rank 1] HybridEP dispatch kernel(NVL) (BF16): 870.55 GB/s, avg_t=286.18 us | HybridEP combine kernel(NVL): 483.91 GB/s, avg_t=514.84 us
[rank 2] HybridEP dispatch kernel(NVL) (BF16): 881.76 GB/s, avg_t=283.81 us | HybridEP combine kernel(NVL): 464.21 GB/s, avg_t=539.09 us
[rank 3] HybridEP dispatch kernel(NVL) (BF16): 838.93 GB/s, avg_t=296.89 us | HybridEP combine kernel(NVL): 482.56 GB/s, avg_t=516.14 us
[rank 0] Correctness check passed (FP8)
[rank 1] Correctness check passed (FP8)
[rank 2] Correctness check passed (FP8)
[rank 3] Correctness check passed (FP8)
[rank 0] HybridEP dispatch torch API (FP8): 420.36 GB/s (NVL), t: 307.01 us, nvl_recv_bytes: 129.05 MB
[rank 1] HybridEP dispatch torch API (FP8): 416.59 GB/s (NVL), t: 307.34 us, nvl_recv_bytes: 128.04 MB
[rank 2] HybridEP dispatch torch API (FP8): 418.40 GB/s (NVL), t: 307.47 us, nvl_recv_bytes: 128.65 MB
[rank 3] HybridEP dispatch torch API (FP8): 418.24 GB/s (NVL), t: 307.75 us, nvl_recv_bytes: 128.71 MB
[rank 0] HybridEP combine torch API: 390.12 GB/s (NVL), t: 641.54 us, combine_send_bytes: 250.28 MB
[rank 1] HybridEP combine torch API: 386.84 GB/s (NVL), t: 641.91 us, combine_send_bytes: 248.32 MB
[rank 2] HybridEP combine torch API: 388.61 GB/s (NVL), t: 642.02 us, combine_send_bytes: 249.50 MB
[rank 3] HybridEP combine torch API: 388.54 GB/s (NVL), t: 642.48 us, combine_send_bytes: 249.63 MB
[rank 0] HybridEP dispatch+permute torch API (FP8): 381.47 GB/s (NVL), t: 338.31 us, nvl_recv_bytes: 129.05 MB
[rank 1] HybridEP dispatch+permute torch API (FP8): 377.90 GB/s (NVL), t: 338.82 us, nvl_recv_bytes: 128.04 MB
[rank 2] HybridEP dispatch+permute torch API (FP8): 379.58 GB/s (NVL), t: 338.92 us, nvl_recv_bytes: 128.65 MB
[rank 3] HybridEP dispatch+permute torch API (FP8): 379.40 GB/s (NVL), t: 339.25 us, nvl_recv_bytes: 128.71 MB
[rank 0] HybridEP combine+unpermute torch API: 326.68 GB/s (NVL), t: 766.15 us, combine_send_bytes: 250.28 MB
[rank 1] HybridEP combine+unpermute torch API: 323.97 GB/s (NVL), t: 766.47 us, combine_send_bytes: 248.32 MB
[rank 2] HybridEP combine+unpermute torch API: 325.44 GB/s (NVL), t: 766.64 us, combine_send_bytes: 249.50 MB
[rank 3] HybridEP combine+unpermute torch API: 325.46 GB/s (NVL), t: 767.00 us, combine_send_bytes: 249.63 MB
[rank 0] HybridEP dispatch kernel(NVL) (FP8): 817.56 GB/s, avg_t=157.85 us | HybridEP combine kernel(NVL): 477.14 GB/s, avg_t=524.54 us
[rank 1] HybridEP dispatch kernel(NVL) (FP8): 806.13 GB/s, avg_t=158.83 us | HybridEP combine kernel(NVL): 482.78 GB/s, avg_t=514.35 us
[rank 2] HybridEP dispatch kernel(NVL) (FP8): 804.44 GB/s, avg_t=159.92 us | HybridEP combine kernel(NVL): 463.49 GB/s, avg_t=538.29 us
[rank 3] HybridEP dispatch kernel(NVL) (FP8): 759.75 GB/s, avg_t=169.42 us | HybridEP combine kernel(NVL): 482.17 GB/s, avg_t=517.71 us

test command:

MASTER_ADDR=x.x.x.x MASTER_PORT=29500 WORLD_SIZE=1 RANK=0 NUM_SMS_DISPATCH=32 NUM_SMS_COMBINE=32 NUM_BLOCKS_PERMUTE=120 NUM_BLOCKS_UNPERMUTE=120 CUDA_HOME=/usr/local/cuda HIDDEN_DIM=4096 MAX_NUM_OF_TOKENS_PER_RANK=8192 NUM_TOKENS_PER_RANK=8192 NUM_LOCAL_EXPERTS=8 TOPK=8 USE_MNNVL=1 NVSHMEM_DEBUG=WARN NCCL_DEBUG=WARN nohup python3 /elasticdl/share/models/lisiyuan.li/autolab/fujianhao.fjh/hybridep-opensouce/tests/combine_performance_test_hybrid_ep.py --num-processes 4

combine_performance_test_hybrid_ep.py

# SPDX-License-Identifier: MIT
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved
import argparse
import time
import torch
import torch.distributed as dist
import os
import deep_ep

from utils import TorchRef, bench, bench_kineto, init_dist, count_rdma_send_from_routing_map

HIDDEN_DIM = int(os.environ.get("HIDDEN_DIM", 7168))
MAX_NUM_OF_TOKENS_PER_RANK = int(os.environ.get("MAX_NUM_OF_TOKENS_PER_RANK", 4096))
# NUM_TOKENS_PER_RANK should equal or less than MAX_NUM_OF_TOKENS_PER_RANK
NUM_TOKENS_PER_RANK = int(os.environ.get("NUM_TOKENS_PER_RANK", 4096))
NUM_LOCAL_EXPERTS = int(os.environ.get("NUM_LOCAL_EXPERTS", 8))
TOPK = int(os.environ.get("TOPK", 8))
PAD_MULTIPLE = int(os.environ.get("PAD_MULTIPLE", 32))
ITERATIONS = int(os.environ.get("ITERATIONS", 100))
SEED = int(os.environ.get("SEED", 42))
USE_MNNVL = os.environ.get("USE_MNNVL", "0").strip().lower() in {"1", "true", "t", "yes", "y", "on"}
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Will be set after the process group is initialized
NUM_OF_RANKS_PER_NODE = None
NUM_OF_NODES = None
NUM_OF_EXPERTS = None

def print_in_order(msg: str):
    """Print message in order by rank to avoid interleaved output"""
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    for i in range(world_size):
        if i == rank:
            print(msg, flush=True)
        dist.barrier()

def bitwise_equal(a: torch.Tensor, b: torch.Tensor) -> bool:
    if a.dtype != b.dtype or a.shape != b.shape or a.device != b.device:
        return False
    a_bytes = a.contiguous().view(torch.uint8)
    b_bytes = b.contiguous().view(torch.uint8)
    return torch.equal(a_bytes, b_bytes)

def init_tensor(
    hidden_dim: int,
    seq_len: int,
    topk: int,
    num_of_experts: int,
    use_fp8: bool = False,
):
    if use_fp8:
        hidden = torch.randint(
            low=0,
            high=256,
            size=(seq_len, hidden_dim),
            device="cuda",
            dtype=torch.uint8,
        )
    else:
        hidden = torch.randn(seq_len, hidden_dim, device="cuda", dtype=torch.bfloat16)
    probs = torch.zeros(seq_len, num_of_experts, device="cuda", dtype=torch.float32)
    topk_idx = torch.zeros(seq_len, topk, device="cuda", dtype=torch.int64)
    topk_weights = torch.zeros(seq_len, topk, device="cuda", dtype=torch.float32)
    scaling_factor = torch.randn(
        seq_len, hidden_dim // 128, device="cuda", dtype=torch.float32
    )

    routing_map = torch.zeros(seq_len, num_of_experts, device="cuda", dtype=torch.bool)

    for i in range(seq_len):
        # Force balanced routing for testing
        # selected_experts = torch.tensor([
        #     ((i * topk) % num_of_experts + val) % num_of_experts for val in range(topk)
        # ], device="cuda")
        selected_experts = torch.randperm(num_of_experts, device="cuda")[:topk]
        topk_idx[i, :] = selected_experts.to(torch.int64)
        topk_weights[i, :] = torch.ones(topk, device="cuda", dtype=torch.float32)
        routing_map[i, selected_experts] = True
        probs[i, selected_experts] = topk_weights[i, :]

    return hidden, probs, scaling_factor, routing_map, topk_idx, topk_weights


def test_hybrid_ep_correctness(buffer: deep_ep.HybridEPBuffer, ref: TorchRef, use_fp8: bool):
    hidden, probs, scaling_factor, routing_map, topk_idx, topk_weights  = init_tensor(
        hidden_dim=HIDDEN_DIM,
        seq_len=NUM_TOKENS_PER_RANK,
        topk=TOPK,
        num_of_experts=NUM_OF_EXPERTS,
        use_fp8=use_fp8,
    )

    # Dispatch correctness check
    for with_probs in [True, False]:
        # The check for the dispatch
        dispatched_hidden_ref, dispatched_probs_ref, dispatched_scaling_factor_ref = (
            ref.dispatch(
                hidden, routing_map, probs if with_probs else None, scaling_factor
            )
        )
        (
            dispatched_hidden,
            dispatched_probs,
            dispatched_scaling_factor,
            handle,
        ) = buffer.dispatch(
            hidden=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights if with_probs else None, num_of_experts=NUM_OF_EXPERTS,
        )

        assert bitwise_equal(dispatched_hidden_ref, dispatched_hidden)
        if dispatched_probs is not None and dispatched_probs_ref is not None:
            start, end = ref._local_expert_range_per_node()
            masked_probs = torch.zeros_like(dispatched_probs)
            masked_probs[:, start:end] = dispatched_probs[:, start:end]
            assert bitwise_equal(dispatched_probs_ref, dispatched_probs[:, start:end])
            dispatched_probs = masked_probs
        if (
            dispatched_scaling_factor is not None
            and dispatched_scaling_factor_ref is not None
        ):
            assert bitwise_equal(
                dispatched_scaling_factor_ref, dispatched_scaling_factor
            )

        _, _, _, num_dispatched_tokens, local_expert_routing_map, _, _ = handle
        num_dispatched_tokens = num_dispatched_tokens.cpu()
        local_expert_routing_map = local_expert_routing_map[
            : num_dispatched_tokens.item()
        ]
        # Simulate the permute and expert and unpermute. The expert is identity op
        copy_times = local_expert_routing_map.sum(dim=1)
        dispatched_hidden = dispatched_hidden.to(torch.bfloat16)  
        # The combine only support bf16
        hidden_to_combine = dispatched_hidden * copy_times.unsqueeze(1)
        probs_to_combine = dispatched_probs

        # The check for the combine
        combined_hidden, combined_probs = buffer.combine(
            hidden_to_combine, probs_to_combine, handle
        )

        # The reconstucted value should be TOPK times larger than the input hidden
        combined_hidden = combined_hidden / TOPK

        assert torch.allclose(combined_hidden, hidden.to(torch.bfloat16), atol=2e-5, rtol=1e-2)
        if combined_probs is not None and probs is not None:
            assert bitwise_equal(combined_probs, probs)

    # Dispatch with permute correctness check
    for with_probs in [True, False]:
        # The check for the dispatch
        (
            dispatched_hidden,
            dispatched_probs,
            dispatched_scaling_factor,
            tokens_per_expert,
            handle,
        ) = buffer.dispatch_with_permute(
            hidden=hidden,
            routing_map=routing_map,
            probs=probs if with_probs else None,
            scaling_factor=scaling_factor,
            pad_multiple=PAD_MULTIPLE,
        )
        _, _, _, num_dispatched_tokens_tensor, local_expert_routing_map, _, _, _, _ = (
            handle
        )
        num_dispatched_tokens_tensor = num_dispatched_tokens_tensor.cpu()
        local_expert_routing_map = local_expert_routing_map[
            : num_dispatched_tokens_tensor.item()
        ]
        # The out_token_num of permutation is the sum of the tokens_per_expert
        out_token_num = tokens_per_expert.sum().item()
        (
            dispatched_hidden_ref,
            dispatched_probs_ref,
            dispatched_scaling_factor_ref,
        ) = ref.dispatch(
            hidden,
            routing_map,
            probs if with_probs else None,
            scaling_factor,
            local_expert_routing_map=local_expert_routing_map,
            out_token_num=out_token_num,
            pad_multiple=PAD_MULTIPLE,
            enable_permute=True,
        )

        assert bitwise_equal(dispatched_hidden_ref, dispatched_hidden)
        if dispatched_probs is not None and dispatched_probs_ref is not None:
            assert bitwise_equal(dispatched_probs_ref, dispatched_probs)
        if (
            dispatched_scaling_factor is not None
            and dispatched_scaling_factor_ref is not None
        ):
            assert bitwise_equal(
                dispatched_scaling_factor_ref, dispatched_scaling_factor
            )

        # The combine only support bf16
        dispatched_hidden = dispatched_hidden.to(torch.bfloat16)  
        hidden_to_combine = dispatched_hidden
        probs_to_combine = dispatched_probs
 
        # The check for the combine
        combined_hidden, combined_probs = buffer.combine_with_unpermute(
            hidden=hidden_to_combine,
            probs=probs_to_combine,
            handle=handle,
            pad_multiple=PAD_MULTIPLE,
        )

        # The reconstucted value should be TOPK times larger than the input hidden
        combined_hidden = combined_hidden / TOPK

        assert torch.allclose(
            combined_hidden, hidden.to(torch.bfloat16), atol=2e-5, rtol=1e-2
        )
        if combined_probs is not None and probs is not None:
            assert bitwise_equal(combined_probs, probs)

    print_in_order(f'[rank {dist.get_rank()}] Correctness check passed ({"FP8" if hidden.dtype == torch.uint8 else "BF16"})')


def test_hybrid_ep_benchmark(buffer: deep_ep.HybridEPBuffer, group: dist.ProcessGroup, use_fp8: bool, nsys_profile: bool):
    hidden, probs, scaling_factor, routing_map, topk_idx, topk_weights = init_tensor(
        hidden_dim=HIDDEN_DIM,
        seq_len=NUM_TOKENS_PER_RANK,
        topk=TOPK,
        num_of_experts=NUM_OF_EXPERTS,
        use_fp8=use_fp8,
    )

    # warmup
    for _ in range(10):
        dispatched_hidden, dispatched_probs, _, handle = (
            buffer.dispatch(hidden=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights, num_of_experts=NUM_OF_EXPERTS)
        )
        # The combine only support bf16
        dispatched_hidden_bf16 = dispatched_hidden.to(torch.bfloat16)
        dispatched_probs = None
        _, _ = buffer.combine(dispatched_hidden_bf16, dispatched_probs, handle)

    rank = dist.get_rank()
    fp8_factor = (1 + 4 / 128) / 2
    dispatch_bf16_nvl_recv_bytes = dispatched_hidden.numel() * 2
    combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
    if NUM_OF_NODES > 1:
        local_node_id = rank // NUM_OF_RANKS_PER_NODE
        num_rdma_send = count_rdma_send_from_routing_map(routing_map, local_node_id, NUM_OF_NODES)
        dispatch_bf16_rdma_send_bytes = num_rdma_send * HIDDEN_DIM * 2
        combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes

    '''
    Benchmark of the dispatch and combine torch API without permute
    '''

    dispatched_hidden, dispatched_probs, _, handle= (
        buffer.dispatch(hidden=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights, num_of_experts=NUM_OF_EXPERTS)
    )
    dispatched_hidden_bf16 = dispatched_hidden.to(torch.bfloat16)

    dispatch_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'topk_idx': topk_idx, 'topk_weights': topk_weights, 'num_of_experts': NUM_OF_EXPERTS, 'handle': handle}
    t = bench(lambda: buffer.dispatch(**dispatch_args))[0]
    nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if hidden.dtype == torch.uint8 else dispatch_bf16_nvl_recv_bytes
    if NUM_OF_NODES > 1:
        rdma_send_bytes = dispatch_bf16_rdma_send_bytes * fp8_factor if hidden.dtype == torch.uint8 else dispatch_bf16_rdma_send_bytes
    print_in_order(f'[rank {rank}] HybridEP dispatch torch API ({"FP8" if hidden.dtype == torch.uint8 else "BF16"}): '
            f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), t: {t * 1e6:.2f} us, nvl_recv_bytes: {nvl_recv_bytes / 1e6:.2f} MB')
    if NUM_OF_NODES > 1:
        print_in_order(f'[rank {rank}] HybridEP dispatch torch API ({"FP8" if hidden.dtype == torch.uint8 else "BF16"}): '
                f'{rdma_send_bytes / 1e9 / t:.2f} GB/s (IB), t: {t * 1e6:.2f} us, rdma_send_bytes: {rdma_send_bytes / 1e6:.2f} MB')

    combine_args = {'hidden': dispatched_hidden_bf16, 'probs': dispatched_probs, 'handle': handle}
    t = bench(lambda: buffer.combine(**combine_args))[0]
    print_in_order(f'[rank {rank}] HybridEP combine torch API: '
            f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), t: {t * 1e6:.2f} us, combine_send_bytes: {combine_bf16_nvl_send_bytes / 1e6:.2f} MB')
    if NUM_OF_NODES > 1:
        print_in_order(f'[rank {rank}] HybridEP combine torch API: '
                    f'{combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (IB), t: {t * 1e6:.2f} us, rdma_recv_bytes: {combine_bf16_rdma_recv_bytes / 1e6:.2f} MB')

    '''
    Benchmark of the dispatch and combine with permute extension
    '''
    dispatched_hidden_with_permute, dispatched_probs_with_permute, _, tokens_per_expert, handle_with_permute= (
        buffer.dispatch_with_permute(hidden=hidden, scaling_factor=scaling_factor, routing_map=routing_map, probs=probs, pad_multiple=PAD_MULTIPLE)
    )
    num_permuted_tokens = tokens_per_expert.sum().item()
    dispatched_hidden_bf16_with_permute = dispatched_hidden_with_permute.to(torch.bfloat16)

    dispatch_with_permute_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'routing_map': routing_map, 'probs': probs, 'pad_multiple': PAD_MULTIPLE, 'handle': handle_with_permute, 'num_permuted_tokens': num_permuted_tokens}
    t = bench(lambda: buffer.dispatch_with_permute(**dispatch_with_permute_args))[0]
    nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if hidden.dtype == torch.uint8 else dispatch_bf16_nvl_recv_bytes
    print_in_order(f'[rank {rank}] HybridEP dispatch+permute torch API ({"FP8" if hidden.dtype == torch.uint8 else "BF16"}): '
            f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), t: {t * 1e6:.2f} us, nvl_recv_bytes: {nvl_recv_bytes / 1e6:.2f} MB')
    if NUM_OF_NODES > 1:
        print_in_order(f'[rank {rank}] HybridEP dispatch+permute torch API ({"FP8" if hidden.dtype == torch.uint8 else "BF16"}): '
                f'{rdma_send_bytes / 1e9 / t:.2f} GB/s (IB), t: {t * 1e6:.2f} us, rdma_send_bytes: {rdma_send_bytes / 1e6:.2f} MB')

    combine_with_unpermute_args = {'hidden': dispatched_hidden_bf16_with_permute, 'probs': dispatched_probs_with_permute, 'handle': handle_with_permute, 'pad_multiple': PAD_MULTIPLE}
    t = bench(lambda: buffer.combine_with_unpermute(**combine_with_unpermute_args))[0]
    print_in_order(f'[rank {rank}] HybridEP combine+unpermute torch API: '
            f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), t: {t * 1e6:.2f} us, combine_send_bytes: {combine_bf16_nvl_send_bytes / 1e6:.2f} MB')
    if NUM_OF_NODES > 1:
        print_in_order(f'[rank {rank}] HybridEP combine+unpermute torch API: '
                f'{combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (IB), t: {t * 1e6:.2f} us, rdma_recv_bytes: {combine_bf16_rdma_recv_bytes / 1e6:.2f} MB')

    if not nsys_profile:
        # noinspection PyShadowingNames
        def test_func():
            dispatched_hidden, dispatched_probs, _, handle = (
                buffer.dispatch(hidden=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights, num_of_experts=NUM_OF_EXPERTS)
            )
            # The combine only support bf16
            dispatched_hidden_bf16 = dispatched_hidden.to(torch.bfloat16)
            dispatched_probs = None
            _, _ = buffer.combine(dispatched_hidden_bf16, dispatched_probs, handle)

        group.barrier()
        dispatch_t, combine_t = bench_kineto(test_func,
                                             kernel_names=('dispatch_kernel', 'combine_kernel'), barrier_comm_profiling=True,
                                             suppress_kineto_output=True)
        print_in_order(f'[rank {rank}] HybridEP dispatch kernel(NVL) ({"FP8" if hidden.dtype == torch.uint8 else "BF16"}): {nvl_recv_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
              f'HybridEP combine kernel(NVL): {combine_bf16_nvl_send_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us')
        if NUM_OF_NODES > 1:
            print_in_order(f'[rank {rank}] HybridEP dispatch kernel(IB) ({"FP8" if hidden.dtype == torch.uint8 else "BF16"}): {rdma_send_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
                  f'HybridEP combine kernel(IB): {combine_bf16_rdma_recv_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us')
    else:
        if torch.distributed.get_rank() == 0:
            torch.cuda.profiler.start()
        with torch.cuda.nvtx.range(f"hybrid-ep dispatch ({"FP8" if hidden.dtype == torch.uint8 else "BF16"})"):
            if rank == 0:
                print(f"profile hybrid-ep dispatch ({"FP8" if hidden.dtype == torch.uint8 else "BF16"})", flush=True)
            dispatch_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'topk_idx': topk_idx, 'topk_weights': topk_weights, 'num_of_experts': NUM_OF_EXPERTS}
            bench(lambda: buffer.dispatch(**dispatch_args))
        with torch.cuda.nvtx.range("hybrid-ep combine"):
            if rank == 0:
                print(f"profile hybrid-ep combine", flush=True)
            combine_args = {'hidden': dispatched_hidden_bf16, 'probs': dispatched_probs, 'handle': handle}
            bench(lambda: buffer.combine(**combine_args))
        with torch.cuda.nvtx.range(f"hybrid-ep dispatch+permute ({"FP8" if hidden.dtype == torch.uint8 else "BF16"})"):
            if rank == 0:
                print(f"profile hybrid-ep dispatch+permute ({"FP8" if hidden.dtype == torch.uint8 else "BF16"})", flush=True)
            dispatch_with_permute_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'routing_map': routing_map, 'probs': probs, 'pad_multiple': PAD_MULTIPLE}
            bench(lambda: buffer.dispatch_with_permute(**dispatch_with_permute_args))
        with torch.cuda.nvtx.range("hybrid-ep combine+unpermute"):
            if rank == 0:
                print(f"profile hybrid-ep combine+unpermute", flush=True)
            combine_with_unpermute_args = {'hidden': dispatched_hidden_bf16_with_permute, 'probs': dispatched_probs_with_permute, 'handle': handle_with_permute, 'pad_multiple': PAD_MULTIPLE}
            bench(lambda: buffer.combine_with_unpermute(**combine_with_unpermute_args))
        time.sleep(1)
        if torch.distributed.get_rank() == 0:
            torch.cuda.profiler.stop()


def test_main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
    _, _, group = init_dist(local_rank, num_local_ranks)

    # Set missing global vars
    global NUM_OF_RANKS_PER_NODE, NUM_OF_NODES, NUM_OF_EXPERTS
    if USE_MNNVL:
        NUM_OF_RANKS_PER_NODE = group.size()
        NUM_OF_NODES = 1
        NUM_OF_EXPERTS = NUM_LOCAL_EXPERTS * NUM_OF_RANKS_PER_NODE * NUM_OF_NODES
    else:
        NUM_OF_RANKS_PER_NODE = args.num_processes
        NUM_OF_NODES = group.size() // NUM_OF_RANKS_PER_NODE
        NUM_OF_EXPERTS = NUM_LOCAL_EXPERTS * NUM_OF_RANKS_PER_NODE * NUM_OF_NODES

    stream = torch.cuda.Stream()
    with torch.cuda.stream(stream):
        for use_fp8 in [False, True]:
            buffer = deep_ep.HybridEPBuffer(
                group=group,
                hidden_dim=HIDDEN_DIM,
                max_num_of_tokens_per_rank=MAX_NUM_OF_TOKENS_PER_RANK,
                num_local_experts=NUM_LOCAL_EXPERTS,
                num_sms_dispatch_api=32,
                num_sms_combine_api=32,
                use_fp8=use_fp8
            )
            
            ref = TorchRef(
                ep_group=group,
                num_of_experts=NUM_OF_EXPERTS,
                num_of_ranks_per_node=NUM_OF_RANKS_PER_NODE,
            )

            test_hybrid_ep_correctness(buffer, ref, use_fp8)
            test_hybrid_ep_benchmark(buffer, group, use_fp8, args.nsys_profile)
    dist.barrier()
    dist.destroy_process_group()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Test intranode EP kernels')
    parser.add_argument('--num-processes', type=int, default=4,
                       help='Number of processes to spawn (default: 4)')
    parser.add_argument('--nsys-profile', action='store_true', default=False,
                       help='benchmark with nsys profile or not (default: False)')
    args = parser.parse_args()
    torch.multiprocessing.spawn(test_main, args=(args.num_processes, args), nprocs=args.num_processes)

@Autumn1998

Copy link
Copy Markdown
Collaborator Author

why combine kernel performance reduce after merge this PR:

before:

[rank 0] HybridEP dispatch kernel(NVL) (BF16): 866.87 GB/s, avg_t=287.36 us | HybridEP combine kernel(NVL): 789.26 GB/s, avg_t=315.61 us

after

[rank 0] HybridEP dispatch kernel(NVL) (BF16): 867.18 GB/s, avg_t=287.26 us | HybridEP combine kernel(NVL): 476.63 GB/s, avg_t=522.63 us

before current PR:

[rank 0] Correctness check passed (BF16)
[rank 1] Correctness check passed (BF16)
[rank 2] Correctness check passed (BF16)
[rank 3] Correctness check passed (BF16)
[rank 0] HybridEP dispatch torch API (BF16): 548.75 GB/s (NVL), t: 453.94 us, nvl_recv_bytes: 249.10 MB
[rank 1] HybridEP dispatch torch API (BF16): 548.83 GB/s (NVL), t: 453.94 us, nvl_recv_bytes: 249.14 MB
[rank 2] HybridEP dispatch torch API (BF16): 550.67 GB/s (NVL), t: 454.44 us, nvl_recv_bytes: 250.25 MB
[rank 3] HybridEP dispatch torch API (BF16): 548.02 GB/s (NVL), t: 454.49 us, nvl_recv_bytes: 249.07 MB
[rank 0] HybridEP combine torch API: 600.77 GB/s (NVL), t: 414.64 us, combine_send_bytes: 249.10 MB
[rank 1] HybridEP combine torch API: 601.05 GB/s (NVL), t: 414.50 us, combine_send_bytes: 249.14 MB
[rank 2] HybridEP combine torch API: 603.04 GB/s (NVL), t: 414.98 us, combine_send_bytes: 250.25 MB
[rank 3] HybridEP combine torch API: 600.13 GB/s (NVL), t: 415.02 us, combine_send_bytes: 249.07 MB
[rank 0] HybridEP dispatch+permute torch API (BF16): 512.00 GB/s (NVL), t: 486.53 us, nvl_recv_bytes: 249.10 MB
[rank 1] HybridEP dispatch+permute torch API (BF16): 511.81 GB/s (NVL), t: 486.78 us, nvl_recv_bytes: 249.14 MB
[rank 2] HybridEP dispatch+permute torch API (BF16): 513.71 GB/s (NVL), t: 487.14 us, nvl_recv_bytes: 250.25 MB
[rank 3] HybridEP dispatch+permute torch API (BF16): 511.21 GB/s (NVL), t: 487.22 us, nvl_recv_bytes: 249.07 MB
[rank 0] HybridEP combine+unpermute torch API: 461.85 GB/s (NVL), t: 539.36 us, combine_send_bytes: 249.10 MB
[rank 1] HybridEP combine+unpermute torch API: 461.87 GB/s (NVL), t: 539.40 us, combine_send_bytes: 249.14 MB
[rank 2] HybridEP combine+unpermute torch API: 463.48 GB/s (NVL), t: 539.94 us, combine_send_bytes: 250.25 MB
[rank 3] HybridEP combine+unpermute torch API: 461.24 GB/s (NVL), t: 540.00 us, combine_send_bytes: 249.07 MB
[rank 0] HybridEP dispatch kernel(NVL) (BF16): 866.87 GB/s, avg_t=287.36 us | HybridEP combine kernel(NVL): 789.26 GB/s, avg_t=315.61 us
[rank 1] HybridEP dispatch kernel(NVL) (BF16): 869.53 GB/s, avg_t=286.52 us | HybridEP combine kernel(NVL): 800.22 GB/s, avg_t=311.33 us
[rank 2] HybridEP dispatch kernel(NVL) (BF16): 878.08 GB/s, avg_t=285.00 us | HybridEP combine kernel(NVL): 804.73 GB/s, avg_t=310.97 us
[rank 3] HybridEP dispatch kernel(NVL) (BF16): 840.67 GB/s, avg_t=296.27 us | HybridEP combine kernel(NVL): 799.29 GB/s, avg_t=311.61 us
[rank 0] Correctness check passed (FP8)
[rank 1] Correctness check passed (FP8)
[rank 2] Correctness check passed (FP8)
[rank 3] Correctness check passed (FP8)
[rank 0] HybridEP dispatch torch API (FP8): 424.78 GB/s (NVL), t: 303.81 us, nvl_recv_bytes: 129.05 MB
[rank 1] HybridEP dispatch torch API (FP8): 421.48 GB/s (NVL), t: 303.79 us, nvl_recv_bytes: 128.04 MB
[rank 2] HybridEP dispatch torch API (FP8): 422.80 GB/s (NVL), t: 304.27 us, nvl_recv_bytes: 128.65 MB
[rank 3] HybridEP dispatch torch API (FP8): 423.07 GB/s (NVL), t: 304.24 us, nvl_recv_bytes: 128.71 MB
[rank 0] HybridEP combine torch API: 605.81 GB/s (NVL), t: 413.14 us, combine_send_bytes: 250.28 MB
[rank 1] HybridEP combine torch API: 600.93 GB/s (NVL), t: 413.22 us, combine_send_bytes: 248.32 MB
[rank 2] HybridEP combine torch API: 603.02 GB/s (NVL), t: 413.74 us, combine_send_bytes: 249.50 MB
[rank 3] HybridEP combine torch API: 603.30 GB/s (NVL), t: 413.77 us, combine_send_bytes: 249.63 MB
[rank 0] HybridEP dispatch+permute torch API (FP8): 392.13 GB/s (NVL), t: 329.11 us, nvl_recv_bytes: 129.05 MB
[rank 1] HybridEP dispatch+permute torch API (FP8): 388.96 GB/s (NVL), t: 329.18 us, nvl_recv_bytes: 128.04 MB
[rank 2] HybridEP dispatch+permute torch API (FP8): 390.43 GB/s (NVL), t: 329.50 us, nvl_recv_bytes: 128.65 MB
[rank 3] HybridEP dispatch+permute torch API (FP8): 390.55 GB/s (NVL), t: 329.57 us, nvl_recv_bytes: 128.71 MB
[rank 0] HybridEP combine+unpermute torch API: 465.14 GB/s (NVL), t: 538.08 us, combine_send_bytes: 250.28 MB
[rank 1] HybridEP combine+unpermute torch API: 461.39 GB/s (NVL), t: 538.19 us, combine_send_bytes: 248.32 MB
[rank 2] HybridEP combine+unpermute torch API: 463.25 GB/s (NVL), t: 538.57 us, combine_send_bytes: 249.50 MB
[rank 3] HybridEP combine+unpermute torch API: 463.49 GB/s (NVL), t: 538.58 us, combine_send_bytes: 249.63 MB
[rank 0] HybridEP dispatch kernel(NVL) (FP8): 810.44 GB/s, avg_t=159.24 us | HybridEP combine kernel(NVL): 796.52 GB/s, avg_t=314.22 us
[rank 1] HybridEP dispatch kernel(NVL) (FP8): 807.32 GB/s, avg_t=158.60 us | HybridEP combine kernel(NVL): 799.91 GB/s, avg_t=310.43 us
[rank 2] HybridEP dispatch kernel(NVL) (FP8): 800.83 GB/s, avg_t=160.64 us | HybridEP combine kernel(NVL): 796.92 GB/s, avg_t=313.07 us
[rank 3] HybridEP dispatch kernel(NVL) (FP8): 761.48 GB/s, avg_t=169.03 us | HybridEP combine kernel(NVL): 802.90 GB/s, avg_t=310.90 us

after current PR:

[rank 0] Correctness check passed (BF16)
[rank 1] Correctness check passed (BF16)
[rank 2] Correctness check passed (BF16)
[rank 3] Correctness check passed (BF16)
[rank 0] HybridEP dispatch torch API (BF16): 544.91 GB/s (NVL), t: 457.15 us, nvl_recv_bytes: 249.10 MB
[rank 1] HybridEP dispatch torch API (BF16): 545.03 GB/s (NVL), t: 457.10 us, nvl_recv_bytes: 249.14 MB
[rank 2] HybridEP dispatch torch API (BF16): 546.80 GB/s (NVL), t: 457.66 us, nvl_recv_bytes: 250.25 MB
[rank 3] HybridEP dispatch torch API (BF16): 544.24 GB/s (NVL), t: 457.64 us, nvl_recv_bytes: 249.07 MB
[rank 0] HybridEP combine torch API: 387.94 GB/s (NVL), t: 642.12 us, combine_send_bytes: 249.10 MB
[rank 1] HybridEP combine torch API: 387.98 GB/s (NVL), t: 642.14 us, combine_send_bytes: 249.14 MB
[rank 2] HybridEP combine torch API: 389.45 GB/s (NVL), t: 642.58 us, combine_send_bytes: 250.25 MB
[rank 3] HybridEP combine torch API: 387.60 GB/s (NVL), t: 642.59 us, combine_send_bytes: 249.07 MB
[rank 0] HybridEP dispatch+permute torch API (BF16): 498.28 GB/s (NVL), t: 499.92 us, nvl_recv_bytes: 249.10 MB
[rank 1] HybridEP dispatch+permute torch API (BF16): 498.28 GB/s (NVL), t: 499.99 us, nvl_recv_bytes: 249.14 MB
[rank 2] HybridEP dispatch+permute torch API (BF16): 500.15 GB/s (NVL), t: 500.35 us, nvl_recv_bytes: 250.25 MB
[rank 3] HybridEP dispatch+permute torch API (BF16): 497.77 GB/s (NVL), t: 500.37 us, nvl_recv_bytes: 249.07 MB
[rank 0] HybridEP combine+unpermute torch API: 325.00 GB/s (NVL), t: 766.46 us, combine_send_bytes: 249.10 MB
[rank 1] HybridEP combine+unpermute torch API: 325.04 GB/s (NVL), t: 766.47 us, combine_send_bytes: 249.14 MB
[rank 2] HybridEP combine+unpermute torch API: 326.28 GB/s (NVL), t: 766.98 us, combine_send_bytes: 250.25 MB
[rank 3] HybridEP combine+unpermute torch API: 324.77 GB/s (NVL), t: 766.91 us, combine_send_bytes: 249.07 MB
[rank 0] HybridEP dispatch kernel(NVL) (BF16): 867.18 GB/s, avg_t=287.26 us | HybridEP combine kernel(NVL): 476.63 GB/s, avg_t=522.63 us
[rank 1] HybridEP dispatch kernel(NVL) (BF16): 870.55 GB/s, avg_t=286.18 us | HybridEP combine kernel(NVL): 483.91 GB/s, avg_t=514.84 us
[rank 2] HybridEP dispatch kernel(NVL) (BF16): 881.76 GB/s, avg_t=283.81 us | HybridEP combine kernel(NVL): 464.21 GB/s, avg_t=539.09 us
[rank 3] HybridEP dispatch kernel(NVL) (BF16): 838.93 GB/s, avg_t=296.89 us | HybridEP combine kernel(NVL): 482.56 GB/s, avg_t=516.14 us
[rank 0] Correctness check passed (FP8)
[rank 1] Correctness check passed (FP8)
[rank 2] Correctness check passed (FP8)
[rank 3] Correctness check passed (FP8)
[rank 0] HybridEP dispatch torch API (FP8): 420.36 GB/s (NVL), t: 307.01 us, nvl_recv_bytes: 129.05 MB
[rank 1] HybridEP dispatch torch API (FP8): 416.59 GB/s (NVL), t: 307.34 us, nvl_recv_bytes: 128.04 MB
[rank 2] HybridEP dispatch torch API (FP8): 418.40 GB/s (NVL), t: 307.47 us, nvl_recv_bytes: 128.65 MB
[rank 3] HybridEP dispatch torch API (FP8): 418.24 GB/s (NVL), t: 307.75 us, nvl_recv_bytes: 128.71 MB
[rank 0] HybridEP combine torch API: 390.12 GB/s (NVL), t: 641.54 us, combine_send_bytes: 250.28 MB
[rank 1] HybridEP combine torch API: 386.84 GB/s (NVL), t: 641.91 us, combine_send_bytes: 248.32 MB
[rank 2] HybridEP combine torch API: 388.61 GB/s (NVL), t: 642.02 us, combine_send_bytes: 249.50 MB
[rank 3] HybridEP combine torch API: 388.54 GB/s (NVL), t: 642.48 us, combine_send_bytes: 249.63 MB
[rank 0] HybridEP dispatch+permute torch API (FP8): 381.47 GB/s (NVL), t: 338.31 us, nvl_recv_bytes: 129.05 MB
[rank 1] HybridEP dispatch+permute torch API (FP8): 377.90 GB/s (NVL), t: 338.82 us, nvl_recv_bytes: 128.04 MB
[rank 2] HybridEP dispatch+permute torch API (FP8): 379.58 GB/s (NVL), t: 338.92 us, nvl_recv_bytes: 128.65 MB
[rank 3] HybridEP dispatch+permute torch API (FP8): 379.40 GB/s (NVL), t: 339.25 us, nvl_recv_bytes: 128.71 MB
[rank 0] HybridEP combine+unpermute torch API: 326.68 GB/s (NVL), t: 766.15 us, combine_send_bytes: 250.28 MB
[rank 1] HybridEP combine+unpermute torch API: 323.97 GB/s (NVL), t: 766.47 us, combine_send_bytes: 248.32 MB
[rank 2] HybridEP combine+unpermute torch API: 325.44 GB/s (NVL), t: 766.64 us, combine_send_bytes: 249.50 MB
[rank 3] HybridEP combine+unpermute torch API: 325.46 GB/s (NVL), t: 767.00 us, combine_send_bytes: 249.63 MB
[rank 0] HybridEP dispatch kernel(NVL) (FP8): 817.56 GB/s, avg_t=157.85 us | HybridEP combine kernel(NVL): 477.14 GB/s, avg_t=524.54 us
[rank 1] HybridEP dispatch kernel(NVL) (FP8): 806.13 GB/s, avg_t=158.83 us | HybridEP combine kernel(NVL): 482.78 GB/s, avg_t=514.35 us
[rank 2] HybridEP dispatch kernel(NVL) (FP8): 804.44 GB/s, avg_t=159.92 us | HybridEP combine kernel(NVL): 463.49 GB/s, avg_t=538.29 us
[rank 3] HybridEP dispatch kernel(NVL) (FP8): 759.75 GB/s, avg_t=169.42 us | HybridEP combine kernel(NVL): 482.17 GB/s, avg_t=517.71 us

test command:

MASTER_ADDR=x.x.x.x MASTER_PORT=29500 WORLD_SIZE=1 RANK=0 NUM_SMS_DISPATCH=32 NUM_SMS_COMBINE=32 NUM_BLOCKS_PERMUTE=120 NUM_BLOCKS_UNPERMUTE=120 CUDA_HOME=/usr/local/cuda HIDDEN_DIM=4096 MAX_NUM_OF_TOKENS_PER_RANK=8192 NUM_TOKENS_PER_RANK=8192 NUM_LOCAL_EXPERTS=8 TOPK=8 USE_MNNVL=1 NVSHMEM_DEBUG=WARN NCCL_DEBUG=WARN nohup python3 /elasticdl/share/models/lisiyuan.li/autolab/fujianhao.fjh/hybridep-opensouce/tests/combine_performance_test_hybrid_ep.py --num-processes 4

combine_performance_test_hybrid_ep.py

# SPDX-License-Identifier: MIT
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved
import argparse
import time
import torch
import torch.distributed as dist
import os
import deep_ep

from utils import TorchRef, bench, bench_kineto, init_dist, count_rdma_send_from_routing_map

HIDDEN_DIM = int(os.environ.get("HIDDEN_DIM", 7168))
MAX_NUM_OF_TOKENS_PER_RANK = int(os.environ.get("MAX_NUM_OF_TOKENS_PER_RANK", 4096))
# NUM_TOKENS_PER_RANK should equal or less than MAX_NUM_OF_TOKENS_PER_RANK
NUM_TOKENS_PER_RANK = int(os.environ.get("NUM_TOKENS_PER_RANK", 4096))
NUM_LOCAL_EXPERTS = int(os.environ.get("NUM_LOCAL_EXPERTS", 8))
TOPK = int(os.environ.get("TOPK", 8))
PAD_MULTIPLE = int(os.environ.get("PAD_MULTIPLE", 32))
ITERATIONS = int(os.environ.get("ITERATIONS", 100))
SEED = int(os.environ.get("SEED", 42))
USE_MNNVL = os.environ.get("USE_MNNVL", "0").strip().lower() in {"1", "true", "t", "yes", "y", "on"}
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Will be set after the process group is initialized
NUM_OF_RANKS_PER_NODE = None
NUM_OF_NODES = None
NUM_OF_EXPERTS = None

def print_in_order(msg: str):
    """Print message in order by rank to avoid interleaved output"""
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    for i in range(world_size):
        if i == rank:
            print(msg, flush=True)
        dist.barrier()

def bitwise_equal(a: torch.Tensor, b: torch.Tensor) -> bool:
    if a.dtype != b.dtype or a.shape != b.shape or a.device != b.device:
        return False
    a_bytes = a.contiguous().view(torch.uint8)
    b_bytes = b.contiguous().view(torch.uint8)
    return torch.equal(a_bytes, b_bytes)

def init_tensor(
    hidden_dim: int,
    seq_len: int,
    topk: int,
    num_of_experts: int,
    use_fp8: bool = False,
):
    if use_fp8:
        hidden = torch.randint(
            low=0,
            high=256,
            size=(seq_len, hidden_dim),
            device="cuda",
            dtype=torch.uint8,
        )
    else:
        hidden = torch.randn(seq_len, hidden_dim, device="cuda", dtype=torch.bfloat16)
    probs = torch.zeros(seq_len, num_of_experts, device="cuda", dtype=torch.float32)
    topk_idx = torch.zeros(seq_len, topk, device="cuda", dtype=torch.int64)
    topk_weights = torch.zeros(seq_len, topk, device="cuda", dtype=torch.float32)
    scaling_factor = torch.randn(
        seq_len, hidden_dim // 128, device="cuda", dtype=torch.float32
    )

    routing_map = torch.zeros(seq_len, num_of_experts, device="cuda", dtype=torch.bool)

    for i in range(seq_len):
        # Force balanced routing for testing
        # selected_experts = torch.tensor([
        #     ((i * topk) % num_of_experts + val) % num_of_experts for val in range(topk)
        # ], device="cuda")
        selected_experts = torch.randperm(num_of_experts, device="cuda")[:topk]
        topk_idx[i, :] = selected_experts.to(torch.int64)
        topk_weights[i, :] = torch.ones(topk, device="cuda", dtype=torch.float32)
        routing_map[i, selected_experts] = True
        probs[i, selected_experts] = topk_weights[i, :]

    return hidden, probs, scaling_factor, routing_map, topk_idx, topk_weights


def test_hybrid_ep_correctness(buffer: deep_ep.HybridEPBuffer, ref: TorchRef, use_fp8: bool):
    hidden, probs, scaling_factor, routing_map, topk_idx, topk_weights  = init_tensor(
        hidden_dim=HIDDEN_DIM,
        seq_len=NUM_TOKENS_PER_RANK,
        topk=TOPK,
        num_of_experts=NUM_OF_EXPERTS,
        use_fp8=use_fp8,
    )

    # Dispatch correctness check
    for with_probs in [True, False]:
        # The check for the dispatch
        dispatched_hidden_ref, dispatched_probs_ref, dispatched_scaling_factor_ref = (
            ref.dispatch(
                hidden, routing_map, probs if with_probs else None, scaling_factor
            )
        )
        (
            dispatched_hidden,
            dispatched_probs,
            dispatched_scaling_factor,
            handle,
        ) = buffer.dispatch(
            hidden=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights if with_probs else None, num_of_experts=NUM_OF_EXPERTS,
        )

        assert bitwise_equal(dispatched_hidden_ref, dispatched_hidden)
        if dispatched_probs is not None and dispatched_probs_ref is not None:
            start, end = ref._local_expert_range_per_node()
            masked_probs = torch.zeros_like(dispatched_probs)
            masked_probs[:, start:end] = dispatched_probs[:, start:end]
            assert bitwise_equal(dispatched_probs_ref, dispatched_probs[:, start:end])
            dispatched_probs = masked_probs
        if (
            dispatched_scaling_factor is not None
            and dispatched_scaling_factor_ref is not None
        ):
            assert bitwise_equal(
                dispatched_scaling_factor_ref, dispatched_scaling_factor
            )

        _, _, _, num_dispatched_tokens, local_expert_routing_map, _, _ = handle
        num_dispatched_tokens = num_dispatched_tokens.cpu()
        local_expert_routing_map = local_expert_routing_map[
            : num_dispatched_tokens.item()
        ]
        # Simulate the permute and expert and unpermute. The expert is identity op
        copy_times = local_expert_routing_map.sum(dim=1)
        dispatched_hidden = dispatched_hidden.to(torch.bfloat16)  
        # The combine only support bf16
        hidden_to_combine = dispatched_hidden * copy_times.unsqueeze(1)
        probs_to_combine = dispatched_probs

        # The check for the combine
        combined_hidden, combined_probs = buffer.combine(
            hidden_to_combine, probs_to_combine, handle
        )

        # The reconstucted value should be TOPK times larger than the input hidden
        combined_hidden = combined_hidden / TOPK

        assert torch.allclose(combined_hidden, hidden.to(torch.bfloat16), atol=2e-5, rtol=1e-2)
        if combined_probs is not None and probs is not None:
            assert bitwise_equal(combined_probs, probs)

    # Dispatch with permute correctness check
    for with_probs in [True, False]:
        # The check for the dispatch
        (
            dispatched_hidden,
            dispatched_probs,
            dispatched_scaling_factor,
            tokens_per_expert,
            handle,
        ) = buffer.dispatch_with_permute(
            hidden=hidden,
            routing_map=routing_map,
            probs=probs if with_probs else None,
            scaling_factor=scaling_factor,
            pad_multiple=PAD_MULTIPLE,
        )
        _, _, _, num_dispatched_tokens_tensor, local_expert_routing_map, _, _, _, _ = (
            handle
        )
        num_dispatched_tokens_tensor = num_dispatched_tokens_tensor.cpu()
        local_expert_routing_map = local_expert_routing_map[
            : num_dispatched_tokens_tensor.item()
        ]
        # The out_token_num of permutation is the sum of the tokens_per_expert
        out_token_num = tokens_per_expert.sum().item()
        (
            dispatched_hidden_ref,
            dispatched_probs_ref,
            dispatched_scaling_factor_ref,
        ) = ref.dispatch(
            hidden,
            routing_map,
            probs if with_probs else None,
            scaling_factor,
            local_expert_routing_map=local_expert_routing_map,
            out_token_num=out_token_num,
            pad_multiple=PAD_MULTIPLE,
            enable_permute=True,
        )

        assert bitwise_equal(dispatched_hidden_ref, dispatched_hidden)
        if dispatched_probs is not None and dispatched_probs_ref is not None:
            assert bitwise_equal(dispatched_probs_ref, dispatched_probs)
        if (
            dispatched_scaling_factor is not None
            and dispatched_scaling_factor_ref is not None
        ):
            assert bitwise_equal(
                dispatched_scaling_factor_ref, dispatched_scaling_factor
            )

        # The combine only support bf16
        dispatched_hidden = dispatched_hidden.to(torch.bfloat16)  
        hidden_to_combine = dispatched_hidden
        probs_to_combine = dispatched_probs
 
        # The check for the combine
        combined_hidden, combined_probs = buffer.combine_with_unpermute(
            hidden=hidden_to_combine,
            probs=probs_to_combine,
            handle=handle,
            pad_multiple=PAD_MULTIPLE,
        )

        # The reconstucted value should be TOPK times larger than the input hidden
        combined_hidden = combined_hidden / TOPK

        assert torch.allclose(
            combined_hidden, hidden.to(torch.bfloat16), atol=2e-5, rtol=1e-2
        )
        if combined_probs is not None and probs is not None:
            assert bitwise_equal(combined_probs, probs)

    print_in_order(f'[rank {dist.get_rank()}] Correctness check passed ({"FP8" if hidden.dtype == torch.uint8 else "BF16"})')


def test_hybrid_ep_benchmark(buffer: deep_ep.HybridEPBuffer, group: dist.ProcessGroup, use_fp8: bool, nsys_profile: bool):
    hidden, probs, scaling_factor, routing_map, topk_idx, topk_weights = init_tensor(
        hidden_dim=HIDDEN_DIM,
        seq_len=NUM_TOKENS_PER_RANK,
        topk=TOPK,
        num_of_experts=NUM_OF_EXPERTS,
        use_fp8=use_fp8,
    )

    # warmup
    for _ in range(10):
        dispatched_hidden, dispatched_probs, _, handle = (
            buffer.dispatch(hidden=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights, num_of_experts=NUM_OF_EXPERTS)
        )
        # The combine only support bf16
        dispatched_hidden_bf16 = dispatched_hidden.to(torch.bfloat16)
        dispatched_probs = None
        _, _ = buffer.combine(dispatched_hidden_bf16, dispatched_probs, handle)

    rank = dist.get_rank()
    fp8_factor = (1 + 4 / 128) / 2
    dispatch_bf16_nvl_recv_bytes = dispatched_hidden.numel() * 2
    combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
    if NUM_OF_NODES > 1:
        local_node_id = rank // NUM_OF_RANKS_PER_NODE
        num_rdma_send = count_rdma_send_from_routing_map(routing_map, local_node_id, NUM_OF_NODES)
        dispatch_bf16_rdma_send_bytes = num_rdma_send * HIDDEN_DIM * 2
        combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes

    '''
    Benchmark of the dispatch and combine torch API without permute
    '''

    dispatched_hidden, dispatched_probs, _, handle= (
        buffer.dispatch(hidden=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights, num_of_experts=NUM_OF_EXPERTS)
    )
    dispatched_hidden_bf16 = dispatched_hidden.to(torch.bfloat16)

    dispatch_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'topk_idx': topk_idx, 'topk_weights': topk_weights, 'num_of_experts': NUM_OF_EXPERTS, 'handle': handle}
    t = bench(lambda: buffer.dispatch(**dispatch_args))[0]
    nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if hidden.dtype == torch.uint8 else dispatch_bf16_nvl_recv_bytes
    if NUM_OF_NODES > 1:
        rdma_send_bytes = dispatch_bf16_rdma_send_bytes * fp8_factor if hidden.dtype == torch.uint8 else dispatch_bf16_rdma_send_bytes
    print_in_order(f'[rank {rank}] HybridEP dispatch torch API ({"FP8" if hidden.dtype == torch.uint8 else "BF16"}): '
            f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), t: {t * 1e6:.2f} us, nvl_recv_bytes: {nvl_recv_bytes / 1e6:.2f} MB')
    if NUM_OF_NODES > 1:
        print_in_order(f'[rank {rank}] HybridEP dispatch torch API ({"FP8" if hidden.dtype == torch.uint8 else "BF16"}): '
                f'{rdma_send_bytes / 1e9 / t:.2f} GB/s (IB), t: {t * 1e6:.2f} us, rdma_send_bytes: {rdma_send_bytes / 1e6:.2f} MB')

    combine_args = {'hidden': dispatched_hidden_bf16, 'probs': dispatched_probs, 'handle': handle}
    t = bench(lambda: buffer.combine(**combine_args))[0]
    print_in_order(f'[rank {rank}] HybridEP combine torch API: '
            f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), t: {t * 1e6:.2f} us, combine_send_bytes: {combine_bf16_nvl_send_bytes / 1e6:.2f} MB')
    if NUM_OF_NODES > 1:
        print_in_order(f'[rank {rank}] HybridEP combine torch API: '
                    f'{combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (IB), t: {t * 1e6:.2f} us, rdma_recv_bytes: {combine_bf16_rdma_recv_bytes / 1e6:.2f} MB')

    '''
    Benchmark of the dispatch and combine with permute extension
    '''
    dispatched_hidden_with_permute, dispatched_probs_with_permute, _, tokens_per_expert, handle_with_permute= (
        buffer.dispatch_with_permute(hidden=hidden, scaling_factor=scaling_factor, routing_map=routing_map, probs=probs, pad_multiple=PAD_MULTIPLE)
    )
    num_permuted_tokens = tokens_per_expert.sum().item()
    dispatched_hidden_bf16_with_permute = dispatched_hidden_with_permute.to(torch.bfloat16)

    dispatch_with_permute_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'routing_map': routing_map, 'probs': probs, 'pad_multiple': PAD_MULTIPLE, 'handle': handle_with_permute, 'num_permuted_tokens': num_permuted_tokens}
    t = bench(lambda: buffer.dispatch_with_permute(**dispatch_with_permute_args))[0]
    nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if hidden.dtype == torch.uint8 else dispatch_bf16_nvl_recv_bytes
    print_in_order(f'[rank {rank}] HybridEP dispatch+permute torch API ({"FP8" if hidden.dtype == torch.uint8 else "BF16"}): '
            f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), t: {t * 1e6:.2f} us, nvl_recv_bytes: {nvl_recv_bytes / 1e6:.2f} MB')
    if NUM_OF_NODES > 1:
        print_in_order(f'[rank {rank}] HybridEP dispatch+permute torch API ({"FP8" if hidden.dtype == torch.uint8 else "BF16"}): '
                f'{rdma_send_bytes / 1e9 / t:.2f} GB/s (IB), t: {t * 1e6:.2f} us, rdma_send_bytes: {rdma_send_bytes / 1e6:.2f} MB')

    combine_with_unpermute_args = {'hidden': dispatched_hidden_bf16_with_permute, 'probs': dispatched_probs_with_permute, 'handle': handle_with_permute, 'pad_multiple': PAD_MULTIPLE}
    t = bench(lambda: buffer.combine_with_unpermute(**combine_with_unpermute_args))[0]
    print_in_order(f'[rank {rank}] HybridEP combine+unpermute torch API: '
            f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), t: {t * 1e6:.2f} us, combine_send_bytes: {combine_bf16_nvl_send_bytes / 1e6:.2f} MB')
    if NUM_OF_NODES > 1:
        print_in_order(f'[rank {rank}] HybridEP combine+unpermute torch API: '
                f'{combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (IB), t: {t * 1e6:.2f} us, rdma_recv_bytes: {combine_bf16_rdma_recv_bytes / 1e6:.2f} MB')

    if not nsys_profile:
        # noinspection PyShadowingNames
        def test_func():
            dispatched_hidden, dispatched_probs, _, handle = (
                buffer.dispatch(hidden=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights, num_of_experts=NUM_OF_EXPERTS)
            )
            # The combine only support bf16
            dispatched_hidden_bf16 = dispatched_hidden.to(torch.bfloat16)
            dispatched_probs = None
            _, _ = buffer.combine(dispatched_hidden_bf16, dispatched_probs, handle)

        group.barrier()
        dispatch_t, combine_t = bench_kineto(test_func,
                                             kernel_names=('dispatch_kernel', 'combine_kernel'), barrier_comm_profiling=True,
                                             suppress_kineto_output=True)
        print_in_order(f'[rank {rank}] HybridEP dispatch kernel(NVL) ({"FP8" if hidden.dtype == torch.uint8 else "BF16"}): {nvl_recv_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
              f'HybridEP combine kernel(NVL): {combine_bf16_nvl_send_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us')
        if NUM_OF_NODES > 1:
            print_in_order(f'[rank {rank}] HybridEP dispatch kernel(IB) ({"FP8" if hidden.dtype == torch.uint8 else "BF16"}): {rdma_send_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | '
                  f'HybridEP combine kernel(IB): {combine_bf16_rdma_recv_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us')
    else:
        if torch.distributed.get_rank() == 0:
            torch.cuda.profiler.start()
        with torch.cuda.nvtx.range(f"hybrid-ep dispatch ({"FP8" if hidden.dtype == torch.uint8 else "BF16"})"):
            if rank == 0:
                print(f"profile hybrid-ep dispatch ({"FP8" if hidden.dtype == torch.uint8 else "BF16"})", flush=True)
            dispatch_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'topk_idx': topk_idx, 'topk_weights': topk_weights, 'num_of_experts': NUM_OF_EXPERTS}
            bench(lambda: buffer.dispatch(**dispatch_args))
        with torch.cuda.nvtx.range("hybrid-ep combine"):
            if rank == 0:
                print(f"profile hybrid-ep combine", flush=True)
            combine_args = {'hidden': dispatched_hidden_bf16, 'probs': dispatched_probs, 'handle': handle}
            bench(lambda: buffer.combine(**combine_args))
        with torch.cuda.nvtx.range(f"hybrid-ep dispatch+permute ({"FP8" if hidden.dtype == torch.uint8 else "BF16"})"):
            if rank == 0:
                print(f"profile hybrid-ep dispatch+permute ({"FP8" if hidden.dtype == torch.uint8 else "BF16"})", flush=True)
            dispatch_with_permute_args = {'hidden': hidden, 'scaling_factor': scaling_factor, 'routing_map': routing_map, 'probs': probs, 'pad_multiple': PAD_MULTIPLE}
            bench(lambda: buffer.dispatch_with_permute(**dispatch_with_permute_args))
        with torch.cuda.nvtx.range("hybrid-ep combine+unpermute"):
            if rank == 0:
                print(f"profile hybrid-ep combine+unpermute", flush=True)
            combine_with_unpermute_args = {'hidden': dispatched_hidden_bf16_with_permute, 'probs': dispatched_probs_with_permute, 'handle': handle_with_permute, 'pad_multiple': PAD_MULTIPLE}
            bench(lambda: buffer.combine_with_unpermute(**combine_with_unpermute_args))
        time.sleep(1)
        if torch.distributed.get_rank() == 0:
            torch.cuda.profiler.stop()


def test_main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
    _, _, group = init_dist(local_rank, num_local_ranks)

    # Set missing global vars
    global NUM_OF_RANKS_PER_NODE, NUM_OF_NODES, NUM_OF_EXPERTS
    if USE_MNNVL:
        NUM_OF_RANKS_PER_NODE = group.size()
        NUM_OF_NODES = 1
        NUM_OF_EXPERTS = NUM_LOCAL_EXPERTS * NUM_OF_RANKS_PER_NODE * NUM_OF_NODES
    else:
        NUM_OF_RANKS_PER_NODE = args.num_processes
        NUM_OF_NODES = group.size() // NUM_OF_RANKS_PER_NODE
        NUM_OF_EXPERTS = NUM_LOCAL_EXPERTS * NUM_OF_RANKS_PER_NODE * NUM_OF_NODES

    stream = torch.cuda.Stream()
    with torch.cuda.stream(stream):
        for use_fp8 in [False, True]:
            buffer = deep_ep.HybridEPBuffer(
                group=group,
                hidden_dim=HIDDEN_DIM,
                max_num_of_tokens_per_rank=MAX_NUM_OF_TOKENS_PER_RANK,
                num_local_experts=NUM_LOCAL_EXPERTS,
                num_sms_dispatch_api=32,
                num_sms_combine_api=32,
                use_fp8=use_fp8
            )
            
            ref = TorchRef(
                ep_group=group,
                num_of_experts=NUM_OF_EXPERTS,
                num_of_ranks_per_node=NUM_OF_RANKS_PER_NODE,
            )

            test_hybrid_ep_correctness(buffer, ref, use_fp8)
            test_hybrid_ep_benchmark(buffer, group, use_fp8, args.nsys_profile)
    dist.barrier()
    dist.destroy_process_group()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Test intranode EP kernels')
    parser.add_argument('--num-processes', type=int, default=4,
                       help='Number of processes to spawn (default: 4)')
    parser.add_argument('--nsys-profile', action='store_true', default=False,
                       help='benchmark with nsys profile or not (default: False)')
    args = parser.parse_args()
    torch.multiprocessing.spawn(test_main, args=(args.num_processes, args), nprocs=args.num_processes)

This may be caused by the default values set in this PR: NUM_OF_TOKENS_PER_CHUNK_DISPATCH_API=64 and NUM_OF_TOKENS_PER_CHUNK_COMBINE_API=64. While 128 is use before. 128 chunk-size might yield better performance for the non-fused case, we set it to 64 because it keeps performance acceptable for both the fused and non-fused case

@alpha-baby

Copy link
Copy Markdown
Contributor

@Autumn1998

I've noticed that the hybrid-ep project has increasingly more configurable parameters, making it challenging to tune to the optimal settings. Currently, it seems that setting the configuration to 128 would be a better option!

currently the performance is normal:

=== Correctness Check (BF16, 4 ranks) ===
  dispatch+combine API: PASS
  dispatch_with_permute + combine_with_unpermute API (non-fused): PASS
  dispatch_with_permute + combine_with_unpermute API (fused): PASS

=== Torch API Benchmark (BF16, 4 ranks) ===
  Non-permute:  dispatch = dispatch_kernel + d2d + misc
                combine  = d2d + combine_kernel + misc
  Permute:      dispatch = dispatch_kernel + permute_kernel + misc
                combine  = unpermute_kernel + combine_kernel + misc
  Fused:        dispatch = fused_permute_dispatch_kernel + misc
                combine  = fused_combine_unpermute_kernel + misc
  (misc = device_sync, update_flag, etc.)
[rank 0] HybridEP dispatch torch API (BF16): 549.14 GB/s (NVL), t: 456.49 us, nvl_recv_bytes: 250.68 MB
[rank 1] HybridEP dispatch torch API (BF16): 544.07 GB/s (NVL), t: 456.40 us, nvl_recv_bytes: 248.32 MB
[rank 2] HybridEP dispatch torch API (BF16): 545.01 GB/s (NVL), t: 457.00 us, nvl_recv_bytes: 249.07 MB
[rank 3] HybridEP dispatch torch API (BF16): 544.70 GB/s (NVL), t: 456.96 us, nvl_recv_bytes: 248.91 MB
dispatch (BF16):                                 548.87 GB/s (NVL), t: 456.7 us [min=456.4, max=457.0], nvl_recv_bytes: 250.68 MB
combine:                                         596.39 GB/s (NVL), t: 420.3 us [min=420.0, max=420.6], combine_send_bytes: 250.68 MB
dispatch+permute (BF16):                         375.09 GB/s (NVL), t: 668.3 us [min=668.0, max=668.7], nvl_recv_bytes: 250.68 MB
combine+unpermute:                               260.71 GB/s (NVL), t: 961.5 us [min=961.3, max=961.8], combine_send_bytes: 250.68 MB
fused dispatch+permute (BF16):                   504.44 GB/s (NVL), t: 496.9 us [min=496.7, max=497.2], nvl_recv_bytes: 250.68 MB
fused combine+unpermute:                         503.37 GB/s (NVL), t: 498.0 us [min=497.7, max=498.3], combine_send_bytes: 250.68 MB

=== Kernel Benchmark (BF16, 4 ranks) ===
  Non-fused:  dispatch_kernel only  |  combine_kernel only
  Fused:      fused_permute_dispatch_kernel only  |  fused_combine_unpermute_kernel only
dispatch kernel (BF16)(NVL):                     868.04 GB/s, avg_t=288.8 us [min=284.8, max=296.5]
combine kernel(NVL):                             796.36 GB/s, avg_t=314.8 us [min=313.6, max=316.4]
fused dispatch+permute kernel (BF16)(NVL):       522.46 GB/s, avg_t=479.8 us [min=475.5, max=482.9]
fused combine+unpermute kernel(NVL):             522.09 GB/s, avg_t=480.1 us [min=477.9, max=484.8]

test command:

MASTER_ADDR=xxxxx MASTER_PORT=29500 WORLD_SIZE=1 RANK=0 NUM_SMS_DISPATCH=32 NUM_SMS_COMBINE=32 NUM_BLOCKS_PERMUTE=120 NUM_BLOCKS_UNPERMUTE=120 NUM_OF_TOKENS_PER_CHUNK_DISPATCH_API=128 NUM_OF_TOKENS_PER_CHUNK_COMBINE_API=128 NUM_OF_TOKENS_PER_CHUNK_PREPROCESSING_API=128 CUDA_HOME=/usr/local/cuda HIDDEN_DIM=4096 MAX_NUM_OF_TOKENS_PER_RANK=8192 NUM_TOKENS_PER_RANK=8192 NUM_LOCAL_EXPERTS=8 TOPK=8 USE_MNNVL=1 NVSHMEM_DEBUG=WARN NCCL_DEBUG=WARN nohup python3 /elasticdl/share/models/lisiyuan.li/autolab/fujianhao.fjh/hybridep-opensouce/tests/test_hybrid_ep_fix.py --num-processes 4

before default args:

=== Correctness Check (BF16, 4 ranks) ===
  dispatch+combine API: PASS
  dispatch_with_permute + combine_with_unpermute API (non-fused): PASS
  dispatch_with_permute + combine_with_unpermute API (fused): PASS

=== Torch API Benchmark (BF16, 4 ranks) ===
  Non-permute:  dispatch = dispatch_kernel + d2d + misc
                combine  = d2d + combine_kernel + misc
  Permute:      dispatch = dispatch_kernel + permute_kernel + misc
                combine  = unpermute_kernel + combine_kernel + misc
  Fused:        dispatch = fused_permute_dispatch_kernel + misc
                combine  = fused_combine_unpermute_kernel + misc
  (misc = device_sync, update_flag, etc.)
[rank 0] HybridEP dispatch torch API (BF16): 547.18 GB/s (NVL), t: 458.13 us, nvl_recv_bytes: 250.68 MB
[rank 1] HybridEP dispatch torch API (BF16): 542.09 GB/s (NVL), t: 458.07 us, nvl_recv_bytes: 248.32 MB
[rank 2] HybridEP dispatch torch API (BF16): 543.17 GB/s (NVL), t: 458.55 us, nvl_recv_bytes: 249.07 MB
[rank 3] HybridEP dispatch torch API (BF16): 542.89 GB/s (NVL), t: 458.48 us, nvl_recv_bytes: 248.91 MB
dispatch (BF16):                                 546.96 GB/s (NVL), t: 458.3 us [min=458.1, max=458.6], nvl_recv_bytes: 250.68 MB
combine:                                         390.57 GB/s (NVL), t: 641.8 us [min=641.5, max=642.1], combine_send_bytes: 250.68 MB
dispatch+permute (BF16):                         374.76 GB/s (NVL), t: 668.9 us [min=668.7, max=669.1], nvl_recv_bytes: 250.68 MB
combine+unpermute:                               211.94 GB/s (NVL), t: 1182.8 us [min=1182.5, max=1183.0], combine_send_bytes: 250.68 MB
fused dispatch+permute (BF16):                   607.63 GB/s (NVL), t: 412.5 us [min=412.3, max=412.8], nvl_recv_bytes: 250.68 MB
fused combine+unpermute:                         381.64 GB/s (NVL), t: 656.8 us [min=656.6, max=657.1], combine_send_bytes: 250.68 MB

=== Kernel Benchmark (BF16, 4 ranks) ===
  Non-fused:  dispatch_kernel only  |  combine_kernel only
  Fused:      fused_permute_dispatch_kernel only  |  fused_combine_unpermute_kernel only
dispatch kernel (BF16)(NVL):                     867.71 GB/s, avg_t=288.9 us [min=284.0, max=297.2]
combine kernel(NVL):                             480.82 GB/s, avg_t=521.4 us [min=510.9, max=536.9]
fused dispatch+permute kernel (BF16)(NVL):       631.72 GB/s, avg_t=396.8 us [min=394.1, max=398.6]
fused combine+unpermute kernel(NVL):             401.59 GB/s, avg_t=624.2 us [min=611.9, max=645.2]

test command:

MASTER_ADDR=xxxxxx MASTER_PORT=29500 WORLD_SIZE=1 RANK=0 NUM_SMS_DISPATCH=32 NUM_SMS_COMBINE=32 NUM_BLOCKS_PERMUTE=120 NUM_BLOCKS_UNPERMUTE=120 CUDA_HOME=/usr/local/cuda HIDDEN_DIM=4096 MAX_NUM_OF_TOKENS_PER_RANK=8192 NUM_TOKENS_PER_RANK=8192 NUM_LOCAL_EXPERTS=8 TOPK=8 USE_MNNVL=1 NVSHMEM_DEBUG=WARN NCCL_DEBUG=WARN nohup python3 /elasticdl/share/models/lisiyuan.li/autolab/fujianhao.fjh/hybridep-opensouce/tests/test_hybrid_ep_fix.py --num-processes

@alpha-baby

Copy link
Copy Markdown
Contributor

Although the performance of the fused combine+unpermute kernel (NVL) has improved, the performance of the fused dispatch+permute kernel (BF16)(NVL) has degraded.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants