Skip to content

[Hardware] Enable Intel Gaudi (HPU) support#1066

Closed
shepark wants to merge 42 commits intoLMCache:devfrom
shepark:intel-main-v1-rebase-upstream
Closed

[Hardware] Enable Intel Gaudi (HPU) support#1066
shepark wants to merge 42 commits intoLMCache:devfrom
shepark:intel-main-v1-rebase-upstream

Conversation

@shepark
Copy link
Copy Markdown

@shepark shepark commented Jul 15, 2025

Enable Intel Gaudi HPU support.
There is corresponding PR (HabanaAI/vllm-fork#1369) in vllm-fork.
The PR in vllm-fork has examples utilizes lmcache. (ex: PD use case)
We need this changes in.

  • install lmcache based on hpu components
    PT_HPU_GPU_MIGRATION=1 pip install -e .

@shepark shepark marked this pull request as ready for review July 15, 2025 21:54
@YaoJiayi YaoJiayi self-requested a review July 15, 2025 22:04
Comment thread lmcache/v1/gpu_connector.py Outdated
hd_shape = h*d
for i in range(len(kvcaches)):
kvcaches[i][0].view(b, hd_shape).index_copy_(0, slot_mapping[start:end], tmp_gpu_buffer[0][i])
kvcaches[i][1].view(b, hd_shape).index_copy_(0, slot_mapping[start:end], tmp_gpu_buffer[1][i])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having a bunch of if-else (also the following), why not having an HPUConnector?

Copy link
Copy Markdown
Author

@shepark shepark Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@YaoJiayi Thank you for the review and suggestion. That'd be good too.
Simply, the changes here is for providing different path when lmc_ops is not exist, for kv cache transfer as you know.
Right, it looks "bunch of" but it's clearly distinguishable and only in "to_gpu" and "from_gpu".
If we have new hpu_connector, then there might be much more codes to select hpu_connector in case in multiple locations.
But, this is the first changes, so we will consider to have separate connector for hpu definitely.

@shepark
Copy link
Copy Markdown
Author

shepark commented Jul 17, 2025

@YaoJiayi I see 2 failing checks.
The 1st one is canceled by you.
The 2nd one is a failed by some permission issue.
Can you check this out?
I'm not sure whether this is something I can fix or not.
Thank you.

--
  | # Host "github.com" already in list of known hosts at "/root/.ssh/known_hosts"
  | $ git clone -v -- git@github.com:LMCache/LMCache.git .
  | Cloning into '.'...
  | git@github.com: Permission denied (publickey).

  | fatal: Could not read from remote repository.
  |  
  | Please make sure you have the correct access rights
  | and the repository exists.
  | ⚠️ Warning: Checkout failed! cloning git repository: exit status 128 (Attempt 3/3)
  | # Removing /var/lib/buildkite-agent/builds/unit-end-to-end-test-2-2/lmcache/lmcache-vllm-integration-tests
  | # Creating "/var/lib/buildkite-agent/builds/unit-end-to-end-test-2-2/lmcache/lmcache-vllm-integration-tests"
  | 🚨 Error: cloning git repository: exit status 128

@shepark shepark force-pushed the intel-main-v1-rebase-upstream branch 3 times, most recently from 0824f9a to 07f2519 Compare July 18, 2025 15:47
@shepark shepark changed the title [Feature] Enable Intel Gaudi (HPU) support [Hardware] Enable Intel Gaudi (HPU) support Jul 18, 2025
@shepark shepark force-pushed the intel-main-v1-rebase-upstream branch from 07f2519 to f50ebc5 Compare July 19, 2025 01:21
@shepark
Copy link
Copy Markdown
Author

shepark commented Jul 19, 2025

@YaoJiayi all test passed, can you review this pr again?

@shepark shepark force-pushed the intel-main-v1-rebase-upstream branch 6 times, most recently from 57d6b69 to 1f0a331 Compare July 23, 2025 05:17
@shepark
Copy link
Copy Markdown
Author

shepark commented Jul 23, 2025

@sammshen could you review PR?

Copy link
Copy Markdown
Collaborator

@YaoJiayi YaoJiayi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! @shepark We do want to support more hardwares! Just two minor comments

flattened.extend(elem)
else:
flattened.append(elem)
new_block_ids = flattened
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the code change here?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@YaoJiayi Thank you for the review.
Actually, I described it in the issue ticket as well but I closed it (#1091)
We expect allocated_block_ids is always list type, like [1,2,3,4,5].
But, with original code which recently added by 1072, it becomes nested lists when new_block_ids data type is list of lists.
So, I need to flatten this again for our case.

Comment thread lmcache/v1/gpu_connector.py Outdated
self.use_mla,
)
else:
if self.gpu_buffer is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move the code from VLLMPagedMemGPUConnectorV2 to sth like VLLMPagedMemHPUConnector?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@YaoJiayi yes, I totally agree with your suggestion but same as your comment as before about hpu_connector.
We are going to continuously improve, and definitely will consider to have separate connector.
In the future, it will be VLLMPagedMemHPUConnector in hpu_connector, not in gpu_connector.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @shepark , would it be ok to do this in this PR? :)

Copy link
Copy Markdown
Author

@shepark shepark Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refer

ok will think about it :) ;;;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@YaoJiayi @sammshen The most recent commit does this. Can you review again?

@shepark shepark force-pushed the intel-main-v1-rebase-upstream branch from ce961f2 to f2849bb Compare July 24, 2025 15:25
@shepark
Copy link
Copy Markdown
Author

shepark commented Jul 24, 2025

@YaoJiayi @sammshen can you check the buildkite failure?
It looks some test infra related, not PR.

@sammshen
Copy link
Copy Markdown
Contributor

Integration tests are still WIP, do not need to pass for now

Comment thread lmcache/v1/hpu_connector.py Outdated
self.gpu_buffer = torch.empty(
shape, dtype=kwargs["dtype"], device=kwargs["device"]
)
self.store_stream = torch.cuda.Stream()
Copy link
Copy Markdown
Contributor

@sammshen sammshen Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if torch.device() is set to hpu in the adapter, can we use torch.cuda here? sorry if I misunderstand

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed line 87 as its not used in hpu

VLLMPagedMemHPUConnectorV2
]

if use_mla and config.use_layerwise:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just one more small check under if config.use_layerwise that if device.type == "hpu", don't support yet

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed this

Comment thread lmcache/v1/hpu_connector.py Outdated

memory_obj.tensor.copy_(tmp_gpu_buffer, non_blocking=True)

if not memory_obj.tensor.is_cuda:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, can we call is_cuda here? tensor.device.type != "cuda" might be more safe

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i removed the if loop here since we dont support cuda stream in hpu_connector

Comment thread requirements/hpu.txt
infinistore
msgspec
numpy
nvtx
Copy link
Copy Markdown
Contributor

@sammshen sammshen Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvtx (and cufile-python) are nvidia specific . I agree we should keep it though if this helps with not breaking the installation (since all the code is annotated with nvtx_annotate)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cufile-python and nvtx want to be kept?

@sammshen
Copy link
Copy Markdown
Contributor

sammshen commented Jul 31, 2025

Hi @skaulintel @shepark (don't worry about the CI, those will be back up very soon), apologize if this is nitpikcy but my overall comment would just be about the implicit reliance on lmc_ops == None inside of the HPU connector. would it be possible to clean it up and only keep the codepaths that use if not lmc_ops?

Approving first as mostly LGTM

@hsubramony hsubramony force-pushed the intel-main-v1-rebase-upstream branch 2 times, most recently from 300f504 to 18d0780 Compare August 12, 2025 21:04
Copy link
Copy Markdown
Collaborator

@DongDongJu DongDongJu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, Happy new year.
I left few comments and questions.
One last thing is that
Do we have any chance to force cuda available to false when hpu.is_available from torch side? that will be really helpful.

VLLMPagedMemLayerwiseGPUConnector,
)

if hasattr(torch, "hpu") and torch.hpu.is_available():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, Good to have a helper func like is_hpu_avaliable().

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so cuda available returns true by design when using PT_HPU_GPU_MIGRATION=1 , this is used for running GPU code on HPU with minimal changes , for ref please check https://docs.habana.ai/en/latest/PyTorch/PyTorch_Model_Porting/GPU_Migration_Toolkit/GPU_Migration_Toolkit.html

# First Party
import lmcache.c_ops as lmc_ops
except (ModuleNotFoundError, ImportError):
lmc_ops = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

torch_dev.set_device(local_rank)
device = torch.device(f"{dev_name}:{local_rank}")
num_gpus = torch_dev.device_count()
local_rank = parallel_config.rank % num_gpus
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So basically, gaudi do not support parallel feature. Is that correct?

Comment thread lmcache/v1/hpu_connector.py Outdated
# First Party
import lmcache.c_ops as lmc_ops
except (ModuleNotFoundError, ImportError):
lmc_ops = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to checking import this for this hpu case? it seems alway none in this case and following all the lmc_ops check also useless.
Please correct me if im wrong.

Comment thread lmcache/v1/kv_layer_groups.py Outdated
shape = kv_cache.shape
dtype = kv_cache.dtype

if shape is not None and dtype is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which case shape and dtype can be none in here?

Comment thread lmcache/v1/memory_management.py Outdated
array_type = ctypes.c_uint8 * size
buf = array_type.from_address(ptr)
buffer = torch.frombuffer(buf, dtype=torch.uint8)
buffer = torch.empty(size, dtype=torch.uint8)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should go inside of logic.
This way has broken behavior for numa_mapping case.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is for when lmc_ops = None

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thats what im saying. It previously working with numa aware manner but will not anymore with this indent.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with something like

if is_hpu_environment():
    return torch.empty(size, dtype=torch.uint8)

The other part of the code should not be touched in this case

@DongDongJu
Copy link
Copy Markdown
Collaborator

@DongDongJu i am working on the machine , might take some time. meanwhile i can provide the logs of the tests. Will that be ok ?

Sure, please post here or send the log in community slack with DEBUG level log. Thanks!

@DongDongJu
Copy link
Copy Markdown
Collaborator

@hsubramony will do tomorrow. Thanks for quick response!

@ApostaC
Copy link
Copy Markdown
Contributor

ApostaC commented Jan 8, 2026

Hey, thanks for the contribution 🙏! I would also like to take a look at this PR over the weekend.

Copy link
Copy Markdown
Contributor

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the terrific work! My understading is that this PR includes at least 3 parts:

  1. Introduce the HPU connector for passing KV cache between vLLM and LMCache
  2. Fix a lot of import issues
  3. Handling the differences in KV cache data structure between HPU and GPU .

Part 1 seems good to me. But we probably want to have a better way (i.e., with clear function definition and less code changes) to achieve part 2 and part 3.

It would be great and also easier for the code maintainers to review if you can split this into 3 PRs.

# First Party
import lmcache.c_ops as lmc_ops
except (ModuleNotFoundError, ImportError):
lmc_ops = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's wrong if we don't do it here? IIUC, cachegen won't be supported unless we have the kernels on intel gpus.
Therefore, functions in this file and the cachegen_encoder.py won't be called anyway.

Comment thread lmcache/v1/cache_engine.py Outdated
Comment on lines +378 to +382
# required for VLLMPagedMemHPUConnectorV2
if hasattr(torch, "hpu") and torch.hpu.is_available():
kv_shapes = self.gpu_connector.get_shape(num_tokens)
else:
kv_shapes = self.metadata.get_shapes(num_tokens)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We did a refactoring in #2284 and a series of related PRs done by @chunxiaozheng.
We should reuse the metadata.get_shapes instead of having an if branch here,.

Comment thread lmcache/v1/gpu_connector.py Outdated

if torch.cuda.is_available():
# First Party
if hasattr(torch, "hpu") and torch.hpu.is_available():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've seen this in many different places. Can we do two things:

  1. make hasattr(torch, "hpu") and torch.hpu.is_available(): a common util function
  2. Try to see if there is a way to avoid this if-checking during import

The goal is to avoid code duplication and make it more maintainable

Comment thread lmcache/v1/kv_layer_groups.py Outdated
Comment on lines +177 to +185
shape, dtype = None, None

if isinstance(kv_cache, (tuple, list)):
# HPU has a tuple list (K,V) with same shape and dtype
for tensor in kv_cache:
if tensor is not None:
shape = tensor.shape
dtype = tensor.dtype
break
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, you are effectively making shape and dtype become optional values (which means it could be None).
Not sure why the linter doesn't complain here, but we should make the code here cleaner. Probably define a clear function to extract the dtype and shape for hpu?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are refactoring this part of the code in #2380. Let's don't touch it for now

Comment thread lmcache/v1/memory_management.py Outdated
array_type = ctypes.c_uint8 * size
buf = array_type.from_address(ptr)
buffer = torch.frombuffer(buf, dtype=torch.uint8)
buffer = torch.empty(size, dtype=torch.uint8)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with something like

if is_hpu_environment():
    return torch.empty(size, dtype=torch.uint8)

The other part of the code should not be touched in this case

Comment thread lmcache/v1/memory_management.py Outdated
Comment on lines +1606 to +1613
if lmc_ops:
ptr = lmc_ops.alloc_pinned_ptr(size, 0)
array_type = ctypes.c_uint8 * size
buf = array_type.from_address(ptr)
self.buffer = torch.frombuffer(buf, dtype=torch.uint8)
else:
self.buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above, creating an extra ident is not ideal and will be confusing to other contributors

Comment thread lmcache/v1/memory_management.py Outdated

def close(self):
if not self._unregistered:
if lmc_ops and not self._unregistered:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we use whether lmc_ops is None to determine whether it's in the hpu environment (and similar logic is used in multiple places).
This may create confusion and maintenance overhead in the future. We should define an explicit function to determine whether it's hpu or not.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And such a function should have negligible overhead if it is called during runtime

Comment thread lmcache/v1/protocol.py Outdated
from lmcache.v1.memory_management import MemoryFormat

MAX_KEY_LENGTH = 150
MAX_KEY_LENGTH = 250
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering why we make this change here?

@DongDongJu
Copy link
Copy Markdown
Collaborator

During checking the log I noticed that hpu support is not officially merging into vllm.
And, it seems multiple forks(https://github.com/vllm-project/vllm-gaudi, https://github.com/HabanaAI/vllm-fork ...) existing.
Even these basic code(hpu_runner from vllm_gaudi, vllm_fork) are different now.
Can we know what is the plan for it?

…pstream

refactor vllm_v1_adapter.py and manager.py
Signed-off-by: Harish Subramony <hsubramony@habana.ai>
Signed-off-by: Harish Subramony <hsubramony@habana.ai>
@hsubramony
Copy link
Copy Markdown

During checking the log I noticed that hpu support is not officially merging into vllm. And, it seems multiple forks(https://github.com/vllm-project/vllm-gaudi, https://github.com/HabanaAI/vllm-fork ...) existing. Even these basic code(hpu_runner from vllm_gaudi, vllm_fork) are different now. Can we know what is the plan for it?

please use https://github.com/vllm-project/vllm-gaudi.git

@hsubramony
Copy link
Copy Markdown

@sammshen @ApostaC i have updated with suggested changes. Please review. thanks

Signed-off-by: Harish Subramony <hsubramony@habana.ai>
Copy link
Copy Markdown
Collaborator

@DongDongJu DongDongJu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the hard work!
I left few questions.

Comment thread lmcache/v1/kv_layer_groups.py Outdated
Comment thread lmcache/utils.py
Signed-off-by: Harish Subramony <hsubramony@habana.ai>
@hsubramony
Copy link
Copy Markdown

@DongDongJu @sammshen @ApostaC i updated the branch and pushed. Please let me know if any other issues. Please help merge. thanks

@DongDongJu
Copy link
Copy Markdown
Collaborator

Hello, Thanks for the great work.
IMO, not sure this is good choice including this one inside of cache engine itself.
Its not like storage backend level module.
I bliv https://github.com/LMCache/LMCache-Ascend way is better than merging in here.
But I left this one for others choice.

@libinta
Copy link
Copy Markdown

libinta commented Jan 20, 2026

Hello, Thanks for the great work. IMO, not sure this is good choice including this one inside of cache engine itself. Its not like storage backend level module. I bliv https://github.com/LMCache/LMCache-Ascend way is better than merging in here. But I left this one for others choice.

@DongDongJu thanks for your review, as we don't have heavy changes for LMCache at this stage, do you think this PR can be merged to main git?

hlin99 added a commit to hlin99/LMCache that referenced this pull request Jan 26, 2026
This reverts commit 5666a1c.
hlin99 added a commit to hlin99/LMCache that referenced this pull request Feb 4, 2026
NO_CUDA_EXT=1 BUILD_WITH_HPU=1 PT_HPU_GPU_MIGRATION=1 pip install -e .
@github-actions
Copy link
Copy Markdown

This pull request has been automatically marked as stale because it has not had activity within 60 days. It will be automatically closed if no further activity occurs within 30 days.

@github-actions
Copy link
Copy Markdown

This pull request has been automatically closed due to inactivity. Please feel free to reopen if you intend to continue working on it!

@github-actions github-actions Bot closed this Apr 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.