Skip to content

THCReduce noncontigdim kernel improvements.#751

Closed
csarofeen wants to merge 1 commit intotorch:masterfrom
csarofeen:reduction
Closed

THCReduce noncontigdim kernel improvements.#751
csarofeen wants to merge 1 commit intotorch:masterfrom
csarofeen:reduction

Conversation

@csarofeen
Copy link
Contributor

THCReduce noncontigdim kernel improvements. Added extra kernel and heuristics to improve smaller tensor reductions.

…uristics to improve smaller tensor reductions.
@ngimel
Copy link

ngimel commented Apr 18, 2017

@soumith, @apaszke This would allow to switch to expand_as/sum (instead of addr/gemv) when adding bias in linear functions, with performance gains, esp for smaller linear sizes.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel how does it helps with bias addition? The output of mm and bias are always contiguous, and this PR only changes noncontig kernels.

for (IndexType i = 0; i < reductionSize; ++i) {
r = reduceOp(r, modifyOp(in.data[inOffset]));
inOffset += reductionStride;
__syncthreads();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is unfortunately necessary. We're trying to prevent warps from getting too far ahead which will have negative effects on the memory system.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can they get too far from each other? If the ops have uneven branches?

}else{
//x dim does different slices
//y dim helps with a slice
//If we only have 8 loops, don't bother sharing work across ydim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be 16 loops?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll fix this comment.

if (!getNoncontigReduceGrid(outElements, grid)) {
return false;

//If there are a large number of outputs to the reduction, avoid syncthreads
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both kernels have syncthreads right now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll fix this comment.

long gridx = THCCeilDiv( outElements, (long)block.x);
if (gridx > 1024){
long n_loops = THCCeilDiv(outElements, (long) (1024 * block.x) );
gridx = outElements / (block.x*n_loops);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this is ok? If you remove the ceil it is equivalent to setting gridx to 1024.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will review this again to make sure it is correct. It's mainly for load balancing the internal slice loop.

__device__ __forceinline__ IndexType getReduceNoncontigDimSliceIndex() {
// Each thread handles one slice
return getLinearBlockId<IndexType>() * THC_NONCONTIG_REDUCE_BLOCK_SIZE + threadIdx.x;
#define LOCAL_MAX_BLOCK_SIZE 512
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this constant is used for shared mem size, but is not used when computing the block size. Is that ok?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/csarofeen/cutorch/blob/master/lib/THC/THCReduce.cuh#L239-L244
Ensures block size = 512, was a little bit of a misnomer as I enforced 512 instead of having it as a max.

Copy link
Contributor

@apaszke apaszke Apr 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it enforces it, but I think it would be better to use the constant in both places. Otherwise there's no point in separating it from the code, because it can get out of sync

*shmem = reduceOp(*shmem, *(shmem + blockDim.x * i) );
}
out.data[outOffset] = *shmem;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this just limited to groupID == 0 ? Wouldn't reducing to half the groups at each step be faster ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be, I could actually try as I forced blockdim.y to be a multiple of 2 so the logic shouldn't be too bad. Will check.

@fmassa
Copy link
Contributor

fmassa commented Apr 19, 2017

@apaszke I think @ngimel meant that we could use expand + add instead of addr in the forward of Linear, and sum the gradOutputs in the backward (instead of gemv). And it also avoids creating and resizing and filling with 1s at every iteration the add_buffer. I think this is related to this discussion in the slack

@apaszke
Copy link
Contributor

apaszke commented Apr 19, 2017

@fmassa I know what's the deal with expand+add vs fill+addr, I'm just asking how is it related to this change. I don't know why I thought that expanded tensors are contiguous, nvm.

Copy link
Contributor

@killeent killeent left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the test plan for this? Do we have some benchmarking that shows this is faster?

T init,
ModifyOp modifyOp,
ReduceOp reduceOp) {
IndexType threadLane = threadIdx.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

threadLane seems like a bit of a misnomer here. I'm not sure how this corresponds to the lane in the warp.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're correct, was a remnant from when I was using a 1-D block. Will name it something more appropriate.

IndexType threadLane = threadIdx.x;
IndexType groupID = threadIdx.y;
IndexType sliceIndex = blockIdx.x * blockDim.x + threadLane;
IndexType sliceStride = gridDim.x * blockDim.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, sliceStride is a bit confusing - this is actually the stride with which to get the next slice for reduction, but the variable name makes it sound like the stride for elements within a slice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a suggestion on this name?

IndexType stride = reductionStride * blockDim.y;

for(IndexType i=groupID; i<reductionSize; i+=blockDim.y){
(*shmem) = reduceOp(*shmem, modifyOp(in.data[inOffset]) );
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure exactly how this works. I could be wrong, but aren't we hitting shared memory every time here? If we want different threads in the "group" to reduce things in registers wouldn't we need a local variable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will check, but compiler tends to optimize it to registers (this is why there's a shared mem volatile flag).

@csarofeen
Copy link
Contributor Author

https://gist.github.com/csarofeen/80e8e567d49e3a2511d6bcd7bd891a98
can be used for benchmarking linear improvements.
@apaszke Tensor is contiguous but reduction is on non-contiguous dimension.

@csarofeen csarofeen closed this Apr 25, 2017
@csarofeen
Copy link
Contributor Author

Still working on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants