Speedup bernoulli_scalar_cuda_kernel with grid-stride loop#20626
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop#20626syed-ahmed wants to merge 20 commits intogh/syed-ahmed/7/basefrom
Conversation
|
After these PRs, will the CUDA RNGs be thread and stream-safe in PyTorch? Does the |
| if (std::is_same<scalar_t, double>::value) { | ||
| distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter, | ||
| gen, | ||
| [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double(state); }, |
There was a problem hiding this comment.
so @ngimel told me that curand's uniform with double is actually a lie because there is not enough precision. I don't know if things have changed or not.
There was a problem hiding this comment.
How many bits of randomness does it provide? 53 bits is standard so that it generates all the rationals x / 2^53 for x in [0,1, ..., 2^53-1].
That doesn't generate all valid doubles in [0, 1), but basically nobody does that outside of toy programs.
There was a problem hiding this comment.
curand_uniform_double is a lie (uses 32 bits), curand_uniform2_double is not (uses 53 bits per value).
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
That is the intention. These PRs achieve thread and stream-safety by replacing curandStateMTGP with curandStatePhilox: #19508 (comment)
I checked locally and got the following ouput: So may be it fixed the issue? I didn't triage and find the source of that bug. May be the tensor iterator magic in the normal PR has fixed this. |
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Stack from ghstack:
Differential Revision: D15454046
Effective Bandwidth Benchmark
Float Type
Before:
After:
Double Type
Before:
After: