Skip to content

Softmax/LogSoftMax refactor (wrapped up)#3245

Merged
soumith merged 2 commits intomasterfrom
softmax_refactor
Oct 25, 2017
Merged

Softmax/LogSoftMax refactor (wrapped up)#3245
soumith merged 2 commits intomasterfrom
softmax_refactor

Conversation

@apaszke
Copy link
Contributor

@apaszke apaszke commented Oct 23, 2017

These commits wrap up the previous Softmax refactor. All the important changes are on the CUDA side. Once I unified the code I also added a special instantiation that made the kernels faster in certain cases (small inner dim, large softmax dim - might be useful in NLP for short sequences?)

tl;dr CUDA Softmax now supports a dim argument, and is usually 4x-256x faster than the previous implementation (it didn't have a spatial implementation before). Now, it also shares kernels with LogSoftmax, and certain optimizations benefited the log case giving up to 64x speedup in certain cases as well.


Here are the plots that show old / new timing ratios for different sizes of dim and size of the innermost dimensions (on the left of dim). Paralellizing over batch is easy, so it is fixed as 64 in all plots. Red dots are better, blue are regressions. Note that the plot is log in all axis (so z of 8 means 2^8x faster)

Softmax

Benefits from this diff all over the place. The old kernel was written in a quite archaic way.

softmax

LogSoftmax

Benefits from adding a custom kernel for the cases when inner_size is no longer 1, so we can't use the super fast kernel, but the dim_size is large, so using a single thread to reduce values is slow. It is only enabled for a subset of the space where it provided speedups.

log

Overall times

These are the log plots (in all axes) of the time (no more ratios) for the new algorithm. Softmax on the left, LogSoftmax on the right:

times

@apaszke
Copy link
Contributor Author

apaszke commented Oct 23, 2017

I forgot that the z axis is also log in all plots, so 2-8x speedup is really 4-256x. I've updated the description

Added a new instantiation of the spatial kernel for
low inner_size and larger dim_size.
Copy link
Contributor

@killeent killeent left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me mostly, I didn't check the validity of the algorithm in any way.

ReduceOp<T> r;
shared += threadIdx.y * blockDim.x;

__syncthreads();

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.



template <typename T, typename AccumT>
struct MaxFloat

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.


template <template<typename> class Reduction, typename AccumT>
__device__ __forceinline__ AccumT
blockReduce(AccumT* smem, AccumT val,

This comment was marked as off-topic.

This comment was marked as off-topic.

int last = size % (ILP * blockDim.x);

// Body (unroll by ILP times)
for (; offset < size - last; offset += blockDim.x * ILP) {

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Oct 24, 2017

@apaszke, do you want to wait until someone reviews the math, or merge it sooner?

@apaszke
Copy link
Contributor Author

apaszke commented Oct 24, 2017

No preference. If anyone is up to reviewing it then I'll wait, otherwise there's no point. I can add more reference functions to make sure it's all ok

@soumith soumith merged commit b3642b3 into master Oct 25, 2017
@soumith soumith deleted the softmax_refactor branch November 21, 2017 18:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants