Large language models rely on relative position encoding. With FlashBias, we can easily speed up the computation of attention with ALiBi bias.
Implementation: ALiBi bias can be exactly decomposed as two basis functions:
Figure 1. Exact decomposition of ALiBi bias.
We consider the GPT-2 style model with 50 heads and 1,600 hidden channels in attention.
- Implementation
It is quite easy to implement ALiBi with FlashBias. Please see ./attention_with_alibi_bias.py for details.
- Efficiency comparison
For ../flash_bias/config.py, we change the configuration of the attention benchmark as follows:
BATCH, N_HEADS, HEAD_DIM, RANK_DIM, CAUSAL = 1, 50, 32, 2, TrueExecute the following command for efficiency comparison:
cd ../flash_bias
python benchmark.pyNote that vanilla flash_attn_triton is unstable during the backward pass phase when the sequence length is 8192. Our FlashBias and all the other kernels are correct across all input lengths.
Figure 1. Efficiency comparison for forward pass on GPT-2 size attention with ALiBi Bias.
Figure 2. Efficiency comparison for backward propagation on GPT-2 size attention with ALiBi Bias.
If you find this repo useful, please cite our paper.
@inproceedings{wu2025flashbias,
title={FlashBias: Fast Computation of Attention with Bias},
author={Haixu Wu and Minghao Guo and Yuezhou Ma and Yuanxu Sun and Jianmin Wang and Wojciech Matusik and Mingsheng Long},
booktitle={Advances in Neural Information Processing Systems},
year={2025}
}
If you have any questions or want to use the code, please contact wuhaixu98@gmail.com