Softmax/LogSoftMax refactor (wrapped up)#3245
Conversation
|
I forgot that the z axis is also log in all plots, so 2-8x speedup is really 4-256x. I've updated the description |
fff05a8 to
4427327
Compare
Added a new instantiation of the spatial kernel for low inner_size and larger dim_size.
4427327 to
866f51e
Compare
killeent
left a comment
There was a problem hiding this comment.
This looks good to me mostly, I didn't check the validity of the algorithm in any way.
| ReduceOp<T> r; | ||
| shared += threadIdx.y * blockDim.x; | ||
|
|
||
| __syncthreads(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
|
|
||
| template <typename T, typename AccumT> | ||
| struct MaxFloat |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| template <template<typename> class Reduction, typename AccumT> | ||
| __device__ __forceinline__ AccumT | ||
| blockReduce(AccumT* smem, AccumT val, |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| int last = size % (ILP * blockDim.x); | ||
|
|
||
| // Body (unroll by ILP times) | ||
| for (; offset < size - last; offset += blockDim.x * ILP) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@apaszke, do you want to wait until someone reviews the math, or merge it sooner? |
|
No preference. If anyone is up to reviewing it then I'll wait, otherwise there's no point. I can add more reference functions to make sure it's all ok |
These commits wrap up the previous Softmax refactor. All the important changes are on the CUDA side. Once I unified the code I also added a special instantiation that made the kernels faster in certain cases (small inner dim, large softmax dim - might be useful in NLP for short sequences?)
tl;dr CUDA Softmax now supports a
dimargument, and is usually 4x-256x faster than the previous implementation (it didn't have a spatial implementation before). Now, it also shares kernels with LogSoftmax, and certain optimizations benefited the log case giving up to 64x speedup in certain cases as well.Here are the plots that show old / new timing ratios for different sizes of
dimand size of the innermost dimensions (on the left of dim). Paralellizing over batch is easy, so it is fixed as 64 in all plots. Red dots are better, blue are regressions. Note that the plot is log in all axis (so z of 8 means 2^8x faster)Softmax
Benefits from this diff all over the place. The old kernel was written in a quite archaic way.
LogSoftmax
Benefits from adding a custom kernel for the cases when
inner_sizeis no longer 1, so we can't use the super fast kernel, but thedim_sizeis large, so using a single thread to reduce values is slow. It is only enabled for a subset of the space where it provided speedups.Overall times
These are the log plots (in all axes) of the time (no more ratios) for the new algorithm. Softmax on the left, LogSoftmax on the right: