Skip to content

feat(observability): add OpenTelemetry tracing for pipeline parallelism#23169

Merged
ShangmingCai merged 1 commit intosgl-project:mainfrom
jiangyinzuo:feat/pp-tace
Apr 28, 2026
Merged

feat(observability): add OpenTelemetry tracing for pipeline parallelism#23169
ShangmingCai merged 1 commit intosgl-project:mainfrom
jiangyinzuo:feat/pp-tace

Conversation

@jiangyinzuo
Copy link
Copy Markdown
Contributor

@jiangyinzuo jiangyinzuo commented Apr 19, 2026

Motivation

Implement PP OpenTelemetry tracing as mentioned in roadmap #13511

image

Modifications

add pp_forward metrics

Accuracy Tests

Speed Tests and Profiling

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces observability for pipeline parallelism by adding tracing for PP forward passes. It includes a new PP_FORWARD request stage and updates the scheduler to record timing statistics and metadata during batch execution. A potential issue was identified in set_time_batch where passing attributes as positional arguments could cause a TypeError for methods that do not support them.

Comment on lines +1142 to +1145
if attrs is None:
method(ts)
else:
method(ts, attrs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Passing attrs as a positional argument to method is risky because many existing set_*_time methods in SchedulerReqTimeStats (like set_forward_entry_time) do not accept a second positional argument. If set_time_batch is called with attrs for one of those methods, it will raise a TypeError. It would be safer to pass attrs as a keyword argument, provided the target methods are updated to accept it.

@jiangyinzuo jiangyinzuo changed the title feat(observability): add OpenTelemetry tracing for pipeline parallism feat(observability): add OpenTelemetry tracing for pipeline parallelism Apr 19, 2026
@jiangyinzuo jiangyinzuo force-pushed the feat/pp-tace branch 2 times, most recently from 3c8f2b8 to a16345c Compare April 20, 2026 02:53
Comment thread python/sglang/srt/managers/scheduler_pp_mixin.py Outdated
Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. cc: @sufeng-buaa please double-check.

)

# pipeline parallelism
PP_FORWARD = RequestStageConfig(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stage name is somewhat ambiguous. The segment you are tracing is actually the CPU-side run_batch, so run_batch_cpu would be a more appropriate name.

last_decode_scheduled_time: float = 0.0
last_forward_entry_time: float = 0.0
last_prefill_finished_time: float = 0.0
pp_forward_start_time: float = 0.0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, please rename it accordingly.

"is_last_pp_rank": self.pp_group.is_last_rank,
}
if mb_id is not None:
attrs["pp_mb_id"] = mb_id
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can keep mb_id, and move the other scheduler-related attributes to the thread span.

with torch.profiler.record_function("run_batch"):
with self.forward_stream_ctx:
self.forward_stream.wait_stream(self.schedule_stream)
if trace_enabled:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the function parameters are simple, there’s no need to check trace_enabled here.

# pipeline parallelism
PP_FORWARD = RequestStageConfig(
"pp_forward",
level=2,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under prefill_forward and decode_forward, there may be chunked_prefill and decode_loop, so this span will be attached one level deeper. Let’s set the level to 4 for now. I’ll refactor the levels consistently later.

@sufeng-buaa
Copy link
Copy Markdown
Collaborator

trace_set_thread_info and __create_thread_context do not include the PP rank information. Please help add it.

And please update scripts/convert_otel_2_perfetto.py

diff --git a/scripts/convert_otel_2_perfetto.py b/scripts/convert_otel_2_perfetto.py
index 3a82969a4..89534a38f 100644
--- a/scripts/convert_otel_2_perfetto.py
+++ b/scripts/convert_otel_2_perfetto.py
@@ -237,6 +237,10 @@ def generate_perfetto_span(engine_root_spans, smg_otel_spans, thread_meta_data):
             pid = int(thread_span["attributes"]["pid"])
             host_id = thread_span["attributes"]["host_id"]
             thread_name = f'{thread_span["attributes"]["host_id"][:8]}:{thread_span["attributes"]["thread_label"]}'
+            if "pp_rank" in thread_span["attributes"]:
+                thread_name += f"-PP{thread_span['attributes']['pp_rank']}"
+            if "dp_rank" in thread_span["attributes"]:
+                thread_name += f"-DP{thread_span['attributes']['dp_rank']}"
             if "tp_rank" in thread_span["attributes"]:
                 thread_name += f"-TP{thread_span['attributes']['tp_rank']}"

@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

jiangyinzuo commented Apr 26, 2026

@sufeng-buaa I have resolved the above reviews. Could you please review again?

attrs=attrs,
)
result = self.run_batch(self.cur_batch, pp_proxy_tensors)
set_time_batch(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_time_batch(
      self.cur_batch.reqs,
      "set_run_batch_cpu_start_time",
      trace_only=True,
      attrs={"pp_mb_id":mb_id}
)

self.cur_batch.reqs,
"set_run_batch_cpu_start_time",
trace_only=True,
attrs=attrs,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No attrs need to be passed in

mb_metadata: List[Optional[PPBatchMetadata]],
last_rank_comm_queue: deque,
):
attrs = (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simple attributes can be passed directly to set_time_batch()

@jiangyinzuo jiangyinzuo force-pushed the feat/pp-tace branch 2 times, most recently from 7a5b731 to c974f66 Compare April 26, 2026 11:17
@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

@sufeng-buaa OK, code simplified

@sufeng-buaa
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix this test.

Begin (11/17):
python3 /home/runner/work/sglang/sglang/test/registered/unit/observability/test_trace.py
.
.

Skipping import of cpp extensions due to incompatible torch version. Please upgrade to torch >= 2.11.0 (found 2.9.1+cu130).
..E
======================================================================
ERROR: test_trace_thread_context (__main__.TestDataclasses)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/runner/work/sglang/sglang/test/registered/unit/observability/test_trace.py", line 82, in test_trace_thread_context
    info = TraceThreadInfo("h", 1, "l", 0, 0)
TypeError: TraceThreadInfo.__init__() missing 1 required positional argument: 'pp_rank'

----------------------------------------------------------------------
Ran 3 tests in 0.000s

FAILED (errors=1)
.

Signed-off-by: Yinzuo Jiang <jiangyinzuo@foxmail.com>
@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

Please fix this test.

Begin (11/17):
python3 /home/runner/work/sglang/sglang/test/registered/unit/observability/test_trace.py
.
.

Skipping import of cpp extensions due to incompatible torch version. Please upgrade to torch >= 2.11.0 (found 2.9.1+cu130).
..E
======================================================================
ERROR: test_trace_thread_context (__main__.TestDataclasses)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/runner/work/sglang/sglang/test/registered/unit/observability/test_trace.py", line 82, in test_trace_thread_context
    info = TraceThreadInfo("h", 1, "l", 0, 0)
TypeError: TraceThreadInfo.__init__() missing 1 required positional argument: 'pp_rank'

----------------------------------------------------------------------
Ran 3 tests in 0.000s

FAILED (errors=1)
.

@ShangmingCai this unittest has been fixed, but CI still has some errors, do they related to this PR?

@ShangmingCai
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

1 similar comment
@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@ShangmingCai ShangmingCai merged commit 71160e4 into sgl-project:main Apr 28, 2026
257 of 285 checks passed
vguduruTT pushed a commit to vguduruTT/sglang that referenced this pull request May 2, 2026
…sm (sgl-project#23169)

Signed-off-by: Yinzuo Jiang <jiangyinzuo@foxmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants