implement faster RoPE embedding#238
Conversation
|
Thanks @HuyNguyen-hust a lot! As per our discussion on Discord - I just want to say thank you again - super apprecitate this! Will do some tests on my end and I'll expedite this PR! |
|
@HuyNguyen-hust I tested the kernel! Can confirm RoPE itself should be faster. The effect on a full training run though is less pronounced sadly, since through Pytorch's Profiler, RoPE itself now takes around 1% of the total runtime, with matrix multiplications taking the bulk of the time. DPO for eg - with your RoPE fix: 1553 seconds. Original: 1542 seconds. So within the margin of error. This was on Colab T4, so I'm pretty sure A100s get more noticeable effects. However, your kernel works absolute wonders when long sequence lengths come into play! The RoPE kernel does creep up to around 2-3% of the total runtime, which means savings are well deserved! Thanks so much for wonderful contribution - added this in! :) I'll probably play around with the group size - it seems like this might be an auto-tunable number!!! |
|
awesome @HuyNguyen-hust, congrats on your great work! |
1 similar comment
|
awesome @HuyNguyen-hust, congrats on your great work! |
|
thanks |
|
cool :O |
|
Congrats @HuyNguyen-hust! Great contribution! |
PR proposes a bit change to the current RoPE embedding kernel:
Benchmark with batch_size=4, head_dim=128, n_heads=32 (// 2 means BLOCK_SIZE=head_dim // 2. If not BLOCK_SIZE=head_dim):

The figure indicates that mine is more sensitive to BLOCK_SIZE.