Skip to content

[user-streams] Add stream support to scheduler#165505

Closed
mlazos wants to merge 23 commits intogh/mlazos/44/basefrom
gh/mlazos/44/head
Closed

[user-streams] Add stream support to scheduler#165505
mlazos wants to merge 23 commits intogh/mlazos/44/basefrom
gh/mlazos/44/head

Conversation

@mlazos
Copy link
Contributor

@mlazos mlazos commented Oct 15, 2025

Summary

Add CUDA stream context management to the inductor scheduler's codegen phase. This is the
foundational infrastructure for multi-stream code generation: it tracks which stream is currently
active and generates the appropriate with torch.cuda.stream(...) enter/exit wrappers around
kernels.

Key pieces:

  • _current_stream_ctx field tracks the active EnterCudaStreamContextLine during codegen
  • current_stream_idx / current_stream_name properties expose the current stream state (also used by
    the wrapper's nesting guard)
  • generate_stream_ctx_switching() is the main entry point called per-node during codegen — it
    compares the node's assigned stream to the current stream and emits enter/exit code only when
    switching is needed. NopKernelSchedulerNodes inherit the previous stream context since they
    generate no kernel code.

This commit adds the codegen-time stream management; the stream assignment logic (node_to_stream,
_populate_stream_assignments) and the codegen callsite are in the next commit in the stack.

Test plan

  • Tests are in the third commit of the stack (test/inductor/test_user_streams.py)

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @chenyang78

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165505

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit b9a8960 with merge base d2d12ef (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Dec 14, 2025
@github-actions github-actions bot closed this Jan 13, 2026
@mlazos mlazos reopened this Jan 13, 2026
mlazos added a commit that referenced this pull request Jan 20, 2026
ghstack-source-id: 971c8f5
Pull Request resolved: #165505
mlazos added a commit that referenced this pull request Jan 21, 2026
ghstack-source-id: 971c8f5
Pull Request resolved: #165505
@mlazos mlazos added release notes: inductor ciflow/trunk Trigger trunk jobs on your pull request and removed Stale labels Jan 21, 2026
mlazos added a commit that referenced this pull request Jan 21, 2026
ghstack-source-id: 971c8f5
Pull Request resolved: #165505
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo chenyang78

[ghstack-poisoned]
mlazos added a commit that referenced this pull request Jan 21, 2026
ghstack-source-id: 9553c75
Pull Request resolved: #165505
mlazos added a commit that referenced this pull request Jan 21, 2026
ghstack-source-id: 971c8f5
Pull Request resolved: #165505
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo chenyang78

[ghstack-poisoned]
mlazos added a commit that referenced this pull request Jan 22, 2026
ghstack-source-id: 66a91d4
Pull Request resolved: #165505
suncapitalllc007-star pushed a commit to suncapitalllc007-star/pytorch that referenced this pull request Jan 25, 2026
mlazos added a commit that referenced this pull request Feb 3, 2026
ghstack-source-id: 66a91d4
Pull Request resolved: #165505
mlazos added a commit that referenced this pull request Feb 3, 2026
ghstack-source-id: 66a91d4
Pull Request resolved: #165505
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo chenyang78

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo chenyang78

[ghstack-poisoned]
@mlazos mlazos requested a review from eellison March 6, 2026 18:07
mlazos added a commit that referenced this pull request Mar 6, 2026
ghstack-source-id: 6fd3c4d
Pull Request resolved: #165505
@mlazos mlazos closed this Mar 8, 2026
@mlazos mlazos reopened this Mar 9, 2026
@mlazos mlazos closed this Mar 9, 2026
@mlazos mlazos reopened this Mar 9, 2026
@desertfire
Copy link
Contributor

@claude, review this PR

@claude
Copy link

claude bot commented Mar 11, 2026

Claude finished @desertfire's task in 4m 28s —— View job


PR Review: [user-streams] Add stream support to scheduler

This PR adds stream context tracking and switching infrastructure to the Scheduler class. The changes are well-scoped: a _current_stream_ctx field, two properties, and three methods that handle entering/exiting/switching CUDA stream contexts during code generation.


Overall Assessment

The design is clean. The scheduler tracks only which stream context it's currently emitting into and generates with torch.cuda.stream(...) blocks accordingly, leaving cross-stream synchronization to user-explicit event operations. This aligns with the PR description's philosophy of not duplicating synchronization that users already express.


Issues

1. Typo: "plaining" should be "planning" (scheduler.py:7437)

# Don't generate ctx switching. Memory plaining code (e.g., delete buffers) on current

Should be "planning" (or perhaps "cleanup" would be more descriptive here).

Fix this →

2. node_to_stream is not defined in this PR (scheduler.py:7408, 7427, 7432)

self.node_to_stream is accessed but never initialized on the Scheduler class. There are three # pyrefly: ignore[missing-attribute] suppressions for it. This presumably comes from another PR in the ghstack (#176700 or #174223), but it does mean this PR cannot function standalone — every call to generate_stream_ctx_enter or generate_stream_ctx_switching will raise AttributeError. If this is intentional for the stack, a comment noting where node_to_stream is defined would help reviewers.

3. generate_stream_ctx_switching — the current_stream_idx == stream comparison when both are None (scheduler.py:7434)

When current_stream_idx is None and stream is None (a NopKernelSchedulerNode), the first branch current_stream_idx == stream evaluates to True and returns early. This is correct behavior (no-op nodes before any stream context shouldn't trigger switching), but it's a subtle case that falls through without an explicit comment. Worth verifying this is the intended semantic — a no-op node encountered before any stream is entered silently does nothing.


Minor Observations

  • The current_stream_name property (scheduler.py:7397-7402) is currently only consumed by codegen_cuda_stream_enter in wrapper.py:1684 for the nested-context error check and by wrapper.py:4157 for raw stream name generation. This is a reasonable coupling since the wrapper needs to know the scheduler's stream state to validate and generate correct code.

  • The generate_stream_ctx_exit method asserts self._current_stream_ctx is not None but doesn't check that the wrapper's indentation state is consistent. The wrapper's ExitCudaStreamContextLine.codegen calls code.do_unindent(), so if exit is ever called without a matching enter in the wrapper's buffer, this would produce malformed output. The assert on the scheduler side provides a reasonable guard, but the invariant spans two objects.

  • The conditional logic in generate_stream_ctx_switching handles four cases clearly. The structure could arguably be simplified (the elif current_stream_idx is not None and stream is None early return could be folded into the first == check since NopKernelSchedulerNode maps to None), but the explicit case breakdown is more readable as-is.


Summary

The PR looks good structurally. The main items to address:

  1. Fix the "plaining" → "planning" typo
  2. Consider adding a note about node_to_stream being defined in a dependent PR

No blocking concerns beyond the typo.

"""Code-gen to enter the Stream context assigned to node."""
assert not isinstance(node, NopKernelSchedulerNode)
# pyrefly: ignore[missing-attribute]
node_stream = self.node_to_stream[node]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have a PR ordering issue? Since node_to_stream is defined in the next PR. (Thanks Claude pointing this out)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah let me fix this

@desertfire
Copy link
Contributor

The PR's commit msg doesn't match its content.

mlazos added a commit that referenced this pull request Mar 11, 2026
ghstack-source-id: 6fd3c4d
Pull Request resolved: #165505
@pytorch-bot pytorch-bot bot added the ciflow/torchtitan Run TorchTitan integration tests label Mar 11, 2026
Summary

  Add CUDA stream context management to the inductor scheduler's codegen phase. This is the
  foundational infrastructure for multi-stream code generation: it tracks which stream is currently
  active and generates the appropriate with torch.cuda.stream(...) enter/exit wrappers around
  kernels.

  Key pieces:
  - _current_stream_ctx field tracks the active EnterCudaStreamContextLine during codegen
  - current_stream_idx / current_stream_name properties expose the current stream state (also used by
   the wrapper's nesting guard)
  - generate_stream_ctx_switching() is the main entry point called per-node during codegen — it
  compares the node's assigned stream to the current stream and emits enter/exit code only when
  switching is needed. NopKernelSchedulerNodes inherit the previous stream context since they
  generate no kernel code.

  This commit adds the codegen-time stream management; the stream assignment logic (node_to_stream,
  _populate_stream_assignments) and the codegen callsite are in the next commit in the stack.

  Test plan

  - Tests are in the third commit of the stack (test/inductor/test_user_streams.py)





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo chenyang78

[ghstack-poisoned]
mlazos added a commit that referenced this pull request Mar 11, 2026
ghstack-source-id: 6fd3c4d
Pull Request resolved: #165505
Summary

  Add CUDA stream context management to the inductor scheduler's codegen phase. This is the
  foundational infrastructure for multi-stream code generation: it tracks which stream is currently
  active and generates the appropriate with torch.cuda.stream(...) enter/exit wrappers around
  kernels.

  Key pieces:
  - _current_stream_ctx field tracks the active EnterCudaStreamContextLine during codegen
  - current_stream_idx / current_stream_name properties expose the current stream state (also used by
   the wrapper's nesting guard)
  - generate_stream_ctx_switching() is the main entry point called per-node during codegen — it
  compares the node's assigned stream to the current stream and emits enter/exit code only when
  switching is needed. NopKernelSchedulerNodes inherit the previous stream context since they
  generate no kernel code.

  This commit adds the codegen-time stream management; the stream assignment logic (node_to_stream,
  _populate_stream_assignments) and the codegen callsite are in the next commit in the stack.

  Test plan

  - Tests are in the third commit of the stack (test/inductor/test_user_streams.py)





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo chenyang78

[ghstack-poisoned]
@mlazos mlazos requested a review from shunting314 March 12, 2026 01:49
sandy-gags pushed a commit to sandy-gags/pytorch that referenced this pull request Mar 12, 2026
sandy-gags pushed a commit to sandy-gags/pytorch that referenced this pull request Mar 12, 2026
mlazos added a commit that referenced this pull request Mar 12, 2026
ghstack-source-id: 746cc10
Pull Request resolved: #165505
Summary

  Add CUDA stream context management to the inductor scheduler's codegen phase. This is the
  foundational infrastructure for multi-stream code generation: it tracks which stream is currently
  active and generates the appropriate with torch.cuda.stream(...) enter/exit wrappers around
  kernels.

  Key pieces:
  - _current_stream_ctx field tracks the active EnterCudaStreamContextLine during codegen
  - current_stream_idx / current_stream_name properties expose the current stream state (also used by
   the wrapper's nesting guard)
  - generate_stream_ctx_switching() is the main entry point called per-node during codegen — it
  compares the node's assigned stream to the current stream and emits enter/exit code only when
  switching is needed. NopKernelSchedulerNodes inherit the previous stream context since they
  generate no kernel code.

  This commit adds the codegen-time stream management; the stream assignment logic (node_to_stream,
  _populate_stream_assignments) and the codegen callsite are in the next commit in the stack.

  Test plan

  - Tests are in the third commit of the stack (test/inductor/test_user_streams.py)





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo chenyang78

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #174223

pytorchmergebot pushed a commit that referenced this pull request Mar 13, 2026
…74223)

Read the 'custom.stream' FX metadata to determine which stream each
scheduler node should run on. This enables the inductor scheduler to:

1. Populate stream assignments BEFORE fusion to prevent fusing nodes
   across stream boundaries
2. Check stream assignments in can_fuse() to block cross-stream fusion
3. Handle stream context switching in codegen with proper device guard
   nesting

The stream metadata is set by dynamo when tracing torch.cuda.stream()
context managers.

Pull Request resolved: #174223
Approved by: https://github.com/shunting314
ghstack dependencies: #165505
pytorchmergebot pushed a commit that referenced this pull request Mar 13, 2026
Add comprehensive tests for user stream support in inductor:

- Stream utility tests (pool, context manager, naming)
- Event factory tests (creation, ordering, hashing)
- Wrapper codegen tests (stream context enter/exit)
- Compile tests for stream semantics preservation

The compile tests verify that torch.compile() correctly handles:
- Stream context managers
- Event record/wait operations
- Multi-stream synchronization patterns
- Fusion behavior within and across streams

Test assertions check for generated code patterns that may appear as
either custom ops (record_event/wait_event) or method calls.

Pull Request resolved: #174224
Approved by: https://github.com/shunting314
ghstack dependencies: #165505, #174223
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/torchtitan Run TorchTitan integration tests ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor release notes: inductor

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants