Skip to content

[ROCm][CI] Fix spec decode logprobs flakiness and parametrize tree attention backends#34599

Merged
vllm-bot merged 8 commits intovllm-project:mainfrom
ROCm:akaratza_v1_others
Feb 21, 2026
Merged

[ROCm][CI] Fix spec decode logprobs flakiness and parametrize tree attention backends#34599
vllm-bot merged 8 commits intovllm-project:mainfrom
ROCm:akaratza_v1_others

Conversation

@AndreasKaratzas
Copy link
Copy Markdown
Collaborator

@AndreasKaratzas AndreasKaratzas commented Feb 16, 2026

This PR fixes V1 Test others test flakiness on ROCm.

test_logprobs.py

test_spec_decode_logprobs was intermittently failing on ROCm due to logprob differences between the base and speculative decode LLM that were misattributed to spec decode itself, due to ROCm skinny GEMM non-determinism --- the wvSplitK kernels in gemm_kernels.cu use persistent workgroup scheduling and wave-level shuffle reductions that produce different results across LLM instantiations, even with identical configs and seeds.

The fix disables the skinny GEMM via VLLM_ROCM_USE_SKINNY_GEMM=0 for this test. Descriptive assertion messages are added for easier future triage.

Additional cleanup: converted standalone VllmRunner instantiations throughout the file to use context managers for proper resource cleanup.

test_tree_attention.py

Parametrizes test_tree_attn_correctness over all reference attention backends available on the current platform rather than hardcoding FLASH_ATTN. On ROCm this includes TRITON_ATTN and the platform default. Adds KV cache layout adaptation (flash <-> block) so backends with different cache layouts can be used as references. Documents known incompatibilities with ROCM_ATTN (paged layout) and ROCM_AITER_FA (head count mismatch) as TODOs.

…hroughout

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…ference backends

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses test flakiness on ROCm by disabling skinny GEMM in logprob tests and parametrizing tree attention tests over available backends. The changes include converting VllmRunner to use context managers for better resource management and adding KV cache layout adaptation. While the overall direction is correct, there are a few issues in the test setup that could lead to crashes or unexpected failures depending on the environment.

Comment on lines +102 to +104
for backend in AttentionBackendEnum:
if backend.value is not None and backend.get_path() == backend_path:
return backend
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This loop will raise a ValueError if it encounters an AttentionBackendEnum member with an empty string value (such as TORCH_SDPA). This is because get_path() explicitly checks for non-empty paths and raises an error otherwise. This could crash test collection on platforms where such backends are present.

    for backend in AttentionBackendEnum:
        try:
            if backend.get_path() == backend_path:
                return backend
        except ValueError:
            continue

backends: list[AttentionBackendEnum] = []

# 1. Whatever the platform would auto-select at runtime.
backends.append(_get_platform_default_backend())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The platform default backend should be filtered against known incompatible backends (as documented in the TODOs below) before being added to the reference backends list. If ROCM_AITER_FA or ROCM_ATTN is selected as the default by the platform (e.g., via environment variables like VLLM_ROCM_USE_AITER), the test will fail due to the documented incompatibilities.

        default_backend = _get_platform_default_backend()
        if default_backend not in (AttentionBackendEnum.ROCM_AITER_FA,
                                   AttentionBackendEnum.ROCM_ATTN):
            backends.append(default_backend)

Copy link
Copy Markdown
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable. One minor nit.

contention. Both use identical chunked prefill settings and eager
mode to control for infrastructure differences.

On ROCm, the custom skinny GEMM kernels are non-deterministic
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think one block comment describing why we are disabling skinny gemms is sufficient :).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done :)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@gshtras gshtras enabled auto-merge (squash) February 20, 2026 16:00
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 20, 2026
@vllm-bot vllm-bot merged commit 54254f7 into vllm-project:main Feb 21, 2026
15 of 17 checks passed
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Feb 21, 2026
@dosubot
Copy link
Copy Markdown

dosubot bot commented Feb 21, 2026

Related Documentation

Checked 0 published document(s) in 1 knowledge base(s). No updates required.

How did I do? Any feedback?  Join Discord

@AndreasKaratzas AndreasKaratzas deleted the akaratza_v1_others branch February 21, 2026 04:26
DarkLight1337 pushed a commit to DarkLight1337/vllm that referenced this pull request Feb 21, 2026
…tention backends (vllm-project#34599)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
joeqzzuo pushed a commit to joeqzzuo/vllm that referenced this pull request Feb 21, 2026
…tention backends (vllm-project#34599)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: joezuo <qianzhou.zuo@gmail.com>
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Feb 22, 2026
…tention backends (vllm-project#34599)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
jmamou pushed a commit to jmamou/vllm that referenced this pull request Feb 23, 2026
…tention backends (vllm-project#34599)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
…tention backends (vllm-project#34599)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
…tention backends (vllm-project#34599)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
askliar pushed a commit to askliar/vllm that referenced this pull request Mar 9, 2026
…tention backends (vllm-project#34599)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Copilot AI pushed a commit to machov/vllm that referenced this pull request Mar 10, 2026
…tention backends (vllm-project#34599)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
…tention backends (vllm-project#34599)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: EricccYang <yangyang4991@gmail.com>
liuchenbing2026 pushed a commit to liuchenbing2026/vllm that referenced this pull request Apr 4, 2026
…tention backends (vllm-project#34599)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants