[Bug Fix] missing index/KV transfer for MTP layer in NSA disaggregation#23539
[Bug Fix] missing index/KV transfer for MTP layer in NSA disaggregation#23539ShangmingCai merged 8 commits intosgl-project:mainfrom
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
ShangmingCai
left a comment
There was a problem hiding this comment.
Nice catch. Thx for the fix.
|
/tag-and-rerun-ci |
|
/rerun-stage stage-c-test-8-gpu-h20 |
|
✅ Triggered |
|
/rerun-stage stage-c-test-8-gpu-h20 |
|
✅ Triggered |
|
/rerun-stage stage-c-test-8-gpu-h20 |
|
✅ Triggered |
| if self.draft_token_to_kv_pool is not None and isinstance( | ||
| self.draft_token_to_kv_pool, NSATokenToKVPool | ||
| ): |
There was a problem hiding this comment.
Two small questions here
- Do we need
draft_token_to_kv_poolto also beNSATokenToKVPool? - Need to check if
hasattr(self.draft_token_to_kv_pool,"get_state_buf_infos")
There was a problem hiding this comment.
Make sense. Do we need to consider the Non-MTP spec decode cases? Not sure if this is a common use-case. @zRzRzRzRzRzRzR
There was a problem hiding this comment.
The isinstance(self.draft_token_to_kv_pool, NSATokenToKVPool) guard makes this PR a no-op for non-NSA draft pools, so non-MTP spec decode cases aren't affected. The fix kicks in only when the draft pool is also NSA, which today is the MTP-on-NSA path. If a future non-MTP spec decode also runs an NSA draft, the same logic applies and would Just Work.
|
/rerun-failed-ci |
1 similar comment
|
/rerun-failed-ci |
|
PD-related CI has passed. Let's merge. |
Motivation
In PD disaggregation with NSA + MTP, only the target model's NSA state buffers are registered for transfer. The draft model's
NSATokenToKVPoolbuffers are never appended tokv_args, so the MTP layer's index/KV state is not sent from prefill to decode, causing wrong speculative decoding results.Modifications
In
DecodePreallocQueueandPrefillBootstrapQueue, when the main pool is NSA, also appenddraft_token_to_kv_pool.get_state_buf_infos()tokv_argsif the draft pool is also NSA.