Remove curandStateMTGP32 usage#20886
Remove curandStateMTGP32 usage#20886syed-ahmed wants to merge 15 commits intogh/syed-ahmed/8/basefrom
Conversation
aten/src/THC/THCTensorRandom.cuh
Outdated
| @@ -286,6 +288,10 @@ sampleMultinomialWithReplacement(curandStateMtgp32* state, | |||
| // what, all block threads must participate in the curand_uniform | |||
| // call to update the generator state. | |||
There was a problem hiding this comment.
This comment is no longer valid (w/o mtgp, individual threads can participate in rng call)
| @@ -296,7 +302,8 @@ sampleMultinomialWithReplacement(curandStateMtgp32* state, | |||
| int sample = sampleBase + threadIdx.y; | |||
|
|
|||
| // All threads participate in this | |||
aten/src/THCUNN/generic/RReLU.cu
Outdated
| // each thread will utilize one random, however, since we have to use | ||
| // curand_uniform4 (See Note [Register spilling in curand call for CUDA < 10]), | ||
| // offset is 4. | ||
| uint64_t offset = gen->state.philox_seed_offset.fetch_add(4); |
There was a problem hiding this comment.
note that NUM_BLOCKS for most cases will be set to 64 (that's a poor choice, but for the next PR), so you'll have a grid-stride loop inside the kernel and generate multiple randoms, so adjust offset accordingly.
There was a problem hiding this comment.
updated offset calc with (numel / block_size * grid.x) * 4.
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
|
@syed-ahmed could you rebase this stack on master? (I can do it myself, but if I do you'll have to force update your own local branch pointer--let me know if you'd prefer me to do it) |
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
|
@ezyang rebased :). |
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
|
Sorry, you rebased on top of broken master. Once the breakage is reverted we'll need another rebase :/ |
|
A little more text in the PR description would have been appreciated for this poor reviewer ^^ |
| THArgCheck(THByteTensor_nElement(rng_state) == total_size, 1, "RNG state is wrong size"); | ||
| THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous"); | ||
| THCudaCheck(cudaMemcpy(THByteTensor_data(rng_state), gen->state.gen_states, | ||
| states_size, cudaMemcpyDeviceToHost)); |
There was a problem hiding this comment.
It might be a good idea to fill in this memory with deterministic garbage so if someone tries to use it (improperly) it won't be a random error
There was a problem hiding this comment.
Filled in the memory with -1 and verified locally that torch.cuda.get_rng_state() gives 255 in the first few elements.
| THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous"); | ||
|
|
||
| THCudaCheck(cudaMemcpy(gen->state.gen_states, THByteTensor_data(rng_state), | ||
| states_size, cudaMemcpyHostToDevice)); |
There was a problem hiding this comment.
Is this necessary? Since I made all the gen_states memory to have -1 in getRNGState, this function will just not affect that value. If I were to do cudaMemcpy or memset here, I need to allocate the gen_states (which I deleted i.e. the initializeGenerator function).
There was a problem hiding this comment.
You're right, please don't do that :) This can be kept as is (I just saw something that looked similar to the previous pattern.)
| template <typename T> | ||
| __global__ void | ||
| sampleMultinomialWithReplacement(curandStateMtgp32* state, | ||
| sampleMultinomialWithReplacement(std::pair<uint64_t, uint64_t> seeds, |
There was a problem hiding this comment.
To be fair, the second element of this pair isn't really a seed, it's an offset, right?
There was a problem hiding this comment.
That's true. Little hand wavy here I agree. But you could interpret it as, since seed decides where a rng sequence starts from, the offset just gives a finer control over it for the philox engine. So seed for philox could be an umbrella term for the actual seed value + offset 🤷♂️ . If you want I can change the name (seed_and_offset maybe?), but then we should be changing the variable name every where and use it like seed_and_offset.first, seed_and_offset.second.
There was a problem hiding this comment.
OK, if you like it, let's keep it :)
|
How can I tell if the offset calculations were done right? Do tests cover this at all? It seems very fiddly. |
The philox offset calculation for RRelu.cu should be good, since it runs the exact same way as the kernels tested in |
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
|
Sorry, we need another rebase; master was a disaster yesterday. |
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Stack from ghstack:
Differential Revision: D15535503
Summary:
This PR removes curandStateMTGP32 usages since it's not stream-safe.
Main changes are: