-
Notifications
You must be signed in to change notification settings - Fork 695
Torch deepseek v2 #1621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torch deepseek v2 #1621
Changes from all commits
7bac22d
77034bf
d7138f1
fd3f84f
1cf9990
f62b005
1bc41d4
a27e72d
cffce51
2685491
68a48f7
10d8a6f
c53f5ce
44dd1b4
e9bc9fc
b86244f
e65a5e0
839fc43
c0677b7
b332dd8
41277f2
e946014
a0d0141
5967ce4
d18af8f
62033ae
7644b9c
34ad3df
203c3be
995e0ed
473283d
e21eaa0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| # Copyright (c) OpenMMLab. All rights reserved. | ||
| from lmdeploy.pytorch.config import ModelConfig | ||
|
|
||
| from .builder import AutoModelConfigBuilder | ||
|
|
||
|
|
||
| class DeepseekV2ModelConfigBuilder(AutoModelConfigBuilder): | ||
|
|
||
| @classmethod | ||
| def condition(cls, hf_config): | ||
| """config.""" | ||
| return hf_config.model_type == 'deepseek_v2' | ||
|
|
||
| @classmethod | ||
| def build(cls, hf_config, model_path: str = None): | ||
| """build.""" | ||
| head_dim = (hf_config.kv_lora_rank + hf_config.qk_rope_head_dim) | ||
| k_head_dim = head_dim | ||
| v_head_dim = 0 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| num_attention_heads = hf_config.num_attention_heads | ||
| num_key_value_heads = 1 | ||
| init_kwargs = dict(attn_implementation='eager') | ||
| return ModelConfig(hidden_size=hf_config.hidden_size, | ||
|
grimoire marked this conversation as resolved.
|
||
| num_layers=hf_config.num_hidden_layers, | ||
| num_attention_heads=num_attention_heads, | ||
| num_key_value_heads=num_key_value_heads, | ||
| bos_token_id=hf_config.bos_token_id, | ||
| eos_token_id=hf_config.eos_token_id, | ||
| head_dim=head_dim, | ||
| k_head_dim=k_head_dim, | ||
| v_head_dim=v_head_dim, | ||
| vocab_size=hf_config.vocab_size, | ||
| multi_query_attention=True, | ||
| init_kwargs=init_kwargs) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -81,6 +81,21 @@ def __get_free_gpu_mem_size(cache_block_size: int): | |
| f' {runtime_cache_size>>20} mb') | ||
| return gpu_mem_physical_free * cache_config.cache_max_entry_count | ||
|
|
||
| def __adjust_block_size(): | ||
| """adjust block_size.""" | ||
| # TODO: support kernel with both large head dim and large block size. | ||
| if model_config.k_head_dim >= 512 and cache_config.block_size > 32: | ||
| cache_config.block_size = 32 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this affect models other than DeepSeek v2?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the mha kernel needs enough smem to cache the kv_cache block and query block. Any model with such a large head_dim should be limited. |
||
| rank = 0 | ||
| if dist.is_initialized(): | ||
| rank = dist.get_rank() | ||
| if rank == 0: | ||
| logger.warning( | ||
| f'Update `block_size={cache_config.block_size}`' | ||
| f' for large `head_dim={model_config.k_head_dim}`.') | ||
|
|
||
| __adjust_block_size() | ||
|
|
||
| cache_block_size = CacheEngine.get_cache_block_size( | ||
| cache_config.block_size, model_config, world_size) | ||
| gpu_mem = __get_free_gpu_mem_size(cache_block_size) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.