Skip to content

Fix CUDA stream synchronization in sampler logprobs extraction#20064

Draft
chanh wants to merge 1 commit intosgl-project:mainfrom
chanh:cnguyen/fix-sampler-memsync
Draft

Fix CUDA stream synchronization in sampler logprobs extraction#20064
chanh wants to merge 1 commit intosgl-project:mainfrom
chanh:cnguyen/fix-sampler-memsync

Conversation

@chanh
Copy link
Copy Markdown
Contributor

@chanh chanh commented Mar 6, 2026

Motivation

get_token_ids_logprobs_batch_optimized was creating tensors directly on GPU and calling torch.repeat_interleave with a GPU tensor, both of which force a cudaStreamSynchronize — a known PyTorch issue
(pytorch/pytorch#108968). This stalls the GPU, preventing it from executing kernels concurrently during the sampler phase.

Modifications

  • Compute lengths and flatten token IDs as Python lists on CPU
  • Create tensors on CPU, then transfer to GPU with non_blocking=True
  • This eliminates the implicit device sync and allows the GPU to continue executing kernels without stalling

Accuracy Tests

No model output changes — this is a performance fix only.

Benchmarking and Profiling

Profiling confirms the cudaStreamSynchronize call is removed after this change.

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

`get_token_ids_logprobs_batch_optimized` was creating tensors directly
on GPU and using `torch.repeat_interleave` with a GPU tensor, both of
which force a `cudaStreamSynchronize` — a known issue
(pytorch/pytorch#108968).

Fix: compute lengths and flatten token IDs as Python lists on CPU,
create tensors on CPU, then transfer to GPU with `non_blocking=True`.
This eliminates the device sync and allows the GPU to continue
executing kernels without stalling.

Profiling confirms the cuda sync is removed after this change.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Copy link
Copy Markdown
Collaborator

@Qiaolin-Yu Qiaolin-Yu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to do similar things like

def get_token_ids_logprobs(logprobs, token_ids_logprobs, no_copy_to_cpu=False):

Adding no_copy_to_cpu here and do the actual gpu-> cpu transfer in

def copy_to_cpu(self, return_logprob: bool):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants