Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def __init__(self, model_runner: ModelRunner):
self.retrieve_parent_token_list = []
self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None
self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None
self.conv_states_shape: tuple[int, int] = None
self.conv_states_shape: tuple[int, int] = (
self.req_to_token_pool.mamba_pool.mamba_cache.conv[0].shape
)

def _forward_metadata(self, forward_batch: ForwardBatch):
bs = forward_batch.batch_size
Expand Down
3 changes: 0 additions & 3 deletions python/sglang/srt/layers/attention/linear/gdn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,6 @@ class GDNAttnBackend(MambaAttnBackendBase):

def __init__(self, model_runner: ModelRunner):
super().__init__(model_runner)
self.conv_states_shape = (
model_runner.req_to_token_pool.mamba_pool.mamba_cache.conv[0].shape
)
if not is_cpu() and not is_npu():
assert (
self.conv_states_shape[-1] < FLA_CHUNK_SIZE
Expand Down
111 changes: 98 additions & 13 deletions python/sglang/srt/models/lfm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"""

import logging
from typing import Iterable, Optional, Set, Tuple
from typing import Iterable, List, Optional, Set, Tuple

import torch
import torch.nn.functional as F
Expand All @@ -24,6 +24,10 @@
causal_conv1d_fn,
causal_conv1d_update,
)
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
causal_conv1d_update as causal_conv1d_update_triton,
)
from sglang.srt.mem_cache.memory_pool import MambaPool
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
Expand Down Expand Up @@ -120,17 +124,17 @@ def __init__(
)
self.scaling = self.head_dim**-0.5

rope_parameters = getattr(config, "rope_parameters", None)
if rope_parameters is not None and "rope_theta" in rope_parameters:
rope_theta = rope_parameters["rope_theta"]
else:
rope_theta = config.rope_parameters["rope_theta"]
rope_parameters = getattr(config, "rope_parameters", None) or {}
rope_theta = rope_parameters.get("rope_theta") or getattr(
config, "rope_theta", 10000
)
rope_scaling = getattr(config, "rope_scaling", None)

self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=self.head_dim,
max_position=getattr(config, "max_position_embeddings", 8192),
rope_scaling=config.rope_parameters,
rope_scaling=rope_scaling,
base=rope_theta,
is_neox_style=True,
dtype=torch.get_default_dtype(),
Expand Down Expand Up @@ -255,6 +259,15 @@ def __init__(
else:
self.register_parameter("conv_bias", None)

# Pre-allocated arange buffer for TARGET_VERIFY's intermediate_state_indices
# (avoids per-call torch.arange in the spec-decode hot path). Sized to the
# default cuda_graph_max_bs; resized lazily if a larger batch shows up.
self.register_buffer(
"_intermediate_state_indices",
torch.arange(256, dtype=torch.int32),
persistent=False,
)

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -266,6 +279,10 @@ def forward(
layer_cache = forward_batch.req_to_token_pool.mamba2_layer_cache(self.layer_idx)
conv_state = layer_cache.conv[0]
req_pool_indices = forward_batch.req_pool_indices
# Pools are independent allocators; extra_buffer adds tracking slots.
mamba_cache_indices = forward_batch.req_to_token_pool.get_mamba_indices(
req_pool_indices
)

# Project and split into gates: B (pre-conv), C (post-conv), x (input)
proj, _ = self.in_proj(hidden_states)
Expand All @@ -280,8 +297,33 @@ def forward(
self.conv_weight,
self.conv_bias,
activation=None,
conv_state_indices=req_pool_indices.to(torch.int32),
conv_state_indices=mamba_cache_indices,
)
elif forward_batch.forward_mode.is_target_verify():
# Tape per-step conv windows into intermediate_conv_window so the
# backend can roll state back to the accepted boundary after verify.
assert isinstance(layer_cache, MambaPool.SpeculativeState), (
"LFM2 TARGET_VERIFY requires --mamba-scheduler-strategy extra_buffer."
)
draft_token_num = forward_batch.spec_info.draft_token_num
bs = req_pool_indices.shape[0]
Bx_reshaped = Bx.view(bs, draft_token_num, -1).transpose(1, 2)
if self._intermediate_state_indices.shape[0] < bs:
self._intermediate_state_indices = torch.arange(
bs, dtype=torch.int32, device=Bx.device
)
intermediate_state_indices = self._intermediate_state_indices[:bs]
conv_out = causal_conv1d_update_triton(
Bx_reshaped,
conv_state,
self.conv_weight,
self.conv_bias,
activation=None,
conv_state_indices=mamba_cache_indices,
intermediate_conv_window=layer_cache.intermediate_conv_window[0],
intermediate_state_indices=intermediate_state_indices,
)
conv_out = conv_out.transpose(1, 2).reshape(bs * draft_token_num, -1)
else:
# Prefill: multiple tokens, use varlen kernel
T = hidden_states.shape[0]
Expand All @@ -298,12 +340,12 @@ def forward(
),
]
)
cache_indices = req_pool_indices.to(torch.int32)
cache_indices = mamba_cache_indices
else:
query_start_loc = torch.tensor(
[0, T], dtype=torch.int32, device=hidden_states.device
)
cache_indices = req_pool_indices[:1].to(torch.int32)
cache_indices = mamba_cache_indices[:1]

conv_out = causal_conv1d_fn(
Bx_t,
Expand Down Expand Up @@ -365,9 +407,13 @@ def forward(
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
forward_batch: ForwardBatch,
captured_last_layer_outputs: Optional[List[torch.Tensor]] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
if not forward_batch.forward_mode.is_idle():
if captured_last_layer_outputs is not None:
captured_last_layer_outputs.append(hidden_states)

residual = hidden_states
normed = self.operator_norm(hidden_states)

Expand Down Expand Up @@ -419,6 +465,13 @@ def get_layer(idx: int, prefix: str, **kwargs):
)
self.embedding_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)

self.layers_to_capture: List[int] = []

def set_dflash_layers_to_capture(self, layers_to_capture: List[int]):
self.layers_to_capture = list(layers_to_capture)
for layer_id in self.layers_to_capture:
setattr(self.layers[layer_id], "_is_layer_to_capture", True)

def forward(
self,
input_ids: torch.Tensor,
Expand All @@ -431,16 +484,26 @@ def forward(
)

residual = None
aux_hidden_states: List[torch.Tensor] = []
for i in range(len(self.layers)):
hidden_states, residual = self.layers[i](
layer = self.layers[i]
hidden_states, residual = layer(
layer_id=i,
positions=positions,
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
captured_last_layer_outputs=(
aux_hidden_states
if getattr(layer, "_is_layer_to_capture", False)
else None
),
)

return self.embedding_norm(hidden_states)
normalized = self.embedding_norm(hidden_states)
if not aux_hidden_states:
return normalized
return normalized, aux_hidden_states


class Lfm2ForCausalLM(nn.Module):
Expand Down Expand Up @@ -477,6 +540,18 @@ def get_num_kv_cache_layers(self) -> int:
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens

@property
def capture_aux_hidden_states(self) -> bool:
return bool(self.model.layers_to_capture)

def set_dflash_layers_to_capture(self, layer_ids: List[int]):
if not self.pp_group.is_last_rank:
return
# Mark layer L to capture its input, which is the output of layer L-1.
# Draft `target_layer_ids` refer to output indices, so shift by +1.
self.model.set_dflash_layers_to_capture([val + 1 for val in layer_ids])


@torch.no_grad()
def forward(
self,
Expand All @@ -487,8 +562,18 @@ def forward(
**kwargs,
):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)

aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states


return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids,
hidden_states,
self.lm_head,
forward_batch,
aux_hidden_states,
)

def load_weights(
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2096,7 +2096,10 @@ def _handle_model_specific_adjustments(self):
self._handle_mamba_radix_cache(
model_arch=model_arch,
support_mamba_cache=True,
support_mamba_cache_extra_buffer=False,
# Allow extra_buffer so DFLASH spec-v2 can run with radix cache.
# Requires the Mamba2/ShortConv attn backend + intermediate_conv_window
# tape to rollback partial verifies; turn off if rollback isn't wired.
support_mamba_cache_extra_buffer=True,
sm100_default_attention_backend="flashinfer",
)
assert self.attention_backend != "triton", (
Expand Down
Loading