Skip to content

Optimize all_reduce by porting the shared memory kernel of deepspeed#5

Merged
mingfeima merged 4 commits intomingfeima:cpu_optfrom
chunyuan-w:chunyuan/pr_shm_allreduce
Feb 19, 2025
Merged

Optimize all_reduce by porting the shared memory kernel of deepspeed#5
mingfeima merged 4 commits intomingfeima:cpu_optfrom
chunyuan-w:chunyuan/pr_shm_allreduce

Conversation

@chunyuan-w
Copy link
Copy Markdown
Collaborator

@chunyuan-w chunyuan-w commented Feb 17, 2025

Motivation

Optimize all_reduce by porting the shared memory implementation in deepspeed.

Modifications

  1. We added a shm_allreduce operator in sgl-kernel.

    The implementation is ported from DeepSpeed:
    sgl-kernel/src/sgl-kernel/csrc/cpu/shm.h ---> DeepSpeed: shm.h
    sgl-kernel/src/sgl-kernel/csrc/cpu/shm.cpp ---> DeepSpeed: shm.cpp
    sgl-kernel/src/sgl-kernel/csrc/cpu/interface.cpp ---> DeepSpeed: ccl.cpp

    To build the kernel:

    cd sgl-kernel
    python setup.py develop
  2. We added a wrapper call tensor_model_parallel_all_reduce_wrapper in SGLang to call the shm_allreduce for CPU and call the original tensor_model_parallel_all_reduce API in vllm for other devices, so that we don't need to change the tensor_model_parallel_all_reduce function in vllm.

  3. In sgl-kernel/src/sgl-kernel/ops/__init__.py, we only import cuda kernels if cuda is available.

Benchmarks

Accuracy

The score on mmlu (higher is better) for tp=2:
score without this PR (using torch.distributed.all_reduce): 0.594
score with this PR: 0.594

Note: --disable-overlap-schedule is needed in the args for CPU, otherwise, a new thread will be created for the forward batch here and the OMP threads binding will not work on this newly created thread.

Command line:

# Server side
# tp = 2:
SGLANG_CPU_OMP_THREADS_BIND="0-39|40-79" python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --disable-radix --trust-remote-code --device cpu --attention-backend torch_native --disable-mla --log-requests --disable-overlap-schedule --tp 2
# Client side
python3 -m sglang.test.run_eval --eval-name mmlu --num-examples 64 --port 30000

Performance

We can observe 25% and 37% speedup on first and next token latency respectively after switching from torch.distributed.all_reduce to shm_allreduce for the below tp=2 command line on GNR.

SGLANG_CPU_OMP_THREADS_BIND="0-39|40-79" python3 -m sglang.bench_one_batch --batch-size 1 --input 1024 --output 8 --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --trust-remote-code --device cpu --attention-backend torch_native --disable-mla --tp 2

Limitations

The shm all_reduce only supports FP32 and BF16 and the cases where all the ranks are local. We fallback to torch.distributed.all_reduce for unsupported cases.

@chunyuan-w chunyuan-w changed the title Optimize all_reduce by porting the shared memory kernel of deepspeed. Optimize all_reduce by porting the shared memory kernel of deepspeed Feb 18, 2025
@chunyuan-w chunyuan-w marked this pull request as ready for review February 18, 2025 06:47
@chunyuan-w chunyuan-w requested a review from mingfeima February 18, 2025 06:48
Copy link
Copy Markdown
Owner

@mingfeima mingfeima left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM! Try to use TAP 2 spaces to align with pytorch coding styles.

Later on we can simplify this piece of code by applying vec.h dtype conversion and AT_DISPATCH_xxx macros.

}
}

__m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove these cvt functions later on when I uploaded vec.h

return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}

void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use ATen stype AT_DISPATCH_xx Macros to simply the code. You can just leave it as it is. And make the change later on.

}
}

static bool is_initialized = 0;
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not use 0, use false.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in fe608ad


void shm_initialize(int size, int rank, char* addr_string, char* port_string)
{
if (is_initialized) return;
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (is_initialized) { return; }

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in fe608ad


#define positive_mod(num, mod) ((((num) % (mod)) + (mod)) % (mod))
#define rank_mod(rank) positive_mod(rank, world_size)
size_t slice_size(size_t chunk_el, int slice_idx)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compute slice_size before calling into this function to remove integer div.

if (i != world_rank) wait_buffer_state_until_2(i, reduce_current, copy_next, state_group);
}

auto t4 = std::chrono::system_clock::now();
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not used t4

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed t4 in fe608ad

@chunyuan-w
Copy link
Copy Markdown
Collaborator Author

Generally LGTM! Try to use TAP 2 spaces to align with pytorch coding styles.

Later on we can simplify this piece of code by applying vec.h dtype conversion and AT_DISPATCH_xxx macros.

Re-formatted the c++ files in 1f1218d following pytorch coding styles.

@mingfeima mingfeima merged commit f90cfb7 into mingfeima:cpu_opt Feb 19, 2025
CaoE pushed a commit to CaoE/sglang that referenced this pull request Apr 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants