[CUDA] enable causal in MultiHeadAttention#21852
Conversation
|
Thank you for adding this! With this PR, can we now pass a 2D attention mask of shape For context, the reformatting from 2D to 4D causal was done in the model builder because not all EPs implement causal attention masking with the |
Yes, you can use 2D mask directly (and no need to convert to 4D attention bias). However, 2D mask is for unfused kernel for now. In the future, we might add 2D to 1D conversion to help huggingface model, however that is not best choice for performance since such conversion need extra cuda kernels. I would suggest to use 1D mask of shape [B] (total seq lengths of each batch assuming right side padding, i.e. the reduce sum of 2D mask) if you want to get benefit of flash attention. Currently, memory efficient attention need another 1D format with shape [3B+2]. In the future, we will add conversion from 1d mask of shape [B] to tensors compatible with memory efficient attention. |
Enable causal in MultiHeadAttention cuda operator. All formats (Q_K_V_BSNH_BSNH_BSNH, Q_K_V_BSNH_BNSH_BNSH, Q_KV_BSNH_BSN2H and QKV_BSN3H) supports causal for now. Internally, casual will be dispatch to flash attention, efficient attention or unfused attention kernel. Currently, MultiHeadAttention has causal enabled in CPU ep, but not in CUDA ep. It could cause issues in onnx conversion, like some model can run in CPU but not in CUDA. Enable causal in CUDA will reduce the difference of support matrix of CPU/CUDA.
Description
Enable causal in MultiHeadAttention cuda operator.
All formats (Q_K_V_BSNH_BSNH_BSNH, Q_K_V_BSNH_BNSH_BNSH, Q_KV_BSNH_BSN2H and QKV_BSN3H) supports causal in CUDA for now. Internally, casual will be dispatch to flash attention, efficient attention or unfused attention kernel.
Motivation and Context
Currently, MultiHeadAttention has causal enabled in CPU ep, but not in CUDA ep. It could cause issues in onnx conversion, like some model can run in CPU but not in CUDA. Enable causal in CUDA will reduce the difference of support matrix of CPU/CUDA.