Fix 3D attention transpose bug in reference implementation#7129
Fix 3D attention transpose bug in reference implementation#7129
Conversation
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
❌ 1 Tests Failed:
View the top 1 failed test(s) by shortest run time
To view more test analytics, go to the Test Analytics Dashboard |
Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
|
@copilot, can you make the exact same fix within the function-definition of the op in its opschema? |
Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com>
Applied the exact same fix to the opschema function definition in |
| .Add("QReshaped = Reshape (Q, QNewShape)") | ||
| .Add("KReshaped = Reshape (K, KVNewShape)") | ||
| .Add("VReshaped = Reshape (V, KVNewShape)") | ||
| // First reshape to [batch_size, seq_length, num_heads, head_size] |
There was a problem hiding this comment.
I think the Reshape and Transpose should be added only for 3D inputs, and not for 4D inputs. Currently, we are using the num_head attributes to determine if the input is 3D or 4D. Can you make this change? Please introduce a boolean variable is_3d_input for this condition to make the code more readable.
There was a problem hiding this comment.
Please see the code at the very bottom of the function definition that does the matching Transpose/Reshape for the output for guidance.
There was a problem hiding this comment.
Added the is_3d_input boolean variable as requested. The reshape and transpose operations are now applied only for 3D inputs (when q_num_heads > 0 && kv_num_heads > 0), while 4D inputs use the original direct reshape behavior. This makes the code more readable and preserves backward compatibility for 4D cases. Commit: 9e5e5c3
There was a problem hiding this comment.
Updated the output handling to use the is_3d_input boolean variable for consistency, matching the pattern used in the input processing section. Commit: 7673bc3
Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com>
| // For 4D inputs: Direct reshape to [batch_size, num_heads, seq_length, head_size] | ||
| builder.Add("QNewShape = Concat <axis = 0> (BatchSize, QNumHeadsAttr, QSeqLen, NegOne)") | ||
| .Add("KVNewShape = Concat <axis = 0> (BatchSize, KVNumHeadsAttr, KVSeqLen, NegOne)") | ||
| .Add("QReshaped = Reshape (Q, QNewShape)") |
There was a problem hiding this comment.
There is no need to reshape in the 4D case, it is already in the desired shape. We can use QReshaped = Identity(Q) to ensure the desired behavior. Same for K and V.
There was a problem hiding this comment.
Changed 4D input handling to use Identity operations instead of unnecessary Reshape operations. The tensors are already in the correct shape [batch_size, num_heads, seq_length, head_size] for 4D inputs. Commit: f4ebea3
Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com>
Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com>
Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
|
Replaced by #7142. |
### Description Replaces #7129. Fixes issue #7118. It does: * fixes class Attention in onnx.reference.ops.op_attention * fixes expanded C++ function in onnx/defs/nn/defs.cc * adds example using float16 and with 4D masks --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Signed-off-by: xadupre <xadupre@microsoft.com> Signed-off-by: Xavier Dupré <xadupre@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com>
### Description Replaces onnx#7129. Fixes issue onnx#7118. It does: * fixes class Attention in onnx.reference.ops.op_attention * fixes expanded C++ function in onnx/defs/nn/defs.cc * adds example using float16 and with 4D masks --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Signed-off-by: xadupre <xadupre@microsoft.com> Signed-off-by: Xavier Dupré <xadupre@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com> Signed-off-by: Yash solanki <alphacr792@gmail.com>
Problem
The reference implementation for the Attention operator incorrectly handles 3D input tensors by directly reshaping them without the necessary transpose operation, violating the ONNX specification.
According to the ONNX spec, 3D inputs with shape
[batch_size, seq_length, hidden_size]should be:[batch_size, seq_length, num_heads, head_size][batch_size, num_heads, seq_length, head_size]However, the current implementation directly reshapes to
[batch_size, num_heads, seq_length, head_size], skipping the intermediate transpose step.Root Cause
In
onnx/reference/ops/op_attention.py, lines 55-65 showed the bug:This violates the specification because it incorrectly interprets the hidden dimension layout.
Solution
Applied minimal fix to implement the correct two-step process:
Changes Made
onnx/reference/ops/op_attention.py: Fixed 3D to 4D conversion logic for Q, K, and V tensorsonnx/backend/test/case/node/attention.py: Addedexport_attention_3d_transpose_verificationtest caseVerification
The fix ensures ONNX reference implementation correctly follows the specification for 3D attention inputs, which is critical for model compatibility and correctness.
Fixes #7118.
💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.