CANN: support gated linear attn#17814
Conversation
|
Note that this operator might be removed at some point - see #17716 (comment) |
|
@YushengZhao Thank you for your contribution! |
| aclTensor* acl_s = ggml_cann_create_tensor(s, ne_s, nb_s, 2, ACL_FORMAT_ND, s_offset); | ||
| aclTensor* acl_s_new = ggml_cann_create_tensor(dst, ne_s, nb_s, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset); | ||
| cann_copy(ctx, acl_s, acl_s_new); | ||
| for (int64_t l = 0; l < L; l++) { |
There was a problem hiding this comment.
This contains a triple loop, which may lead to poor performance. Are there better design ideas to reduce the complexity of operator calls?
| * @brief Computes the Gated Linear Attention for a ggml tensor using the CANN | ||
| * backend. | ||
| * | ||
| * @details ... |
There was a problem hiding this comment.
Is this detailed section still not finished?
| */ | ||
| void aclnn_sin(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst); | ||
|
|
||
| static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst); |
There was a problem hiding this comment.
If you need call some inner functions, it's better to declare them here, just like aclnn_sin does.
|
@hipudding In case you're wondering I'm the one canceling Server CI jobs, they are causing trouble, fix pending merge... |
Thanks for the clarification. That makes sense. |
|
@YushengZhao Please rebase, Thanks. |
|
This PR is updated in #18653 |
Description
This PR adds support for the Gated Linear Attention (GLA) operator in the GGML CANN backend. GLA is widely used in efficient attention mechanisms (e.g., RWKV, Linear Transformer variants, etc.), which leverage gating signals and state accumulation to significantly reduce computational complexity while preserving strong modeling capacity.
Summary of Changes:
GGML_OP_GATED_LINEAR_ATTNinggml/src/ggml-cann/ggml-cann.cppand bound it to a newly implemented functionggml_cann_gated_linear_attn.ggml_cann_gated_linear_attninggml/src/ggml-cann/aclnn_ops.cpp, using ACLNN primitives such asRepeat,Mul,Add, andMvto compose the GLA computation.(C, H, T, B), where:C = H * D(total channel dimension),T = B * L(flattened batch × sequence length),gand recurrent statesas additional inputs, enabling joint state update and output generation in a single pass.Testing
Build with CANN backend enabled:
Run GLA-specific backend test (requires adding a test case for
GATED_LINEAR_ATTNintests/test-backend-ops.cpp):./bin/test-backend-ops test -b CANN0 -o GATED_LINEAR_ATTN