Skip to content

[user-streams] Enable FX metadata stream annotations in scheduler#174223

Closed
mlazos wants to merge 19 commits intogh/mlazos/93/basefrom
gh/mlazos/93/head
Closed

[user-streams] Enable FX metadata stream annotations in scheduler#174223
mlazos wants to merge 19 commits intogh/mlazos/93/basefrom
gh/mlazos/93/head

Conversation

@mlazos
Copy link
Contributor

@mlazos mlazos commented Feb 3, 2026

Stack from ghstack (oldest at bottom):

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

  1. Populate stream assignments BEFORE fusion to prevent fusing nodes
    across stream boundaries
  2. Check stream assignments in can_fuse() to block cross-stream fusion
  3. Handle stream context switching in codegen with proper device guard
    nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 3, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/174223

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit 2f618c0 with merge base d2d12ef (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 3, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

…heduler"

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
mlazos added a commit that referenced this pull request Feb 18, 2026
Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

ghstack-source-id: b459478
Pull Request resolved: #174223
mlazos added a commit that referenced this pull request Feb 18, 2026
Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

ghstack-source-id: b459478
Pull Request resolved: #174223
mlazos added a commit that referenced this pull request Feb 18, 2026
Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

ghstack-source-id: b459478
Pull Request resolved: #174223
…heduler"

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
mlazos added a commit that referenced this pull request Mar 2, 2026
Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

ghstack-source-id: b208af6
Pull Request resolved: #174223
…heduler"

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
mlazos added a commit that referenced this pull request Mar 3, 2026
Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

ghstack-source-id: 9263147
Pull Request resolved: #174223
mlazos added 11 commits March 3, 2026 11:55
…heduler"

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
…heduler"

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
…heduler"

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
…heduler"

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
…heduler"

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
…heduler"

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
…heduler"

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
…heduler"

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
…heduler"

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
…heduler"

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
…heduler"

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
…heduler"

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
@mlazos mlazos requested a review from eellison March 6, 2026 18:07
mlazos added a commit that referenced this pull request Mar 6, 2026
Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

ghstack-source-id: 7486432
Pull Request resolved: #174223
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you say more about how we guarantee the meta stream is correct ? E.g., do we need to add support to FakeTensorUpdater ? Do we need to make register_replacement, replace_by_example, etc, propagate stream assignment ?

At one point we discussed a fork stream op that would take inputs/produce them on output stream. do you have any graphs i could look at ?

@mlazos
Copy link
Contributor Author

mlazos commented Mar 9, 2026

Can you say more about how we guarantee the meta stream is correct ? E.g., do we need to add support to FakeTensorUpdater ? Do we need to make register_replacement, replace_by_example, etc, propagate stream assignment ?

At one point we discussed a fork stream op that would take inputs/produce them on output stream. do you have any graphs i could look at ?

Yeah, I have a draft of the fork/join stuff, but wanted to experiment with this first to see how complex it would be to propagate. This impl is a lot simpler, and I wanted to let Ads experiment with this first. I do remember we discussed this, do you think propagation would be complicated? With the frontier stuff people have been relying on meta quite a bit, especially for the regional inductor work so I was thinking it was solid.

@eellison
Copy link
Contributor

eellison commented Mar 9, 2026

I think that's reasonable, but I do think we should re-propagate, or have other safety mechanisms. Any graph pattern today will lose the stream annotation.

@mlazos
Copy link
Contributor Author

mlazos commented Mar 11, 2026

@claude review this pr

@claude
Copy link

claude bot commented Mar 11, 2026

Claude finished @mlazos's task in 3m 19s —— View job


  • Read the PR diff and understand changes
  • Read surrounding context in scheduler.py
  • Analyze code quality, correctness, and edge cases
  • Post review feedback

Review: [user-streams] Enable FX metadata stream annotations in scheduler

Overall the approach is sound — reading custom.stream FX metadata before fusion and blocking cross-stream fusion in can_fuse() is a clean design. Here are some issues I spotted:


1. node_to_stream missing entries for fused/grouped/combo nodes (potential KeyError)

_populate_stream_assignments() runs before fusion (line 3146), populating node_to_stream with pre-fusion nodes. After fusion, fuse_two_nodes() (line 4706) creates new FusedSchedulerNode objects that replace the original nodes, but node_to_stream is never updated with these new fused node entries.

  • can_fuse() (line 5775) uses .get() which safely returns None — this is fine.
  • But generate_stream_ctx_enter() (line 7489) and generate_stream_ctx_switching() (line 7508) use direct [] access and assert node in self.node_to_stream, which will KeyError/AssertionError for any fused node.

Similarly, create_combo_kernel_nodes() and process_grouped_nodes() create new node objects post-fusion that won't have entries in node_to_stream.

The codegen loop at line 7282 iterates over self.nodes (post-fusion) and calls generate_stream_ctx_switching(node), so fused nodes will hit this path. You need to propagate stream assignments to newly created fused/grouped nodes — e.g., in fuse_two_nodes(), assert both nodes have the same stream (since can_fuse blocks cross-stream fusion) and assign the fused node the same stream index.

Fix this →


2. num_streams calculation is fragile

At line 7272-7274:

unique_streams = OrderedSet(self.node_to_stream.values())
num_streams = max(unique_streams) + 1 if unique_streams else 1

This computes num_streams as max(stream_idx) + 1, which assumes stream indices are dense and contiguous. Since user_obj_to_stream_idx assigns indices incrementally (1, 2, 3...) this works in practice, but it's worth noting that if nodes get filtered/removed between _populate_stream_assignments and codegen, you could allocate more streams than needed (not incorrect, just wasteful). This is minor and acceptable for now.


3. Origin iteration takes only the first stream annotation

At lines 3315-3326, when iterating over FX node origins, you break on the first origin that has a custom.stream annotation:

for fx_node in origins:
    ...
    if "stream" in custom_meta:
        ...
        break

If a scheduler node has multiple origins with different stream annotations, this silently picks one. This could happen if an operation's inputs come from different streams. A diagnostic warning or assertion that all origins agree on the stream assignment would make debugging easier.


4. _has_multi_stream_nodes() method vs direct attribute access

_has_multi_stream_nodes() (line 3340) is a one-liner returning self._multi_stream_nodes. It's called in three hot paths (can_fuse, device guard enter, stream switching). Per the CLAUDE.md guidelines about trivial helpers, this could just be the attribute directly, though having it as a method does provide a consistent interface if the logic changes later. This is minor — just calling it out per the repo style guidelines.


5. Stream context exit ordering in device guard cleanup

At lines 7261-7263:

# Exit stream context before exiting device guard
if self.current_stream_idx is not None:
    self.generate_stream_ctx_exit()

This is correct — you need to exit the stream context before exiting the device guard. But there's no corresponding stream context exit when the codegen loop finishes (after the last node). If the last node is on a non-default stream, does the stream context get properly closed? I'd check whether there's cleanup logic at the end of the codegen method that handles this.


Summary: The main concern is issue #1 — fused nodes won't be in node_to_stream and will cause runtime errors during codegen when multi-stream nodes are present. The fusion-blocking logic in can_fuse() is correct, but the downstream codegen paths need the fused nodes to also have stream assignments. The rest of the changes look reasonable.

--- |

self.buff_to_stream[buf] = stream_idx

# Check if we have any nodes on non-default streams
self._multi_stream_nodes = any(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._multi_stream_nodes = next_stream_idx > 1 ?

mlazos added a commit that referenced this pull request Mar 11, 2026
Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

ghstack-source-id: 7486432
Pull Request resolved: #174223
…heduler"

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the ciflow/torchtitan Run TorchTitan integration tests label Mar 11, 2026
mlazos added a commit that referenced this pull request Mar 11, 2026
Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

ghstack-source-id: 7486432
Pull Request resolved: #174223
…heduler"

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
sandy-gags pushed a commit to sandy-gags/pytorch that referenced this pull request Mar 12, 2026
Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

ghstack-source-id: 4ec5175
Pull Request resolved: pytorch/pytorch#174223
…heduler"

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
@mlazos mlazos added release notes: inductor ciflow/trunk Trigger trunk jobs on your pull request labels Mar 12, 2026
@mlazos
Copy link
Contributor Author

mlazos commented Mar 13, 2026

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Mar 13, 2026
Add comprehensive tests for user stream support in inductor:

- Stream utility tests (pool, context manager, naming)
- Event factory tests (creation, ordering, hashing)
- Wrapper codegen tests (stream context enter/exit)
- Compile tests for stream semantics preservation

The compile tests verify that torch.compile() correctly handles:
- Stream context managers
- Event record/wait operations
- Multi-stream synchronization patterns
- Fusion behavior within and across streams

Test assertions check for generated code patterns that may appear as
either custom ops (record_event/wait_event) or method calls.

Pull Request resolved: #174224
Approved by: https://github.com/shunting314
ghstack dependencies: #165505, #174223
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/torchtitan Run TorchTitan integration tests ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor release notes: inductor

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants