Skip to content

Add TensorScatter op for in-place kv cache update#7114

Merged
yuanyao-nv merged 9 commits intoonnx:mainfrom
yuanyao-nv:dev-tensorscatter
Jul 24, 2025
Merged

Add TensorScatter op for in-place kv cache update#7114
yuanyao-nv merged 9 commits intoonnx:mainfrom
yuanyao-nv:dev-tensorscatter

Conversation

@yuanyao-nv
Copy link
Copy Markdown
Contributor

@yuanyao-nv yuanyao-nv commented Jul 9, 2025

Description

This PR adds TensorScatter op for in-place kv updates, to be used in Attention computations.

This op takes the same shape for the input past_cache and output present_cache (with the sequence length dimension being max_seqlen) so that backends are free to alias them to achieve efficient in-place updates during the autoregressive iterations of Transformer models.

Motivation and Context

This work is part of the effort of the GenAI WG to enable LLM features in ONNX.

@yuanyao-nv yuanyao-nv requested a review from a team as a code owner July 9, 2025 06:20
@github-project-automation github-project-automation Bot moved this to In progress in PR Tracker Jul 9, 2025
@yuanyao-nv yuanyao-nv marked this pull request as draft July 9, 2025 06:21
Comment thread onnx/backend/test/case/node/tensorscatter.py Fixed
Comment thread onnx/reference/ops/op_tensor_scatter.py Fixed
Comment thread onnx/backend/test/case/node/tensorscatter.py Fixed
Comment thread onnx/defs/tensor/defs.cc Fixed
@justinchuby justinchuby added this to the 1.19 milestone Jul 9, 2025
@yuanyao-nv yuanyao-nv marked this pull request as ready for review July 10, 2025 22:34
@yuanyao-nv yuanyao-nv requested a review from a team as a code owner July 10, 2025 22:34
@yuanyao-nv yuanyao-nv changed the title Draft: Add tensorscatter op for in-place kv cache udpate Add tensorscatter op for in-place kv cache udpate Jul 10, 2025
@justinchuby
Copy link
Copy Markdown
Member

@onnx/sig-operators-approvers

@yuanyao-nv yuanyao-nv changed the title Add tensorscatter op for in-place kv cache udpate Add TensorScatter op for in-place kv cache udpate Jul 11, 2025
Comment thread onnx/reference/ops/op_tensor_scatter.py Outdated
Comment thread onnx/backend/test/case/node/tensorscatter.py Fixed
Comment thread onnx/backend/test/case/node/tensorscatter.py Fixed
Comment thread onnx/backend/test/case/node/tensorscatter.py Fixed
Comment thread onnx/backend/test/case/node/tensorscatter.py Fixed
Comment thread onnx/backend/test/case/node/tensorscatter.py Fixed
Comment thread onnx/reference/ops/op_tensor_scatter.py Fixed
Comment thread onnx/reference/ops/op_tensor_scatter.py Fixed
Comment thread onnx/reference/ops/op_tensor_scatter.py Fixed
Comment thread onnx/reference/ops/op_tensor_scatter.py Fixed
Comment thread onnx/reference/ops/op_tensor_scatter.py Fixed
@justinchuby
Copy link
Copy Markdown
Member

Is this PR ready?

Comment thread onnx/backend/test/case/node/tensorscatter.py Outdated
Comment thread onnx/defs/tensor/defs.cc Outdated
Comment thread onnx/defs/tensor/defs.cc Outdated
Comment thread onnx/defs/tensor/defs.cc
@xadupre
Copy link
Copy Markdown
Contributor

xadupre commented Jul 16, 2025

I think this PR may have some undesired consequences. The nodes in the graph must be ordered in a way that all inputs exist when the node is added. There are many different possible orders for the same computation graph. But in that of this operator, we cannot change its position because it modifies both the first input and the output since they are the same. Any existing reordering algorithm may change the order in a way it changes the computation. The runtime can determine whether or not this operator can be executed inplace. If no posterior node is using the first input, then it can be done inplace.

@yuanyao-nv
Copy link
Copy Markdown
Contributor Author

I think this PR may have some undesired consequences. The nodes in the graph must be ordered in a way that all inputs exist when the node is added. There are many different possible orders for the same computation graph. But in that of this operator, we cannot change its position because it modifies both the first input and the output since they are the same. Any existing reordering algorithm may change the order in a way it changes the computation. The runtime can determine whether or not this operator can be executed inplace. If no posterior node is using the first input, then it can be done inplace.

ONNX itself does not need to take into account the I/O aliasing as that's strictly a runtime property. For the runtime, in the common kv cache use cases I think only the output will be used by another op, not the input, so I don't think this will become a problem. Not sure if I'm missing some considerations or not.
cc @gramalingam

@xadupre
Copy link
Copy Markdown
Contributor

xadupre commented Jul 16, 2025

I think this PR may have some undesired consequences. The nodes in the graph must be ordered in a way that all inputs exist when the node is added. There are many different possible orders for the same computation graph. But in that of this operator, we cannot change its position because it modifies both the first input and the output since they are the same. Any existing reordering algorithm may change the order in a way it changes the computation. The runtime can determine whether or not this operator can be executed inplace. If no posterior node is using the first input, then it can be done inplace.

ONNX itself does not need to take into account the I/O aliasing as that's strictly a runtime property. For the runtime, in the common kv cache use cases I think only the output will be used by another op, not the input, so I don't think this will become a problem. Not sure if I'm missing some considerations or not. cc @gramalingam

I'm not sure I understood what "in-place" means in fact. I looked into the reference implementation you made and I found present_cache = np.copy(past_cache) so it is not really inplace, it is just a new scatter op with a syntax specific to cache update. I had in mind something like that:

updated_cache = TensorScatter(cache, ...)   # if it is really inplace, updated_cache and cache are the same, cache no longer exists
w = cache + 1   # so the following is not possible

@justinchuby
Copy link
Copy Markdown
Member

This op allows backends to do an in-place optimization in their implementation I think? So it doesn’t not changed the functional nature of ONNX.

@gramalingam
Copy link
Copy Markdown
Contributor

Right. This op is functional ... it is a special-variant of Scatter op, it has the exact same nature (it is a functional representation of what is an imperative update in other languages that are imperative/non-functional). It is intended to enable backends realize and implement it in-place, but backends need to verify that it is safe to do so.

So, there is no problem in that regard.

@justinchuby justinchuby added the topic: operator Issues related to ONNX operators label Jul 21, 2025
Comment thread onnx/defs/tensor/defs.cc
Comment on lines +3969 to +3985
ONNX_OPERATOR_SET_SCHEMA(
TensorScatter,

Check notice

Code scanning / CodeQL

Unused static variable Note

Static variable dbg_count_check_Onnx_24_verTensorScatter is never read.
@codecov
Copy link
Copy Markdown

codecov Bot commented Jul 22, 2025

Codecov Report

Attention: Patch coverage is 43.28358% with 38 lines in your changes missing coverage. Please review.

Project coverage is 53.73%. Comparing base (88d94b8) to head (5971a15).
Report is 1 commits behind head on main.

✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
onnx/backend/test/case/node/tensorscatter.py 0.00% 30 Missing ⚠️
onnx/reference/ops/op_tensor_scatter.py 74.19% 4 Missing and 4 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #7114      +/-   ##
==========================================
- Coverage   53.75%   53.73%   -0.03%     
==========================================
  Files         510      512       +2     
  Lines       32204    32271      +67     
  Branches     2972     2982      +10     
==========================================
+ Hits        17311    17340      +29     
- Misses      14127    14161      +34     
- Partials      766      770       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment thread onnx/reference/ops/op_tensor_scatter.py Outdated
Comment thread onnx/defs/tensor/defs.cc
@yuanyao-nv
Copy link
Copy Markdown
Contributor Author

@gramalingam Thanks for the comments. I've updated the PR. Please lmk if there are further comments.

@github-project-automation github-project-automation Bot moved this from In progress to Reviewer approved in PR Tracker Jul 22, 2025
Comment thread onnx/defs/tensor/defs.cc Outdated
Copy link
Copy Markdown
Contributor

@cbourjau cbourjau left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the late review! It might be obvious, but could you spell out why the existing scatter operators don't allow runtimes to do similar optimizations that this operator would allow?

Comment thread onnx/defs/tensor/defs.cc
Comment thread onnx/backend/test/case/node/tensorscatter.py
Comment thread onnx/reference/ops/op_tensor_scatter.py Outdated
@yuanyao-nv
Copy link
Copy Markdown
Contributor Author

Apologies for the late review! It might be obvious, but could you spell out why the existing scatter operators don't allow runtimes to do similar optimizations that this operator would allow?

The reason is two fold:

  • The ScatterND op can cover the functionality of this op, but it requires the index to be given for each head and each token as well (taking the BNSH format as example). So the indices input would need to be a much larger tensor.
  • Having a dedicated op for kv cache update can indicate more clearly to backends that this is not a regular scatter operation and that the I/O tensors should probably be aliased for efficiency.

Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
@yuanyao-nv yuanyao-nv merged commit 7c3b330 into onnx:main Jul 24, 2025
38 checks passed
@github-project-automation github-project-automation Bot moved this from Reviewer approved to Done in PR Tracker Jul 24, 2025
@cbourjau
Copy link
Copy Markdown
Contributor

IMHO, stabilizing this operator in its current shape is a mistake. I believe this operation to be a narrow special case of a more general scatter operation that we should tackle instead. Consider the following NumPy snippet, which appears to work fine for TensorScatter as discussed here:

class TensorScatter(OpRun):
    def _run(self, past_cache, update, write_indices=None, mode="linear", axis=-2):
        if write_indices is None:
            write_indices = np.zeros((past_cache.shape[0],), dtype=np.int64)

        input_shape = past_cache.shape
        update_shape = update.shape

        max_sequence_length = input_shape[axis]
        sequence_length = update_shape[axis]
        batch_size = input_shape[0]

        # build dense indices of shape (prod(input_shape[:axis]), sequence_length)
        starts = write_indices
        dense_indices = starts[:, None] + np.arange(sequence_length)

        if mode == "circular":
            dense_indices = dense_indices % max_sequence_length

        dense_indices += (np.arange(batch_size) * np.prod(input_shape[1:axis+1]))[:, None]

        assert dense_indices.shape == (np.prod(input_shape[:axis]), sequence_length)

        common_shape = (-1,) + input_shape[axis+1:]

        present_cache = np.copy(past_cache).reshape(common_shape)
        present_cache[dense_indices.flatten(), ...] = update.reshape(common_shape)

        return (present_cache.reshape(input_shape),)

It highlights that this operator's semantics do not necessarily require huge index matrices when expressed with fairly standard tensor operations. The index tensor would be larger than the one of TensoreScatter, but given the motivating context I don't readily believe that this would add significant overhead. Do we have any data that this operator would reduce the runtime measurable compared to using existing operators?

Concerning the question of reusing the existing cache, I'm not sure if runtimes can do much with the "hint" given by this operation. It seems like a rather obvious optimization for the implementation of any operator to try to reuse a buffer of the inputs if it is not needed by other operators downstream. I would think that this optimization is either already in place regardless of any hint given by the operator description, or it is not easy to implement for unrelated reasons that cannot be mitigated by this hint either.

yuanyao-nv added a commit that referenced this pull request Jul 30, 2025
### Description

To accompany the
[TensorScatter-24](#7114) op for
managing in-place KV cache update, this PR makes the following changes
to the Attention op:
- Add `nonpad_kv_seqlen` to indicate the number of valid (nonpadded)
tokens in the K and V inputs when the K and V inputs are the entire
cache tensors (where the number of valid tokens can potentially make up
only a small proportion of the cache tensors). The `nonpad_kv_seqlen`
input would provided optimization opportunities for backends to skip the
unnecessary computation on the padding tokens.
- Allow the kv_seqlen dimension (-1 dimension) of `attn_mask` input to
be shorter than K and V. The missing portion will be assumed to be -inf.
The length should still be larger than the max value in
`nonpad_kv_seqlen`.

Also, allow `attn_mask` and `is_causal` to be present at the same time.
This would allow for easier export of HF models later.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve? -->
<!-- - If it fixes an open issue, please link to the issue here. -->

---------

Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
alx256 pushed a commit to alx256/onnx that referenced this pull request Aug 1, 2025
This PR adds TensorScatter op for in-place kv updates, to be used in
Attention computations.

This op takes the same shape for the input past_cache and output
present_cache (with the sequence length dimension being max_seqlen) so
that backends are free to alias them to achieve efficient in-place
updates during the autoregressive iterations of Transformer models.

This work is part of the effort of the GenAI WG to enable LLM features
in ONNX.

---------

Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
alx256 pushed a commit to alx256/onnx that referenced this pull request Aug 1, 2025
To accompany the
[TensorScatter-24](onnx#7114) op for
managing in-place KV cache update, this PR makes the following changes
to the Attention op:
- Add `nonpad_kv_seqlen` to indicate the number of valid (nonpadded)
tokens in the K and V inputs when the K and V inputs are the entire
cache tensors (where the number of valid tokens can potentially make up
only a small proportion of the cache tensors). The `nonpad_kv_seqlen`
input would provided optimization opportunities for backends to skip the
unnecessary computation on the padding tokens.
- Allow the kv_seqlen dimension (-1 dimension) of `attn_mask` input to
be shorter than K and V. The missing portion will be assumed to be -inf.
The length should still be larger than the max value in
`nonpad_kv_seqlen`.

Also, allow `attn_mask` and `is_causal` to be present at the same time.
This would allow for easier export of HF models later.

<!-- - Why is this change required? What problem does it solve? -->
<!-- - If it fixes an open issue, please link to the issue here. -->

---------

Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
@yuanyao-nv
Copy link
Copy Markdown
Contributor Author

@cbourjau Sorry I missed your reply. It would be harder for me to miss with direct tagging. :)
To answer your questions:

  • Regarding the size of the indices input, like I mentioned earlier it'd be a factor of H*S larger if we are to use ScatterND. For BNSH format, H ranges from 4 to 16 typically. S can be on the order of tens to hundreds for prefill.
  • Regarding tensor aliasing. For intermediate ops whose tensors do not show up as model I/O, it's true that reuse happens all the time - in fact, they might not even materialize if they are part of a fused pattern. For model I/O tensors though, separate buffers are allocated typically. The runtime/backend does need additional logic to achieve aliasing.

MagellaX pushed a commit to MagellaX/onnx that referenced this pull request Aug 9, 2025
### Description
This PR adds TensorScatter op for in-place kv updates, to be used in
Attention computations.

This op takes the same shape for the input past_cache and output
present_cache (with the sequence length dimension being max_seqlen) so
that backends are free to alias them to achieve efficient in-place
updates during the autoregressive iterations of Transformer models.

### Motivation and Context
This work is part of the effort of the GenAI WG to enable LLM features
in ONNX.

---------

Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yash solanki <alphacr792@gmail.com>
MagellaX pushed a commit to MagellaX/onnx that referenced this pull request Aug 9, 2025
### Description

To accompany the
[TensorScatter-24](onnx#7114) op for
managing in-place KV cache update, this PR makes the following changes
to the Attention op:
- Add `nonpad_kv_seqlen` to indicate the number of valid (nonpadded)
tokens in the K and V inputs when the K and V inputs are the entire
cache tensors (where the number of valid tokens can potentially make up
only a small proportion of the cache tensors). The `nonpad_kv_seqlen`
input would provided optimization opportunities for backends to skip the
unnecessary computation on the padding tokens.
- Allow the kv_seqlen dimension (-1 dimension) of `attn_mask` input to
be shorter than K and V. The missing portion will be assumed to be -inf.
The length should still be larger than the max value in
`nonpad_kv_seqlen`.

Also, allow `attn_mask` and `is_causal` to be present at the same time.
This would allow for easier export of HF models later.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve? -->
<!-- - If it fixes an open issue, please link to the issue here. -->

---------

Signed-off-by: Yuan Yao <yuanyao@nvidia.com>
Signed-off-by: Yash solanki <alphacr792@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

topic: operator Issues related to ONNX operators

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants