Skip to content

[CUDA][Green Context] Expose green context streams#171116

Closed
eqy wants to merge 14 commits intopytorch:mainfrom
eqy:greenstreamexposed
Closed

[CUDA][Green Context] Expose green context streams#171116
eqy wants to merge 14 commits intopytorch:mainfrom
eqy:greenstreamexposed

Conversation

@eqy
Copy link
Copy Markdown
Collaborator

@eqy eqy commented Dec 22, 2025

Also uses a non-default stream in the green context as passing around a default (null) stream seems sketchy
set/pop-context APIs still use default stream

cc @ptrblck @msaroufim @jerryzh168 @tinglvv @nWEIdia

@eqy eqy added module: cuda Related to torch.cuda, and CUDA support in general open source labels Dec 22, 2025
@eqy eqy requested a review from syed-ahmed as a code owner December 22, 2025 19:10
@eqy eqy added the release notes: cuda release notes category label Dec 22, 2025
@eqy eqy requested a review from Aidyn-A as a code owner December 22, 2025 19:10
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Dec 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/171116

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (4 Unrelated Failures)

As of commit 38dca67 with merge base 5e30b70 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/b200 ciflow/h100 ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Dec 22, 2025
@@ -97,6 +100,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should ifdef the entire move ctor and have these all be member initializers. Also that way the TORCH_CHECK error would properly give a stack trace instead of just terminating.

CUgreenCtx green_ctx_ = nullptr;
CUcontext context_ = nullptr;
cudaStream_t parent_stream_ = nullptr;
CUstream green_ctx_stream_;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this also be nullptr initialized?

auto default_stream = c10::cuda::getDefaultCUDAStream();
ev.block(default_stream);
c10::cuda::setCurrentCUDAStream(default_stream);
auto green_ctx_stream = c10::cuda::getStreamFromExternal(green_ctx_stream_, device_id_);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we create greem_ctx_stream_ as CUDAStream so you can directly use it here? nbd if no

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can just revert to using the default stream in this case given below comment, no real reason this has to be the same stream as returned by Stream()


CUDAStream GreenContext::Stream() {
#if HAS_CUDA_GREEN_CONTEXT()
return c10::cuda::getStreamFromExternal(green_ctx_stream_, device_id_);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh this limits users to just one stream per green context? People are used to writing s1 = torch.cuda.Stream(); s2 = torch.cuda.Stream(), if ctx.Stream() has drastically different behavior this will be confusing. Also is there a real reason for this?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so if we make the tracking user responsibility

CUstream green_ctx_side_stream;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxStreamCreate_(
&green_ctx_side_stream, green_ctx_, CU_STREAM_NON_BLOCKING, 0));
// implies we leak side-streams, but this has precedent in e.g., c10/cuda/CUDAStream.cpp
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really, CUDAStream.cpp creates fixed number of streams, getStreamFromExternal implies that external libraries that created the stream can also destroy it, but here we can potentially create and leak an unbounded number of streams, because it's very common to have code that just creates and "destroys" streams like no tomorrow.
Can we instead go CUDAStream.cpp route, precreate a fixed number of streams and dole them out as needed?

@eqy
Copy link
Copy Markdown
Collaborator Author

eqy commented Jan 5, 2026

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Successfully rebased greenstreamexposed onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout greenstreamexposed && git pull --rebase)

CUcontext context_ = nullptr;
cudaStream_t parent_stream_ = nullptr;
std::array<CUstream, kStreamPerGreenContextPool> green_ctx_streams_;
int32_t curr_stream_idx_ = -1;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs to be atomic?

@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 5, 2026
@eqy
Copy link
Copy Markdown
Collaborator Author

eqy commented Jan 5, 2026

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 5, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Successfully rebased greenstreamexposed onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout greenstreamexposed && git pull --rebase)

@eqy
Copy link
Copy Markdown
Collaborator Author

eqy commented Jan 12, 2026

@pytorchmergebot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: Meta Internal-Only Changes Check

Details for Dev Infra team Raised by workflow job

@eqy
Copy link
Copy Markdown
Collaborator Author

eqy commented Jan 12, 2026

@pytorchmergebot merge -i

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Jan 12, 2026
…71116)"

This reverts commit 5ecb35e.

Reverted pytorch#171116 on behalf of https://github.com/jeanschmidt due to breaks internal builds, see D90148243 ([comment](pytorch#171116 (comment)))
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@ngimel
Copy link
Copy Markdown
Collaborator

ngimel commented Jan 13, 2026

@pytorchbot merge -f "merge keeps timing out"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/b200 ciflow/h100 ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/trunk Trigger trunk jobs on your pull request Merged module: cuda Related to torch.cuda, and CUDA support in general open source release notes: cuda release notes category Reverted triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants