Add xformer and support training on V100s
Why are these changes needed?
We are going to use xformer instead of flash attention. Xformer is better because:
- It supports more GPU architectures than flash attention, including V100
- It has similar memory footprint and flops compared to flashattention
- It is developed and maintained by Meta and has more useful functionality.
We can gradually deprecate flash attention.
cc @DachengLi1
Co-authored-by: Dacheng Li[email protected]
Related issue number (if applicable)
Checks
- [x] I've run
format.shto lint the changes in this PR. - [x] I've included any doc changes needed.
- [ ] I've made sure the relevant tests are passing (if applicable).
The authors of Flash Attention have also developed the triton-based implementation (https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py). How about replacing their original implementation with the triton-based?
This works fantastically on V100 gpu(s), please merge it ASAP! Appreciate it!
@ss-zheng thanks, will merge soon.
Thanks for this great PR! Can you also apply it to other finetuning scripts like train_lora.py, etc.
FYI, I just learned that xformers' memory-efficient attention has been upstreamed to torch.nn.functional.scaled_dot_product_attention: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html .
However, it needs code change in transformers to enable/pin down. I think this PR is still worth merging.
@zhisbug just a reminder that the developer of Flash Attention gave up on the V100s.
https://github.com/Dao-AILab/flash-attention/issues/148#issuecomment-1573216640
The authors of Flash Attention have also developed the triton-based implementation (https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py). How about replacing their original implementation with the triton-based?
That would force everyone to use PyTorch 2.0, which is not deployed in many supercomputing centres.
So we can work on V100 with flash attention models like Llama 2?
@jshin49 doesn't seem so, look at the comment before.