Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions lmdeploy/pytorch/check_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def check_env_torch():
_handle_exception(e, 'PyTorch', logger)


MAX_TRITON_VERSION = '2.2.0'


def check_env_triton():
"""check OpenAI Triton environment."""
from packaging import version
Expand All @@ -47,8 +50,9 @@ def check_env_triton():
logger.debug('Checking <Triton> environment.')
import torch
import triton
if version.parse(triton.__version__) != version.parse('2.1.0'):
logger.warning('Install triton==2.1.0'
if version.parse(
triton.__version__) > version.parse(MAX_TRITON_VERSION):
logger.warning(f'Install triton<={MAX_TRITON_VERSION}'
' if you want to get better performance.')

from .triton_custom_add import custom_add
Expand Down Expand Up @@ -80,7 +84,7 @@ def check_env(device_type: str):


MIN_TRANSFORMERS_VERSION = '4.33.0'
MAX_TRANSFORMERS_VERSION = '4.38.2'
MAX_TRANSFORMERS_VERSION = '4.41.2'
Comment thread
lvhan028 marked this conversation as resolved.


def check_transformers_version(model_path: str,
Expand Down
9 changes: 9 additions & 0 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class ModelConfig:
bos_token_id: int
eos_token_id: List[int]
head_dim: int
k_head_dim: int = None
v_head_dim: int = None
sliding_window: int = -1
dtype: torch.dtype = torch.float16
multi_query_attention: bool = False
Expand Down Expand Up @@ -106,6 +108,13 @@ def from_hf_config(cls, hf_config: Any, model_path: str = None):

model_config = AutoModelConfigBuilder.build(hf_config, model_path)

if model_config.k_head_dim is None:
assert model_config.head_dim is not None
model_config.k_head_dim = model_config.head_dim
if model_config.v_head_dim is None:
assert model_config.head_dim is not None
model_config.v_head_dim = model_config.head_dim

model_arch = model_config.hf_config.architectures[0]
model_config.model_arch = model_arch
# should after setting `hf_config` and `model_arch` attributes
Expand Down
34 changes: 34 additions & 0 deletions lmdeploy/pytorch/configurations/deepseek_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.config import ModelConfig

from .builder import AutoModelConfigBuilder


class DeepseekV2ModelConfigBuilder(AutoModelConfigBuilder):

@classmethod
def condition(cls, hf_config):
"""config."""
return hf_config.model_type == 'deepseek_v2'

@classmethod
def build(cls, hf_config, model_path: str = None):
"""build."""
head_dim = (hf_config.kv_lora_rank + hf_config.qk_rope_head_dim)
k_head_dim = head_dim
v_head_dim = 0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why v_head_dim is 0?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_attention_heads = hf_config.num_attention_heads
num_key_value_heads = 1
init_kwargs = dict(attn_implementation='eager')
return ModelConfig(hidden_size=hf_config.hidden_size,
Comment thread
grimoire marked this conversation as resolved.
num_layers=hf_config.num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
head_dim=head_dim,
k_head_dim=k_head_dim,
v_head_dim=v_head_dim,
vocab_size=hf_config.vocab_size,
multi_query_attention=True,
init_kwargs=init_kwargs)
110 changes: 63 additions & 47 deletions lmdeploy/pytorch/engine/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,7 @@ def __init__(
self.model_config = model_config

self.block_size = cache_config.block_size

self.head_size = model_config.get_head_size()
self.num_layers = model_config.num_layers
self.num_heads = model_config.num_key_value_heads

self.kv_cache_dtype = model_config.dtype

# Initialize the cache.
Expand Down Expand Up @@ -81,31 +77,46 @@ def num_cpu_blocks(self):
"""num gpu blocks."""
return self.cache_config.num_cpu_blocks

@classmethod
def _get_block_shape_impl(cls,
model_config: ModelConfig,
block_size: int,
head_size: int,
world_size: int = 1,
local: bool = True):
"""get single block shape."""
num_heads = model_config.num_key_value_heads
if local and not model_config.multi_query_attention:
assert num_heads % world_size == 0, \
f'num_heads: {num_heads}, world_size: {world_size}'
num_heads = num_heads // world_size
return (block_size, num_heads, head_size)

def get_key_block_shape(self, local: bool = False) -> Tuple[int, int, int]:
"""get shape of key block."""
num_heads = self.num_heads
if local and not self.model_config.multi_query_attention:
assert self.num_heads % self.world_size == 0, \
f'num_heads: {self.num_heads}, world_size: {self.world_size}'
num_heads = self.num_heads // self.world_size
return (
self.block_size,
num_heads,
self.head_size,
head_size = self.model_config.k_head_dim
if head_size is None:
head_size = self.model_config.head_dim
return self._get_block_shape_impl(
self.model_config,
block_size=self.block_size,
head_size=head_size,
world_size=self.world_size,
local=local,
)

def get_value_block_shape(self,
local: bool = False) -> Tuple[int, int, int]:
"""get shape of value block."""
num_heads = self.num_heads
if local and not self.model_config.multi_query_attention:
assert self.num_heads % self.world_size == 0, \
f'num_heads: {self.num_heads}, world_size: {self.world_size}'
num_heads = self.num_heads // self.world_size
return (
self.block_size,
num_heads,
self.head_size,
head_size = self.model_config.v_head_dim
if head_size is None:
head_size = self.model_config.head_dim
return self._get_block_shape_impl(
self.model_config,
block_size=self.block_size,
head_size=head_size,
world_size=self.world_size,
local=local,
)

def allocate_gpu_cache(self):
Expand Down Expand Up @@ -190,8 +201,9 @@ def swap_out(self, src_to_dst: Dict[int, int]) -> None:
"""
self._swap(self.local_gpu_cache, self.local_cpu_cache, src_to_dst)

@staticmethod
def get_cache_block_size(block_size: int,
@classmethod
def get_cache_block_size(cls,
block_size: int,
model_config: ModelConfig,
world_size: int = 1) -> int:
"""Get the required cache size of the model.
Expand All @@ -203,27 +215,31 @@ def get_cache_block_size(block_size: int,
Return:
int: Required memory size in bytes.
"""
head_size = model_config.get_head_size()
num_layers = model_config.num_layers
num_heads = model_config.num_key_value_heads
if not model_config.multi_query_attention:
num_heads = num_heads // world_size

key_cache_block = block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)

dtype_size = _get_dtype_size(model_config.dtype)
return dtype_size * total


def _get_dtype_size(dtype: torch.dtype) -> int:
"""get size of the given dtype.

Args:
dtype (torch.dtype): Data type.

Return:
int: size in bytes.
"""
return torch.tensor([], dtype=dtype).element_size()
key_head_size = model_config.k_head_dim
value_head_size = model_config.v_head_dim
if key_head_size is None:
key_head_size = model_config.head_dim
if value_head_size is None:
value_head_size = model_config.head_dim
key_shape = cls._get_block_shape_impl(
model_config,
block_size=block_size,
head_size=key_head_size,
world_size=world_size,
local=True,
)
value_shape = cls._get_block_shape_impl(
model_config,
block_size=block_size,
head_size=value_head_size,
world_size=world_size,
local=True,
)
dtype = model_config.dtype
key_block = torch.empty(key_shape, dtype=dtype, device='meta')
value_block = torch.empty(value_shape, dtype=dtype, device='meta')
mem_key_block = key_block.numel() * key_block.element_size()
mem_value_block = value_block.numel() * value_block.element_size()
total = num_layers * (mem_key_block + mem_value_block)
return total
15 changes: 15 additions & 0 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,21 @@ def __get_free_gpu_mem_size(cache_block_size: int):
f' {runtime_cache_size>>20} mb')
return gpu_mem_physical_free * cache_config.cache_max_entry_count

def __adjust_block_size():
"""adjust block_size."""
# TODO: support kernel with both large head dim and large block size.
if model_config.k_head_dim >= 512 and cache_config.block_size > 32:
cache_config.block_size = 32
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this affect models other than DeepSeek v2?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the mha kernel needs enough smem to cache the kv_cache block and query block. Any model with such a large head_dim should be limited.
Among all the models that Pytorch engine has supported, only deepseek v2 with MLA implementation meets the condition.

rank = 0
if dist.is_initialized():
rank = dist.get_rank()
if rank == 0:
logger.warning(
f'Update `block_size={cache_config.block_size}`'
f' for large `head_dim={model_config.k_head_dim}`.')

__adjust_block_size()

cache_block_size = CacheEngine.get_cache_block_size(
cache_config.block_size, model_config, world_size)
gpu_mem = __get_free_gpu_mem_size(cache_block_size)
Expand Down
23 changes: 16 additions & 7 deletions lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _div_up(val, other):
stride_boff=int,
BLOCK=torch.int32,
BLOCK_D=torch.int32,
BLOCK_DV=torch.int32,
BLOCK_H=torch.int32,
))
@triton.jit
Expand Down Expand Up @@ -71,6 +72,7 @@ def _fill_kv_cache_kernel(
stride_boff,
BLOCK: tl.constexpr,
BLOCK_D: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_H: tl.constexpr,
):
"""fill kv cache kernel."""
Expand Down Expand Up @@ -118,13 +120,17 @@ def _fill_kv_cache_kernel(
k,
mask=mask)

v = tl.load(vs_ptr + sidx * stride_vss + h_off[:, None] * stride_vsh +
d_off[None, :] * stride_vsd,
mask=mask)
tl.store(vc_ptr + bidx * stride_vcb + h_off[:, None] * stride_vch +
d_off[None, :] * stride_vcd,
v,
mask=mask)
if BLOCK_DV > 0:
dv_off = tl.arange(0, BLOCK_DV)
maskv = (h_off[:, None] < num_heads) & (dv_off[None, :] < head_dim)
v = tl.load(vs_ptr + sidx * stride_vss +
h_off[:, None] * stride_vsh +
dv_off[None, :] * stride_vsd,
mask=maskv)
tl.store(vc_ptr + bidx * stride_vcb + h_off[:, None] * stride_vch +
dv_off[None, :] * stride_vcd,
v,
mask=maskv)


def fill_kv_cache(k_states: Tensor, v_states: Tensor, k_caches: Tensor,
Expand All @@ -136,11 +142,13 @@ def fill_kv_cache(k_states: Tensor, v_states: Tensor, k_caches: Tensor,
block_offsets = block_offsets.contiguous()
batch_size = block_offsets.size(0)
block_size, num_heads, head_dim = k_caches.size()[1:]
head_dim_v = v_states.size(-1)
max_num_blocks = triton.cdiv(max_q_seq_length, block_size) + 1

BLOCK = block_size
BLOCK_H = triton.next_power_of_2(num_heads)
BLOCK_D = triton.next_power_of_2(head_dim)
BLOCK_DV = triton.next_power_of_2(head_dim_v)
grid = [batch_size, max_num_blocks]
kernel_meta = get_kernel_meta(k_states)
_fill_kv_cache_kernel[grid](
Expand Down Expand Up @@ -171,6 +179,7 @@ def fill_kv_cache(k_states: Tensor, v_states: Tensor, k_caches: Tensor,
stride_boff=block_offsets.stride(0),
BLOCK=BLOCK,
BLOCK_D=BLOCK_D,
BLOCK_DV=BLOCK_DV,
BLOCK_H=BLOCK_H,
num_warps=4,
num_stages=3,
Expand Down
Loading