Optimize all_reduce by porting the shared memory kernel of deepspeed#5
Conversation
all_reduce by porting the shared memory kernel of deepspeed.all_reduce by porting the shared memory kernel of deepspeed
mingfeima
left a comment
There was a problem hiding this comment.
Generally LGTM! Try to use TAP 2 spaces to align with pytorch coding styles.
Later on we can simplify this piece of code by applying vec.h dtype conversion and AT_DISPATCH_xxx macros.
| } | ||
| } | ||
|
|
||
| __m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw"))); |
There was a problem hiding this comment.
we can remove these cvt functions later on when I uploaded vec.h
| return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); | ||
| } | ||
|
|
||
| void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) |
There was a problem hiding this comment.
we can use ATen stype AT_DISPATCH_xx Macros to simply the code. You can just leave it as it is. And make the change later on.
| } | ||
| } | ||
|
|
||
| static bool is_initialized = 0; |
|
|
||
| void shm_initialize(int size, int rank, char* addr_string, char* port_string) | ||
| { | ||
| if (is_initialized) return; |
There was a problem hiding this comment.
if (is_initialized) { return; }
|
|
||
| #define positive_mod(num, mod) ((((num) % (mod)) + (mod)) % (mod)) | ||
| #define rank_mod(rank) positive_mod(rank, world_size) | ||
| size_t slice_size(size_t chunk_el, int slice_idx) |
There was a problem hiding this comment.
compute slice_size before calling into this function to remove integer div.
| if (i != world_rank) wait_buffer_state_until_2(i, reduce_current, copy_next, state_group); | ||
| } | ||
|
|
||
| auto t4 = std::chrono::system_clock::now(); |
Re-formatted the c++ files in 1f1218d following pytorch coding styles. |
Motivation
Optimize
all_reduceby porting the shared memory implementation in deepspeed.Modifications
We added a
shm_allreduceoperator insgl-kernel.The implementation is ported from DeepSpeed:
sgl-kernel/src/sgl-kernel/csrc/cpu/shm.h---> DeepSpeed: shm.hsgl-kernel/src/sgl-kernel/csrc/cpu/shm.cpp---> DeepSpeed: shm.cppsgl-kernel/src/sgl-kernel/csrc/cpu/interface.cpp---> DeepSpeed: ccl.cppTo build the kernel:
cd sgl-kernel python setup.py developWe added a wrapper call
tensor_model_parallel_all_reduce_wrapperin SGLang to call theshm_allreducefor CPU and call the originaltensor_model_parallel_all_reduceAPI in vllm for other devices, so that we don't need to change thetensor_model_parallel_all_reducefunction in vllm.In
sgl-kernel/src/sgl-kernel/ops/__init__.py, we only import cuda kernels if cuda is available.Benchmarks
Accuracy
The score on mmlu (higher is better) for tp=2:
score without this PR (using torch.distributed.all_reduce): 0.594
score with this PR: 0.594
Note:
--disable-overlap-scheduleis needed in the args for CPU, otherwise, a new thread will be created for the forward batch here and the OMP threads binding will not work on this newly created thread.Command line:
# Client side python3 -m sglang.test.run_eval --eval-name mmlu --num-examples 64 --port 30000Performance
We can observe 25% and 37% speedup on first and next token latency respectively after switching from
torch.distributed.all_reducetoshm_allreducefor the below tp=2 command line on GNR.SGLANG_CPU_OMP_THREADS_BIND="0-39|40-79" python3 -m sglang.bench_one_batch --batch-size 1 --input 1024 --output 8 --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --trust-remote-code --device cpu --attention-backend torch_native --disable-mla --tp 2Limitations
The shm
all_reduceonly supports FP32 and BF16 and the cases where all the ranks are local. We fallback totorch.distributed.all_reducefor unsupported cases.