Skip to content

[CI] Fix ColBERT HF comparison tests on AMD CI + refactor#34567

Merged
vllm-bot merged 6 commits intovllm-project:mainfrom
ROCm:akaratza_fix_lang_mod_ext_pooling
Feb 21, 2026
Merged

[CI] Fix ColBERT HF comparison tests on AMD CI + refactor#34567
vllm-bot merged 6 commits intovllm-project:mainfrom
ROCm:akaratza_fix_lang_mod_ext_pooling

Conversation

@AndreasKaratzas
Copy link
Copy Markdown
Collaborator

@AndreasKaratzas AndreasKaratzas commented Feb 14, 2026

This PR fixes the test_colbert_hf_comparison_modernbert crash on AMD CI and refactors the ColBERT HF comparison tests for maintainability and cross-platform robustness.

test_colbert_hf_comparison_modernbert crashes on AMD CI nightly because the HF reference model runs on CPU while ModernBERT defaults to Triton-based flash attention:

test_colbert_hf_comparison_modernbert failure log
ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)

The HF ModernBERT model defaults to flash/Triton attention kernels, which require GPU tensors. The test loads the HF reference model on CPU via AutoModel.from_pretrained(model_name) without specifying an attention backend, so the Triton rotary embedding kernel fails when it receives CPU tensors.

Introduced by #34170, caught on AMD CI nightly.

Multiple ColBERT tests emit UserWarning: To copy construct from a tensor due to torch.tensor() being called on existing tensors:

torch.tensor copy-construct warnings
tests/models/language/pooling/test_colbert.py::test_colbert_late_interaction_1_to_N[bert]
tests/models/language/pooling/test_colbert.py::test_colbert_late_interaction_1_to_N[modernbert]
tests/models/language/pooling/test_colbert.py::test_colbert_late_interaction_1_to_N[jina]
  UserWarning: To copy construct from a tensor, it is recommended to use
  sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True),
  rather than torch.tensor(sourceTensor).
    q_emb = torch.tensor(q_outputs[0])

torch.tensor() always copies data and emits this warning when given an existing tensor. Since these call sites don't need a copy, torch.as_tensor() is the correct idiom — it returns a view when possible and never warns.

Changes

Fix Triton crash: detect device and set attention implementation

New _load_hf_model helper detects whether inference will run on CPU or GPU and passes attn_implementation="eager" on CPU to avoid Triton kernel crashes. On GPU it uses the default (flash attention). This is the actual bug fix.

Diff details
+def _load_hf_model(model_name: str, hf_spec: dict, device: torch.device):
+    """Load HF model on the given device with a compatible attention impl."""
+    from transformers import AutoModel, BertModel
+
+    cls = BertModel if hf_spec["model_cls"] == "BertModel" else AutoModel
+    trust = hf_spec.get("trust_remote_code", False)
+
+    # Flash / Triton kernels require GPU tensors; fall back to eager on CPU.
+    extra = {}
+    if device.type == "cpu":
+        extra["attn_implementation"] = "eager"
+
+    model = cls.from_pretrained(
+        model_name,
+        trust_remote_code=trust,
+        **extra,
+    ).to(device)
+    model.eval()
+    return model

Refactor: collapse three HF comparison tests into one parametrized test

The three test_colbert_hf_comparison_{bert,modernbert,jina} functions were near-identical (~80 lines of duplication), differing only in which model class, weights file, and trust_remote_code flag to use. Added an hf_comparison sub-dict to each COLBERT_MODELS entry to capture these differences declaratively, extracted shared logic into helpers, and replaced the three functions with a single @pytest.mark.parametrize-driven test. This ensures the Triton fix (and any future fix) only needs to exist in one place.

Diff details

hf_comparison metadata added to each model spec:

     "bert": {
         "model": "answerdotai/answerai-colbert-small-v1",
         "colbert_dim": 96,
         "max_model_len": 512,
         "extra_kwargs": {},
+        "hf_comparison": {
+            "weights_file": "model.safetensors",
+            "weights_key": "linear.weight",
+            "trust_remote_code": False,
+            "model_cls": "BertModel",
+        },
     },

(Similarly for modernbert and jinamodernbert uses "1_Dense/model.safetensors", jina uses "trust_remote_code": True.)

New shared helpers for weight loading and embedding computation:

+def _load_projection_weight(model_name: str, hf_spec: dict, device: torch.device):
+    """Download and return the ColBERT linear projection weight."""
+    from huggingface_hub import hf_hub_download
+    from safetensors.torch import load_file
+
+    path = hf_hub_download(model_name, filename=hf_spec["weights_file"])
+    weights = load_file(path)
+    return weights[hf_spec["weights_key"]].to(device)
+
+
+def _compute_hf_colbert_embeddings(model, tokenizer, linear_weight, texts,
+                                   device):
+    """Run HF model + projection and return L2-normalized token embeddings."""
+    import torch.nn.functional as F
+
+    embeddings = []
+    for text in texts:
+        inputs = tokenizer(text, return_tensors="pt").to(device)
+        with torch.no_grad():
+            hidden = model(**inputs).last_hidden_state.float()
+            projected = F.linear(hidden, linear_weight.float())
+            normalized = F.normalize(projected, p=2, dim=-1)
+            embeddings.append(normalized.squeeze(0).cpu())
+    return embeddings

.float() is called unconditionally on both hidden_states and linear_weight — this is a no-op for bert/modernbert (already float32) and matches the explicit .float() calls the old Jina test required. Results are moved to CPU for device-agnostic comparison.

Three separate test functions replaced by one:

-def test_colbert_hf_comparison_bert(vllm_runner):
-    ...
-def test_colbert_hf_comparison_modernbert(vllm_runner):
-    ...
-def test_colbert_hf_comparison_jina(vllm_runner):
-    ...
+@pytest.mark.parametrize("backend", list(COLBERT_MODELS.keys()))
+def test_colbert_hf_comparison(vllm_runner, backend):
+    """Test that vLLM ColBERT embeddings match HuggingFace for each backend."""
+    from transformers import AutoTokenizer
+
+    spec = COLBERT_MODELS[backend]
+    hf_spec = spec["hf_comparison"]
+    model_name = spec["model"]
+    test_texts = [TEXTS_1[0], TEXTS_2[0]]
+
+    with vllm_runner(
+        model_name,
+        runner="pooling",
+        dtype="float32",
+        max_model_len=spec["max_model_len"],
+        enforce_eager=True,
+        **spec["extra_kwargs"],
+    ) as vllm_model:
+        vllm_outputs = vllm_model.token_embed(test_texts)
+
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    hf_tokenizer = AutoTokenizer.from_pretrained(
+        model_name,
+        trust_remote_code=hf_spec.get("trust_remote_code", False),
+    )
+    hf_model = _load_hf_model(model_name, hf_spec, device)
+    linear_weight = _load_projection_weight(model_name, hf_spec, device)
+
+    hf_embeddings = _compute_hf_colbert_embeddings(
+        hf_model, hf_tokenizer, linear_weight, test_texts, device,
+    )
+
+    _assert_embeddings_close(vllm_outputs, hf_embeddings)

Resolve torch.tensor copy-construct warnings

All torch.tensor(existing_tensor) calls replaced with torch.as_tensor() across every test function to eliminate the UserWarning. torch.as_tensor() returns a view when possible and is the correct idiom when a copy is not needed.

Diff details
 # _assert_embeddings_close
-        vllm_emb = torch.tensor(vllm_out).float()
+        vllm_emb = torch.as_tensor(vllm_out).float()

 # test_colbert_token_embed
-        emb = torch.tensor(outputs[0])
+        emb = torch.as_tensor(outputs[0])

 # test_colbert_late_interaction_1_to_1
-        q_emb = torch.tensor(q_outputs[0])
-        d_emb = torch.tensor(d_outputs[0])
+        q_emb = torch.as_tensor(q_outputs[0])
+        d_emb = torch.as_tensor(d_outputs[0])

 # test_colbert_late_interaction_1_to_N
-        q_emb = torch.tensor(q_outputs[0])
+        q_emb = torch.as_tensor(q_outputs[0])
         ...
-            d_emb = torch.tensor(d_out)
+            d_emb = torch.as_tensor(d_out)

 # test_colbert_late_interaction_N_to_N
-            q_emb = torch.tensor(q_out)
-            d_emb = torch.tensor(d_out)
+            q_emb = torch.as_tensor(q_out)
+            d_emb = torch.as_tensor(d_out)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
@mergify mergify bot added the rocm Related to AMD ROCm label Feb 14, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Feb 14, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fix for a crash on AMD CI in ColBERT HF comparison tests and significantly refactors the tests for better maintainability. The crash, caused by using Triton attention kernels on CPU tensors, is correctly addressed by conditionally setting attn_implementation="eager" when running on CPU. The refactoring consolidates three near-identical test functions into a single, parameterized test, which greatly reduces code duplication and improves clarity. Additionally, the PR resolves torch.tensor copy-construction warnings by correctly using torch.as_tensor. The changes are well-implemented and improve both the correctness and quality of the test suite. I have no further comments.

@AndreasKaratzas
Copy link
Copy Markdown
Collaborator Author

AndreasKaratzas commented Feb 14, 2026

cc @ieBoytsov @DarkLight1337 @noooop Follow-up to #34170

cls = BertModel if hf_spec["model_cls"] == "BertModel" else AutoModel
trust = hf_spec.get("trust_remote_code", False)

# Flash / Triton kernels require GPU tensors; fall back to eager on CPU.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @AndreasKaratzas nice job, refactored tests looks great. I don't understand so far why we need explicit cast to eager attention. The motivation here makes sense but tests passed everywhere except AMD? Could you elaborate if there is anything specific about AMD in this context? Thanks in advance : )

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I just tested it on ROCm as well right now without the eager step and it passes at least on MI355. I would prefer to keep this there though as it's only targeting HF API, not vLLM. There is still a related ticket on ROCm for this one though too: #30167

It's the same reason there is this conftest.py under the parent of this TG.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I am not sure if without eager, the test would work on CPU CI. Btw, I just tested it on MI325 as well, without the eager patch and it passes on there too. So all in all, I can remove it, but I think it's beneficial and it does not affect vLLM evaluation, rather HF API side only.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, glad it is merged now!

@AndreasKaratzas
Copy link
Copy Markdown
Collaborator Author

The failure in this one has to do with /Qwen-VL/assets/SimSun.ttf not being found. This is a Qwen font file, that is hosted in some server that refuses connection. This failure is not related to this PR.

Copy link
Copy Markdown
Collaborator

@noooop noooop left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for this fix.

As @ieBoytsov also feels great for this.

@noooop noooop enabled auto-merge (squash) February 20, 2026 12:39
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 20, 2026
@AndreasKaratzas
Copy link
Copy Markdown
Collaborator Author

@noooop Distributed model tests failure looks completely unrelated to this PR: https://buildkite.com/vllm/ci/builds/52483/steps/canvas?jid=019c7c16-c973-4172-a5f5-06ad78c33d20&tab=output#019c7c16-c973-4172-a5f5-06ad78c33d20/L2942

Can we merge it?

@vllm-bot vllm-bot merged commit 89358f0 into vllm-project:main Feb 21, 2026
14 of 16 checks passed
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Feb 21, 2026
@dosubot
Copy link
Copy Markdown

dosubot bot commented Feb 21, 2026

Related Documentation

Checked 0 published document(s) in 1 knowledge base(s). No updates required.

How did I do? Any feedback?  Join Discord

@AndreasKaratzas AndreasKaratzas deleted the akaratza_fix_lang_mod_ext_pooling branch February 21, 2026 04:14
DarkLight1337 pushed a commit to DarkLight1337/vllm that referenced this pull request Feb 21, 2026
joeqzzuo pushed a commit to joeqzzuo/vllm that referenced this pull request Feb 21, 2026
…ct#34567)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: joezuo <qianzhou.zuo@gmail.com>
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Feb 22, 2026
jmamou pushed a commit to jmamou/vllm that referenced this pull request Feb 23, 2026
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
askliar pushed a commit to askliar/vllm that referenced this pull request Mar 9, 2026
…ct#34567)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Copilot AI pushed a commit to machov/vllm that referenced this pull request Mar 10, 2026
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
…ct#34567)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: EricccYang <yangyang4991@gmail.com>
liuchenbing2026 pushed a commit to liuchenbing2026/vllm that referenced this pull request Apr 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants