There are a number of unnecessary cuda synchronizations in PyTorch ops, and I think we should endeavor to remove them whenever possible.
To check syncs, you can use torch.cuda.set_sync_debug_mode("warn")
I'm creating this issue to track ones that I've seen/found.
A = torch.rand(10)
torch.multinomial(A, num_samples=1)
A = torch.randn(3, device='cuda')
num_repeats = torch.tensor([2, 3, 5])
out = torch.repeat_interleave(A, num_repeats.cuda(), dim=0)
🚀 The feature, motivation and pitch
There are a number of unnecessary cuda synchronizations in PyTorch ops, and I think we should endeavor to remove them whenever possible.
To check syncs, you can use
torch.cuda.set_sync_debug_mode("warn")I'm creating this issue to track ones that I've seen/found.
num_samples=1. For this I think we should simply remove the error check causing the sync, and ideally turn it into a cuda async error. https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Distributions.cpp#L615repeatswith a non-cuda tensor, and that forces a synchronization. For this I think we should add a list of ints overload or allow passing a CPU tensor for repeats.Indexing with a scalar tensor performs a synchronization. See Turn indexing with a scalar tensor into an copy into a view and avoid a D2H synchronization. #105641 for more details.
torch.normalalso incurs a sync on std: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/DistributionTemplates.h#L222nanmedianincurs a sync: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/Sorting.cpp#L149prod_backward: torch.prod cannot be used with cudagraphs #128396
Alternatives
No response
Additional context
No response
cc @ptrblck