Add TensorScatter op for in-place kv cache update#7114
Conversation
85561bb to
9488105
Compare
|
@onnx/sig-operators-approvers |
|
Is this PR ready? |
|
I think this PR may have some undesired consequences. The nodes in the graph must be ordered in a way that all inputs exist when the node is added. There are many different possible orders for the same computation graph. But in that of this operator, we cannot change its position because it modifies both the first input and the output since they are the same. Any existing reordering algorithm may change the order in a way it changes the computation. The runtime can determine whether or not this operator can be executed inplace. If no posterior node is using the first input, then it can be done inplace. |
ONNX itself does not need to take into account the I/O aliasing as that's strictly a runtime property. For the runtime, in the common kv cache use cases I think only the output will be used by another op, not the input, so I don't think this will become a problem. Not sure if I'm missing some considerations or not. |
I'm not sure I understood what "in-place" means in fact. I looked into the reference implementation you made and I found |
|
This op allows backends to do an in-place optimization in their implementation I think? So it doesn’t not changed the functional nature of ONNX. |
|
Right. This op is functional ... it is a special-variant of Scatter op, it has the exact same nature (it is a functional representation of what is an imperative update in other languages that are imperative/non-functional). It is intended to enable backends realize and implement it in-place, but backends need to verify that it is safe to do so. So, there is no problem in that regard. |
2d53255 to
7f740f8
Compare
| ONNX_OPERATOR_SET_SCHEMA( | ||
| TensorScatter, |
Check notice
Code scanning / CodeQL
Unused static variable Note
Codecov ReportAttention: Patch coverage is
✅ All tests successful. No failed tests found.
Additional details and impacted files@@ Coverage Diff @@
## main #7114 +/- ##
==========================================
- Coverage 53.75% 53.73% -0.03%
==========================================
Files 510 512 +2
Lines 32204 32271 +67
Branches 2972 2982 +10
==========================================
+ Hits 17311 17340 +29
- Misses 14127 14161 +34
- Partials 766 770 +4 ☔ View full report in Codecov by Sentry. |
|
@gramalingam Thanks for the comments. I've updated the PR. Please lmk if there are further comments. |
cbourjau
left a comment
There was a problem hiding this comment.
Apologies for the late review! It might be obvious, but could you spell out why the existing scatter operators don't allow runtimes to do similar optimizations that this operator would allow?
The reason is two fold:
|
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
2635bc4 to
5971a15
Compare
|
IMHO, stabilizing this operator in its current shape is a mistake. I believe this operation to be a narrow special case of a more general scatter operation that we should tackle instead. Consider the following NumPy snippet, which appears to work fine for class TensorScatter(OpRun):
def _run(self, past_cache, update, write_indices=None, mode="linear", axis=-2):
if write_indices is None:
write_indices = np.zeros((past_cache.shape[0],), dtype=np.int64)
input_shape = past_cache.shape
update_shape = update.shape
max_sequence_length = input_shape[axis]
sequence_length = update_shape[axis]
batch_size = input_shape[0]
# build dense indices of shape (prod(input_shape[:axis]), sequence_length)
starts = write_indices
dense_indices = starts[:, None] + np.arange(sequence_length)
if mode == "circular":
dense_indices = dense_indices % max_sequence_length
dense_indices += (np.arange(batch_size) * np.prod(input_shape[1:axis+1]))[:, None]
assert dense_indices.shape == (np.prod(input_shape[:axis]), sequence_length)
common_shape = (-1,) + input_shape[axis+1:]
present_cache = np.copy(past_cache).reshape(common_shape)
present_cache[dense_indices.flatten(), ...] = update.reshape(common_shape)
return (present_cache.reshape(input_shape),)It highlights that this operator's semantics do not necessarily require huge index matrices when expressed with fairly standard tensor operations. The index tensor would be larger than the one of Concerning the question of reusing the existing cache, I'm not sure if runtimes can do much with the "hint" given by this operation. It seems like a rather obvious optimization for the implementation of any operator to try to reuse a buffer of the inputs if it is not needed by other operators downstream. I would think that this optimization is either already in place regardless of any hint given by the operator description, or it is not easy to implement for unrelated reasons that cannot be mitigated by this hint either. |
### Description To accompany the [TensorScatter-24](#7114) op for managing in-place KV cache update, this PR makes the following changes to the Attention op: - Add `nonpad_kv_seqlen` to indicate the number of valid (nonpadded) tokens in the K and V inputs when the K and V inputs are the entire cache tensors (where the number of valid tokens can potentially make up only a small proportion of the cache tensors). The `nonpad_kv_seqlen` input would provided optimization opportunities for backends to skip the unnecessary computation on the padding tokens. - Allow the kv_seqlen dimension (-1 dimension) of `attn_mask` input to be shorter than K and V. The missing portion will be assumed to be -inf. The length should still be larger than the max value in `nonpad_kv_seqlen`. Also, allow `attn_mask` and `is_causal` to be present at the same time. This would allow for easier export of HF models later. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? --> <!-- - If it fixes an open issue, please link to the issue here. --> --------- Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
This PR adds TensorScatter op for in-place kv updates, to be used in Attention computations. This op takes the same shape for the input past_cache and output present_cache (with the sequence length dimension being max_seqlen) so that backends are free to alias them to achieve efficient in-place updates during the autoregressive iterations of Transformer models. This work is part of the effort of the GenAI WG to enable LLM features in ONNX. --------- Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
To accompany the [TensorScatter-24](onnx#7114) op for managing in-place KV cache update, this PR makes the following changes to the Attention op: - Add `nonpad_kv_seqlen` to indicate the number of valid (nonpadded) tokens in the K and V inputs when the K and V inputs are the entire cache tensors (where the number of valid tokens can potentially make up only a small proportion of the cache tensors). The `nonpad_kv_seqlen` input would provided optimization opportunities for backends to skip the unnecessary computation on the padding tokens. - Allow the kv_seqlen dimension (-1 dimension) of `attn_mask` input to be shorter than K and V. The missing portion will be assumed to be -inf. The length should still be larger than the max value in `nonpad_kv_seqlen`. Also, allow `attn_mask` and `is_causal` to be present at the same time. This would allow for easier export of HF models later. <!-- - Why is this change required? What problem does it solve? --> <!-- - If it fixes an open issue, please link to the issue here. --> --------- Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
|
@cbourjau Sorry I missed your reply. It would be harder for me to miss with direct tagging. :)
|
### Description This PR adds TensorScatter op for in-place kv updates, to be used in Attention computations. This op takes the same shape for the input past_cache and output present_cache (with the sequence length dimension being max_seqlen) so that backends are free to alias them to achieve efficient in-place updates during the autoregressive iterations of Transformer models. ### Motivation and Context This work is part of the effort of the GenAI WG to enable LLM features in ONNX. --------- Signed-off-by: Yuan Yao <yuanyao@nvidia.com> Signed-off-by: Yash solanki <alphacr792@gmail.com>
### Description To accompany the [TensorScatter-24](onnx#7114) op for managing in-place KV cache update, this PR makes the following changes to the Attention op: - Add `nonpad_kv_seqlen` to indicate the number of valid (nonpadded) tokens in the K and V inputs when the K and V inputs are the entire cache tensors (where the number of valid tokens can potentially make up only a small proportion of the cache tensors). The `nonpad_kv_seqlen` input would provided optimization opportunities for backends to skip the unnecessary computation on the padding tokens. - Allow the kv_seqlen dimension (-1 dimension) of `attn_mask` input to be shorter than K and V. The missing portion will be assumed to be -inf. The length should still be larger than the max value in `nonpad_kv_seqlen`. Also, allow `attn_mask` and `is_causal` to be present at the same time. This would allow for easier export of HF models later. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? --> <!-- - If it fixes an open issue, please link to the issue here. --> --------- Signed-off-by: Yuan Yao <yuanyao@nvidia.com> Signed-off-by: Yash solanki <alphacr792@gmail.com>
Description
This PR adds TensorScatter op for in-place kv updates, to be used in Attention computations.
This op takes the same shape for the input past_cache and output present_cache (with the sequence length dimension being max_seqlen) so that backends are free to alias them to achieve efficient in-place updates during the autoregressive iterations of Transformer models.
Motivation and Context
This work is part of the effort of the GenAI WG to enable LLM features in ONNX.