Skip to content

feat: DeepSeek new v3.2 encoding#14249

Merged
Fridge003 merged 9 commits intosgl-project:mainfrom
Eva20150932-atlascloud:v32_encoding
Dec 2, 2025
Merged

feat: DeepSeek new v3.2 encoding#14249
Fridge003 merged 9 commits intosgl-project:mainfrom
Eva20150932-atlascloud:v32_encoding

Conversation

@Eva20150932-atlascloud
Copy link
Copy Markdown
Contributor

Motivation

#14227
DeepSeek official release a new encoding func to replace chat_template, and I made one workable version(though it is hard-coded and breaks other models)

if you still use old chat-template for the formal v3.2, then it works bad for tool_calling. so we need new encoding to run new v3.2 model.

Modifications

Accuracy Tests

start a server like python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3.2 --trust-remote-code --tp-size 8 --host 0.0.0.0 --tool-call-parser deepseekv32 --enable-metrics --max-queued-requests 3 --max-running-requests 64 --cuda-graph-max-bs 64 --reasoning-parser deepseek-v3 and my test for tool_calling passed.

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@Fridge003
Copy link
Copy Markdown
Collaborator

@Eva20150932-atlascloud Can we be compatible with the former template of DeepSeek-v32-Exp?

@Eva20150932-atlascloud
Copy link
Copy Markdown
Contributor Author

I've tried old v32 chat template. but model doesn't work for my tool-call tests

@Fridge003
Copy link
Copy Markdown
Collaborator

I've tried old v32 chat template. but model doesn't work for my tool-call tests

I mean can we put the different chat templates in separate files, and apply them to the different models (V32/V32-Exp)

@Johnsonms
Copy link
Copy Markdown
Contributor

Verified it works

@Eva20150932-atlascloud
Copy link
Copy Markdown
Contributor Author

@Fridge003 possible, though it needs to set xml attribute, and I'm not experienced in building jinja

…=ChoiceDeltaToolCallFunction(arguments={}, name=None), type=function)] when streaming
…Added detection logic for using DPSK V3.2 encoding based on tokenizer configuration and architecture. Updated tests to validate encoding path and functionality. Adapted encoding_dsv32.py from Hugging Face repository.

Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Comment thread python/sglang/srt/entrypoints/openai/serving_chat.py
@JustinTong0323 JustinTong0323 changed the title hard code hacking for DeepSeek new v3.2 encoding feat: DeepSeek new v3.2 encoding Dec 2, 2025
Comment thread python/sglang/srt/entrypoints/openai/serving_chat.py
Comment thread python/sglang/srt/entrypoints/openai/serving_chat.py
@Fridge003
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label Dec 2, 2025
Comment thread python/sglang/srt/entrypoints/openai/serving_chat.py
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
@JustinTong0323
Copy link
Copy Markdown
Collaborator

JustinTong0323 commented Dec 2, 2025

I believe this PR is ready cc @Fridge003

Comment on lines +86 to +96
self.use_dpsk_v32_encoding = self._use_dpsk_v32_encoding()

def _use_dpsk_v32_encoding(self) -> bool:
has_chat_template = (
self.tokenizer_manager.tokenizer is not None
and self.tokenizer_manager.tokenizer.chat_template is not None
)
architectures = self.tokenizer_manager.server_args.get_hf_config().architectures
is_dpsk_v32 = "DeepseekV3" in architectures[0] if architectures else False
return not has_chat_template and is_dpsk_v32

Copy link
Copy Markdown
Contributor

@jimmy-evo jimmy-evo Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.use_dpsk_v32_encoding = self.tokenizer_manager.server_args.tool_call_parser == "deepseekv32"

just dont determine with "architectures", but "tool_call_parser"

Copy link
Copy Markdown
Collaborator

@JustinTong0323 JustinTong0323 Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should not as the tool_call_parser is not necessary in some cases, but this code path matters

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can just add an environ SGLANG_USE_DPSKV32_ENCODING=True, then no need to concern how to determine this

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JustinTong0323 i think this custom encode is a temp way. actually it is a kind of def apply_chat_template

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite sure, do you mean we should not default it? But this code is adapted from deepseek's hf repo so I think it should be enabled by default

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite sure, do you mean we should not default it? But this code is adapted from deepseek's hf repo so I think it should be enabled by default

do you remember when huggingface transforers~=4.2x (2023/2024), open source models usually provide a tokenizer.py with def apply_chat_template.

this encoding_dsv32.py is that def apply_chat_template.

JustinTong0323 and others added 2 commits December 2, 2025 08:22
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>

# Check if invoke_content is empty or whitespace only
# If so, skip this tool call entirely (it's likely incomplete or malformed)
if not invoke_content.strip():
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here will ignore the non-parameter function.
Like :

<|DSML|invoke name="get_current_time">

</|DSML|invoke>

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would fix that later.

@soaringk
Copy link
Copy Markdown
Contributor

soaringk commented Dec 2, 2025

There is a parse_tool_calls function in encoding_dsv32.py. I reckon we should use that one to parse function calls?

@Fridge003 Fridge003 merged commit 7c38eca into sgl-project:main Dec 2, 2025
174 of 185 checks passed
harvenstar pushed a commit to harvenstar/sglang that referenced this pull request Dec 4, 2025
Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
yingluosanqian pushed a commit to yingluosanqian/sglang that referenced this pull request Dec 4, 2025
Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
tonyluj pushed a commit to openanolis/sglang that referenced this pull request Dec 5, 2025
Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
@Muqi1029
Copy link
Copy Markdown
Contributor

Muqi1029 commented Dec 8, 2025

Hi, may I ask why use the while loop in parse_streaming_increment? @Eva20150932-atlascloud

@Eva20150932-atlascloud
Copy link
Copy Markdown
Contributor Author

Do you mean we only need to prepare for parsing one invoke-block, since the model generates only one token per forward?

PR(#11652) to support MTP on v3.2 makes generating more than one invoke-block possible (though very low possibility). And by the way, I think it's harmless to use the while loop as it would break once the invoke-regex is not matched. @Muqi1029

yuchengz816-bot pushed a commit to yuchengz816-bot/sglang that referenced this pull request Dec 8, 2025
Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
@Muqi1029
Copy link
Copy Markdown
Contributor

Muqi1029 commented Dec 8, 2025

But even though MTP generates more than one tokens once, the logic without while loop can still handle that next time since it's still in the buffer, right?

@Eva20150932-atlascloud
Copy link
Copy Markdown
Contributor Author

That logic sounds good, and it makes me rethink things.

Could there be a case where the response doesn't have a 'next time'? For instance, what if the MTP forward generates the eos token?

@Muqi1029
Copy link
Copy Markdown
Contributor

Muqi1029 commented Dec 9, 2025

@Eva20150932-atlascloud Thanks for you answering! I think maybe you are right, while loop should be kept!

But here I have met another question, why you use these markers here?

# Check if buffer contains any DSML markers or ends with potential tag prefix
# This handles partial/streaming DSML content
dsml_markers = ["|DSML|", "<|", "</|"]
potentially_dsml = any(marker in current_text for marker in dsml_markers)
# Also check if text ends with start of a tag (to handle "<" arriving separately)
dsml_prefixes = ["<", "<|", "</", "</|"]
ends_with_prefix = any(
current_text.rstrip().endswith(prefix) for prefix in dsml_prefixes
)
if not has_tool_call and not potentially_dsml and not ends_with_prefix:
self._buffer = ""
for e_token in [self.eot_token, self.invoke_end_token]:
if e_token in new_text:
new_text = new_text.replace(e_token, "")
return StreamingParseResult(normal_text=new_text)

I think model output in the token level, you can use the following scripts to see the tokens:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-V3.2")

special_tokens = [
    "<|DSML|function_calls>",
    "</|DSML|function_calls>",
    "<|DSML|invoke",
    "</|DSML|invoke",
]

for tokens in special_tokens:
    print("\n\n")
    print(f" Processing {tokens} ".center(80, "-"))
    ids = tokenizer.encode(tokens, add_special_tokens=False)
    for id in ids:
        tokens = tokenizer.decode(id)
        print(f"'{tokens}' : {id}")

The output is as follows:


---------------------- Processing <|DSML|function_calls> -----------------------
'<' : 30
'|DSML|' : 128793
'function' : 8701
'_c' : 4941
'alls' : 12548
'>' : 32



---------------------- Processing </|DSML|function_calls> ----------------------
'</' : 1718
'|DSML|' : 128793
'function' : 8701
'_c' : 4941
'alls' : 12548
'>' : 32



--------------------------- Processing <|DSML|invoke ---------------------------
'<' : 30
'|DSML|' : 128793
'inv' : 40148
'oke' : 5406



-------------------------- Processing </|DSML|invoke ---------------------------
'</' : 1718
'|DSML|' : 128793
'inv' : 40148
'oke' : 5406

So <| will not be generated forever. Right?

@jxz542189
Copy link
Copy Markdown

@Eva20150932-atlascloud Can you ensure that the function calls are output in the expected streaming manner? #14711

@Muqi1029
Copy link
Copy Markdown
Contributor

Muqi1029 commented Dec 9, 2025

@Eva20150932-atlascloud Can you ensure that the function calls are output in the expected streaming manner? #14711

@jxz542189
It indeed doesn't support Streaming output, I am working on this, planning to submit a PR tonight

whybeyoung added a commit to whybeyoung/sglang that referenced this pull request Jan 5, 2026
@whybeyoung
Copy link
Copy Markdown
Collaborator

When using smg and grpc mode, i think it should do similar thing with this pr @slin1237

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.