Skip to content

Minor improvements to token_type_ids extension for PA#34661

Merged
p-wysocki merged 38 commits intoopenvinotoolkit:masterfrom
p-wysocki:attn_fixes
Apr 3, 2026
Merged

Minor improvements to token_type_ids extension for PA#34661
p-wysocki merged 38 commits intoopenvinotoolkit:masterfrom
p-wysocki:attn_fixes

Conversation

@p-wysocki
Copy link
Copy Markdown
Contributor

Details:

Tickets:

  • N/A

Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
…nto attn_idea_2

Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
…into attn_idea_2

Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
@p-wysocki p-wysocki requested a review from a team as a code owner March 12, 2026 11:34
@github-actions github-actions bot added category: Core OpenVINO Core (aka ngraph) category: transformations OpenVINO Runtime library - Transformations labels Mar 12, 2026
// Shared flag to track whether the model is Gemma3, set when any layer matches
// the gptoss_gemma3 sliding window pattern. Combined with the token_type_ids check,
// this uniquely identifies Gemma3 (gpt-oss shares the pattern but lacks token_type_ids).
auto is_gptoss_gemma3 = std::make_shared<bool>(false);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define this variable inside the callback?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree it looks strange is required to define it outside

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemma3 has a repeating sequence of attention layers: 5x sliding window attention, 1x full attention. The pattern we currently have detects sliding window, but token_type_ids has to be passed to full attention layers as well.

has_token_type_ids is defined outside of the callback as shared_ptr, because it has to stay consistent between all lambda callbacks - since lambda's capture is =, it gets a new shared pointer to the object. Without it, the token_type_ids would be routed to PA only for sliding window PAa, and not for full attention PAs.

Technically we could detect full attention pattern to avoid gpt-oss/gemma3 mixup and do the same trick, but then the first 5x sliding window attentions would not receive token_type_ids input, because only the first full attention layer (6th in line) would set the variable to true.

Summing up, it may not be as clean as I'd like it to be, but it works. If you insist that this piece of code will cause issues I can keep looking for a universal pattern which would:

  1. separate gpt-oss and gemma3
  2. work for both sliding window and full attention layers

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's a little dirty solution, but if there's no other option, I believe we can live with it.

I can keep looking for a universal pattern

Any ideas of what this could be?

Copy link
Copy Markdown
Contributor Author

@p-wysocki p-wysocki Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll be improving the GenAI integration next week, I'll give finding a better pattern another go, as I'll be modifying this file anyway. If I find it, the whole thing will be solved gracefully. For now IMO the PR can be merged, as overall it's a net positive change over master.

sliding_window = std::make_shared<v1::Subtract>(v0::Constant::create(element::i32, Shape{}, {2}), offset);
} else if (pattern_map.count(gptoss_gemma3_offset)) {
*is_gptoss_gemma3 = true;
is_gemma3 = optional_model_wide_params.count("token_type_ids");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact any model with token_type_ids and matching sliding window pattern will set this is_gemma3 flag true, why not simply name this variable has_token_type_ids?
Or set has_sliding_window here instead, and use below.
Also currently is_gemma3 will be false for causal mask case (no sliding window) within the same model.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed the variable, and regarding the no sliding window case, the explanation is provided in #34661 (comment).

Comment on lines 760 to 763
if (is_gemma3) {
pa_arguments.insert(pa_arguments.begin() + 25, handle_gemma3_token_type_ids(optional_model_wide_params));
} else {
pa_arguments.insert(pa_arguments.begin() + 25, v0::Constant::create(element::i32, Shape{0}, {}));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable naming is tight to gemma3 but it can be generic for any model having has_token_type_ids and has_sliding_window true.
It is currently applied for sliding_window case only, but as a next step it could be extended to causal case as well then this if else will be reduced to single case:

pa_arguments.insert(pa_arguments.begin() + 25, handle_token_type_ids(optional_model_wide_params));

Suggested change
if (is_gemma3) {
pa_arguments.insert(pa_arguments.begin() + 25, handle_gemma3_token_type_ids(optional_model_wide_params));
} else {
pa_arguments.insert(pa_arguments.begin() + 25, v0::Constant::create(element::i32, Shape{0}, {}));
if (has_sliding_window) {
pa_arguments.insert(pa_arguments.begin() + 25, handle_token_type_ids(optional_model_wide_params));
} else {
pa_arguments.insert(pa_arguments.begin() + 25, v0::Constant::create(element::i32, Shape{0}, {}));

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the variable name. The token_type_ids is currently working also for causal case, see #34661 (comment).

…into attn_fixes

Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refines Gemma3 token_type_ids handling for the SDPA→PagedAttention transformation and strengthens PagedAttentionExtension type-propagation coverage around the newly-supported token_type_ids ranks.

Changes:

  • Add type-prop tests validating token_type_ids acceptance for rank-1/rank-2, dynamic shape, and invalid type/rank cases.
  • Simplify token_type_ids retrieval/conversion in the Gemma3 path by assuming presence when the Gemma3 condition is met and avoiding an internal fallback.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
src/core/tests/type_prop/paged_attention.cpp Adds dedicated type-prop tests for token_type_ids rank/type validation.
src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp Adjusts Gemma3 detection flagging and streamlines token_type_ids handling (convert-to-i32 when needed).

Comment on lines +438 to +442
// Set to true once a sliding_attention layer matching the gptoss_gemma3 pattern is found
// alongside a token_type_ids model input - the combination that uniquely identifies Gemma3
// since pattern for full attention mask in Gemma3 is different than sliding window
// it has to be persistent in the callback, so shared_ptr is used
auto has_token_type_ids = std::make_shared<bool>(false);
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[LOW] has_token_type_ids is used as a persisted “Gemma3 detected / enable token_type_ids wiring” flag (it’s only updated when the gptoss_gemma3 sliding-window pattern matches), so the name is misleading—there are cases where the model may have a token_type_ids input but this flag stays false until that pattern is seen. Consider renaming it to something like is_gemma3 / enable_gemma3_token_type_ids to reflect the actual semantics and reduce the chance of future misuse.

Copilot generated this review using guidance from repository custom instructions.
Comment on lines +201 to +207
static std::shared_ptr<ov::Node> handle_gemma3_token_type_ids(
const std::map<std::string, std::shared_ptr<v0::Parameter>>& optional_model_wide_params) {
if (optional_model_wide_params.find("token_type_ids") != optional_model_wide_params.end()) {
auto param = optional_model_wide_params.at("token_type_ids");
if (param->get_element_type() != ov::element::i32) {
return std::make_shared<v0::Convert>(param, ov::element::i32);
}
return param;
auto param = optional_model_wide_params.at("token_type_ids");
if (param->get_element_type() != ov::element::i32) {
return std::make_shared<v0::Convert>(param, ov::element::i32);
}
return v0::Constant::create(ov::element::i32, ov::Shape{0}, {});
return param;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this helper looks unsafe as can be used without the pre-check if the input token_type_ids exists, also it just add Convert, if the type is not aligned, what is done for other inputs as well, offsets for example. It's not unique for Gemma, maybe this handle_gemma3_token_type_ids helper can just be just skipped/removed in this PR, and as a separate contribution common apply_convert helper can be added and reused for other inputs as well.

auto offset = pattern_map.at(phi3_offset).get_node_shared_ptr();
if (offset->get_element_type() != element::i32) {
offset = std::make_shared<v0::Convert>(offset, element::i32);
}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Helper has been inlined, now the PR proposes using just the has_token_type_ids and its convert/insertion is inlined instead of being handled by a util.

OPENVINO_ASSERT(pa_arguments.size() == 25);

if (*is_gptoss_gemma3) {
if (*has_token_type_ids) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This handle_gemma3_token_type_ids brings more confusion than clarity, it just converts the type, whithout any model specific logic, so I would recommend to rename the helper to be more generic like "apply_convert" and reuse along the transformaion or for now just put the Convert insertion here explicitly, as it is done for other cases like offsets

Suggested change
if (*has_token_type_ids) {
if (*has_token_type_ids) {
auto token_type_ids = optional_model_wide_params.at("token_type_ids");
if (param->get_element_type() != ov::element::i32) {
token_type_ids = std::make_shared<v0::Convert>(token_type_ids, ov::element::i32));
}
pa_arguments.insert(pa_arguments.begin() + 25, token_type_ids);

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied, the logic has been simplified.

Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
@p-wysocki p-wysocki requested a review from mitruska April 1, 2026 12:40
@mlukasze mlukasze enabled auto-merge April 1, 2026 19:08
@mlukasze mlukasze added this pull request to the merge queue Apr 3, 2026
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Apr 3, 2026
@p-wysocki p-wysocki added this pull request to the merge queue Apr 3, 2026
Merged via the queue into openvinotoolkit:master with commit b286a61 Apr 3, 2026
226 of 228 checks passed
@p-wysocki p-wysocki deleted the attn_fixes branch April 3, 2026 16:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

category: Core OpenVINO Core (aka ngraph) category: transformations OpenVINO Runtime library - Transformations

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants