Skip to content

Commit f951860

Browse files
sriumcpclaude
andcommitted
[Feature] Add rich request snapshot stream (PR #5)
Implements subsampled per-request detailed progress events with KV metrics: - Add step_tracing_rich_subsample_rate config (default 0.001 = 0.1%) - Emit step.REQUEST_SNAPSHOT events for running requests when subsampled - Use PR #4 get_per_request_kv_metrics() for KV cache data - Two-stage sampling: batch summary sampled AND rich subsampled - SpanAttributes: 10 new constants for per-request metrics - Emission after batch summary, before _update_after_schedule() Also fixes PR #3 CLI wiring bug: - Wire step_tracing_enabled/sample_rate through EngineArgs - Add fields to EngineArgs dataclass - Pass to ObservabilityConfig constructor - Add test_step_tracing_cli_wiring() for regression prevention Tests: 6 new tests (5 rich snapshot + 1 CLI wiring), all 15 pass Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent ab555b8 commit f951860

6 files changed

Lines changed: 388 additions & 0 deletions

File tree

tests/v1/core/test_step_tracing.py

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,3 +457,272 @@ def test_step_tracing_failure_safety():
457457

458458
# Verify exception was caught (add_event was called but didn't crash)
459459
assert mock_span.add_event.call_count == 1
460+
461+
462+
def test_rich_snapshot_rate_zero():
463+
"""Test that rich subsample rate 0.0 produces no rich events.
464+
465+
Verifies:
466+
- Batch summary still emitted (step.BATCH_SUMMARY)
467+
- No rich snapshot events (step.REQUEST_SNAPSHOT)
468+
- Rich sampling is independent from batch summary sampling
469+
"""
470+
with patch("vllm.tracing.init_tracer") as mock_init_tracer:
471+
mock_tracer = Mock()
472+
mock_span = Mock()
473+
mock_tracer.start_span.return_value = mock_span
474+
mock_init_tracer.return_value = mock_tracer
475+
476+
scheduler = create_scheduler(
477+
step_tracing_enabled=True,
478+
step_tracing_sample_rate=1.0, # Always sample batch summary
479+
step_tracing_rich_subsample_rate=0.0, # Never sample rich snapshots
480+
otlp_traces_endpoint="http://localhost:4317",
481+
)
482+
483+
# Override samplers
484+
scheduler._step_sampler = lambda step_id, rate: True # Batch summary always
485+
scheduler._rich_sampler = lambda step_id, rate: False # Rich never
486+
487+
# Add requests and schedule
488+
requests = create_requests(num_requests=3)
489+
for request in requests:
490+
scheduler.add_request(request)
491+
492+
output = scheduler.schedule()
493+
494+
# Verify scheduler worked
495+
assert output.scheduler_step == 1
496+
assert len(output.scheduled_new_reqs) == 3
497+
498+
# Extract event names
499+
event_names = [call[0][0] for call in mock_span.add_event.call_args_list]
500+
501+
# Should have exactly 1 batch summary, no rich snapshots
502+
assert event_names.count("step.BATCH_SUMMARY") == 1
503+
assert event_names.count("step.REQUEST_SNAPSHOT") == 0
504+
505+
506+
def test_rich_snapshot_enabled():
507+
"""Test that rich subsample rate 1.0 emits events for all running requests.
508+
509+
Verifies:
510+
- One step.REQUEST_SNAPSHOT event per running request
511+
- Events have correct step.id correlation
512+
- All required attributes present
513+
- KV metrics populated
514+
"""
515+
with patch("vllm.tracing.init_tracer") as mock_init_tracer:
516+
mock_tracer = Mock()
517+
mock_span = Mock()
518+
mock_tracer.start_span.return_value = mock_span
519+
mock_init_tracer.return_value = mock_tracer
520+
521+
scheduler = create_scheduler(
522+
step_tracing_enabled=True,
523+
step_tracing_sample_rate=1.0,
524+
step_tracing_rich_subsample_rate=1.0, # Always sample rich
525+
otlp_traces_endpoint="http://localhost:4317",
526+
)
527+
528+
# Override samplers to always return True
529+
scheduler._step_sampler = lambda step_id, rate: True
530+
scheduler._rich_sampler = lambda step_id, rate: True
531+
532+
# Add requests and schedule
533+
requests = create_requests(num_requests=4)
534+
for request in requests:
535+
scheduler.add_request(request)
536+
537+
output = scheduler.schedule()
538+
539+
# Verify scheduler worked
540+
assert output.scheduler_step == 1
541+
assert len(output.scheduled_new_reqs) == 4
542+
543+
# Extract events
544+
event_names = [call[0][0] for call in mock_span.add_event.call_args_list]
545+
event_attrs = [call[1]["attributes"] for call in mock_span.add_event.call_args_list]
546+
547+
# Should have 1 batch summary + 4 rich snapshots
548+
assert event_names.count("step.BATCH_SUMMARY") == 1
549+
assert event_names.count("step.REQUEST_SNAPSHOT") == 4
550+
551+
# Verify rich snapshot attributes
552+
rich_events = [
553+
attrs
554+
for name, attrs in zip(event_names, event_attrs)
555+
if name == "step.REQUEST_SNAPSHOT"
556+
]
557+
assert len(rich_events) == 4
558+
559+
for attrs in rich_events:
560+
# Required attributes
561+
assert attrs[SpanAttributes.STEP_ID] == 1
562+
assert SpanAttributes.REQUEST_ID in attrs
563+
assert attrs[SpanAttributes.REQUEST_PHASE] in ("PREFILL", "DECODE")
564+
assert SpanAttributes.REQUEST_NUM_PROMPT_TOKENS in attrs
565+
assert SpanAttributes.REQUEST_NUM_COMPUTED_TOKENS in attrs
566+
assert SpanAttributes.REQUEST_NUM_OUTPUT_TOKENS in attrs
567+
assert SpanAttributes.REQUEST_NUM_PREEMPTIONS in attrs
568+
assert SpanAttributes.REQUEST_SCHEDULED_TOKENS_THIS_STEP in attrs
569+
assert SpanAttributes.KV_BLOCKS_ALLOCATED_GPU in attrs
570+
assert SpanAttributes.KV_BLOCKS_CACHED_GPU in attrs
571+
572+
# Verify types
573+
assert isinstance(attrs[SpanAttributes.STEP_ID], int)
574+
assert isinstance(attrs[SpanAttributes.REQUEST_ID], str)
575+
assert isinstance(attrs[SpanAttributes.KV_BLOCKS_ALLOCATED_GPU], int)
576+
assert isinstance(attrs[SpanAttributes.KV_BLOCKS_CACHED_GPU], int)
577+
578+
579+
def test_rich_snapshot_gated_on_batch_summary():
580+
"""Test that rich snapshots are only emitted when batch summary is sampled.
581+
582+
Verifies the two-stage sampling:
583+
1. Step must be batch-summary-sampled
584+
2. Then rich subsampling decision
585+
"""
586+
with patch("vllm.tracing.init_tracer") as mock_init_tracer:
587+
mock_tracer = Mock()
588+
mock_span = Mock()
589+
mock_tracer.start_span.return_value = mock_span
590+
mock_init_tracer.return_value = mock_tracer
591+
592+
scheduler = create_scheduler(
593+
step_tracing_enabled=True,
594+
step_tracing_sample_rate=1.0,
595+
step_tracing_rich_subsample_rate=1.0,
596+
otlp_traces_endpoint="http://localhost:4317",
597+
)
598+
599+
# Override batch summary sampler to return False (not sampled)
600+
scheduler._step_sampler = lambda step_id, rate: False
601+
# Rich sampler is irrelevant (shouldn't be called)
602+
scheduler._rich_sampler = lambda step_id, rate: True
603+
604+
# Add requests and schedule
605+
requests = create_requests(num_requests=3)
606+
for request in requests:
607+
scheduler.add_request(request)
608+
609+
output = scheduler.schedule()
610+
611+
# Verify scheduler worked
612+
assert output.scheduler_step == 1
613+
assert len(output.scheduled_new_reqs) == 3
614+
615+
# No events should be emitted (batch summary not sampled)
616+
assert mock_span.add_event.call_count == 0
617+
618+
619+
def test_rich_snapshot_deterministic_sampling():
620+
"""Test deterministic rich sampling for reproducible tests.
621+
622+
Verifies:
623+
- Deterministic sampler produces stable results
624+
- Rich sampling decision is independent per step
625+
- Same seed produces same sample set
626+
"""
627+
with patch("vllm.tracing.init_tracer") as mock_init_tracer:
628+
mock_tracer = Mock()
629+
mock_span = Mock()
630+
mock_tracer.start_span.return_value = mock_span
631+
mock_init_tracer.return_value = mock_tracer
632+
633+
scheduler = create_scheduler(
634+
step_tracing_enabled=True,
635+
step_tracing_sample_rate=1.0,
636+
step_tracing_rich_subsample_rate=1.0,
637+
otlp_traces_endpoint="http://localhost:4317",
638+
)
639+
640+
# Use deterministic samplers
641+
scheduler._step_sampler = make_deterministic_step_sampler(seed=42)
642+
scheduler._rich_sampler = make_deterministic_step_sampler(seed=100)
643+
644+
# Run multiple steps
645+
requests = create_requests(num_requests=2)
646+
for request in requests:
647+
scheduler.add_request(request)
648+
649+
for _ in range(5):
650+
scheduler.schedule()
651+
652+
# Extract event names
653+
event_names = [call[0][0] for call in mock_span.add_event.call_args_list]
654+
655+
# With deterministic sampling, results should be stable
656+
batch_summaries = event_names.count("step.BATCH_SUMMARY")
657+
rich_snapshots = event_names.count("step.REQUEST_SNAPSHOT")
658+
659+
# Verify we got some events (exact count depends on hash outputs)
660+
assert batch_summaries > 0
661+
# Rich snapshots only emitted when batch summary was sampled
662+
assert rich_snapshots % 2 == 0 # Should be even (2 requests per step)
663+
664+
665+
def test_rich_snapshot_with_zero_running_requests():
666+
"""Test that rich snapshots work correctly with empty running queue.
667+
668+
Verifies:
669+
- Batch summary emitted even with no running requests
670+
- No rich snapshot events (no requests to snapshot)
671+
- No crashes or errors
672+
"""
673+
with patch("vllm.tracing.init_tracer") as mock_init_tracer:
674+
mock_tracer = Mock()
675+
mock_span = Mock()
676+
mock_tracer.start_span.return_value = mock_span
677+
mock_init_tracer.return_value = mock_tracer
678+
679+
scheduler = create_scheduler(
680+
step_tracing_enabled=True,
681+
step_tracing_sample_rate=1.0,
682+
step_tracing_rich_subsample_rate=1.0,
683+
otlp_traces_endpoint="http://localhost:4317",
684+
)
685+
686+
# Override samplers to always return True
687+
scheduler._step_sampler = lambda step_id, rate: True
688+
scheduler._rich_sampler = lambda step_id, rate: True
689+
690+
# Schedule with no requests
691+
output = scheduler.schedule()
692+
693+
# Verify scheduler worked
694+
assert output.scheduler_step == 1
695+
assert len(output.scheduled_new_reqs) == 0
696+
697+
# Extract event names
698+
event_names = [call[0][0] for call in mock_span.add_event.call_args_list]
699+
700+
# Should have 1 batch summary, 0 rich snapshots (no running requests)
701+
assert event_names.count("step.BATCH_SUMMARY") == 1
702+
assert event_names.count("step.REQUEST_SNAPSHOT") == 0
703+
704+
705+
def test_step_tracing_cli_wiring():
706+
"""Test that CLI flags are properly wired through to ObservabilityConfig.
707+
708+
This is a regression test for PR #3 and PR #5 CLI wiring.
709+
Ensures that step tracing flags flow from CLI -> EngineArgs -> ObservabilityConfig.
710+
"""
711+
from vllm.engine.arg_utils import EngineArgs
712+
713+
# Test values different from defaults
714+
engine_args = EngineArgs(
715+
model="facebook/opt-125m",
716+
step_tracing_enabled=True, # Default: False
717+
step_tracing_sample_rate=0.75, # Default: 0.01
718+
step_tracing_rich_subsample_rate=0.05, # Default: 0.001
719+
)
720+
721+
# Create engine config and verify wiring
722+
vllm_config = engine_args.create_engine_config()
723+
obs_config = vllm_config.observability_config
724+
725+
# Verify all three fields are correctly wired
726+
assert obs_config.step_tracing_enabled is True
727+
assert obs_config.step_tracing_sample_rate == 0.75
728+
assert obs_config.step_tracing_rich_subsample_rate == 0.05

tests/v1/core/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def create_scheduler(
6262
otlp_traces_endpoint: str | None = None,
6363
step_tracing_enabled: bool = False,
6464
step_tracing_sample_rate: float = 0.01,
65+
step_tracing_rich_subsample_rate: float = 0.001,
6566
) -> Scheduler | AsyncScheduler:
6667
"""Create scheduler under test.
6768
@@ -141,6 +142,7 @@ def create_scheduler(
141142
otlp_traces_endpoint=otlp_traces_endpoint,
142143
step_tracing_enabled=step_tracing_enabled,
143144
step_tracing_sample_rate=step_tracing_sample_rate,
145+
step_tracing_rich_subsample_rate=step_tracing_rich_subsample_rate,
144146
)
145147

146148
vllm_config = VllmConfig(

vllm/config/observability.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ def show_hidden_metrics(self) -> bool:
9393
"""Sampling rate for step batch summary events. Range [0.0, 1.0].
9494
Default 0.01 = 1% of steps. Only applies when step_tracing_enabled is True."""
9595

96+
step_tracing_rich_subsample_rate: float = Field(default=0.001, ge=0.0, le=1.0)
97+
"""Subsampling rate for rich per-request snapshots within sampled steps. Range [0.0, 1.0].
98+
Default 0.001 = 0.1% of steps get detailed per-request data. Only applies when
99+
step_tracing_enabled is True AND the step is batch-summary-sampled."""
100+
96101
@cached_property
97102
def collect_model_forward_time(self) -> bool:
98103
"""Whether to collect model forward time for the request."""

vllm/engine/arg_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,13 @@ class EngineArgs:
528528
)
529529
enable_mm_processor_stats: bool = ObservabilityConfig.enable_mm_processor_stats
530530
enable_journey_tracing: bool = ObservabilityConfig.enable_journey_tracing
531+
step_tracing_enabled: bool = ObservabilityConfig.step_tracing_enabled
532+
step_tracing_sample_rate: float = get_field(
533+
ObservabilityConfig, "step_tracing_sample_rate"
534+
)
535+
step_tracing_rich_subsample_rate: float = get_field(
536+
ObservabilityConfig, "step_tracing_rich_subsample_rate"
537+
)
531538
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
532539
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
533540

@@ -1082,6 +1089,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
10821089
"--step-tracing-sample-rate",
10831090
**observability_kwargs["step_tracing_sample_rate"],
10841091
)
1092+
observability_group.add_argument(
1093+
"--step-tracing-rich-subsample-rate",
1094+
**observability_kwargs["step_tracing_rich_subsample_rate"],
1095+
)
10851096

10861097
# Scheduler arguments
10871098
scheduler_kwargs = get_kwargs(SchedulerConfig)
@@ -1739,6 +1750,9 @@ def create_engine_config(
17391750
enable_mm_processor_stats=self.enable_mm_processor_stats,
17401751
enable_logging_iteration_details=self.enable_logging_iteration_details,
17411752
enable_journey_tracing=self.enable_journey_tracing,
1753+
step_tracing_enabled=self.step_tracing_enabled,
1754+
step_tracing_sample_rate=self.step_tracing_sample_rate,
1755+
step_tracing_rich_subsample_rate=self.step_tracing_rich_subsample_rate,
17421756
)
17431757

17441758
# Compilation config overrides

vllm/tracing.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,18 @@ class SpanAttributes:
275275
KV_BLOCKS_TOTAL_GPU = "kv.blocks_total_gpu"
276276
KV_BLOCKS_FREE_GPU = "kv.blocks_free_gpu"
277277

278+
# Rich request snapshot attributes (per-request step snapshots)
279+
REQUEST_ID = "request.id"
280+
REQUEST_PHASE = "request.phase"
281+
REQUEST_NUM_PROMPT_TOKENS = "request.num_prompt_tokens"
282+
REQUEST_NUM_COMPUTED_TOKENS = "request.num_computed_tokens"
283+
REQUEST_NUM_OUTPUT_TOKENS = "request.num_output_tokens"
284+
REQUEST_NUM_PREEMPTIONS = "request.num_preemptions"
285+
REQUEST_SCHEDULED_TOKENS_THIS_STEP = "request.scheduled_tokens_this_step"
286+
KV_BLOCKS_ALLOCATED_GPU = "kv.blocks_allocated_gpu"
287+
KV_BLOCKS_CACHED_GPU = "kv.blocks_cached_gpu"
288+
REQUEST_EFFECTIVE_PROMPT_LEN = "request.effective_prompt_len"
289+
278290

279291
def contains_trace_headers(headers: Mapping[str, str]) -> bool:
280292
return any(h in headers for h in TRACE_HEADERS)

0 commit comments

Comments
 (0)