[Inductor] support masked vectorization for the tail_loop of the 2d tiles kernel#130724
[Inductor] support masked vectorization for the tail_loop of the 2d tiles kernel#130724jiayisunx wants to merge 17 commits intogh/jiayisunx/14/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130724
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f81ff80 with merge base dc00eeb ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| ) | ||
| inner_main_loop.set_kernel(tile2d_kernel) | ||
| inner_tail_loop.set_kernel(vec_kernel) | ||
| if could_masked_vec: |
There was a problem hiding this comment.
Can we combine the logic under two if could_masked_vec into one?
| inner_tail_loop_of_outer_tail_loop.steps, | ||
| outer_tail_loop.steps, | ||
| ) | ||
| inner_tail_loop_of_outer_tail_loop.set_kernel(masked_tile2d_kernel3) |
There was a problem hiding this comment.
Can we simplify the code with loops or functions to avoid dup?
There was a problem hiding this comment.
There are three types of tail loops for the 2D tiles kernel: inner_tail_loop(inner_tail_loop of outer main loop) , inner_main_loop_of_outer_tail_loop, inner_tail_loop_of_outer_tail_loop. These three types of tail loops require codegen_kernel and set_kernel respectively.
There was a problem hiding this comment.
I mean you can perhaps further simplify the code. This code pattern now looks like an unrolled loop...
There was a problem hiding this comment.
Thanks for your advice, I have simplified the code.
| input[i] = _mm512_setzero_ps(); | ||
| } | ||
|
|
||
| // unpacking and interleaving 32-bit elements |
There was a problem hiding this comment.
The logic below is same as that in transpose_mxn<float, 16, 16>. Can we factor out a common function to share and avoid dup?
There was a problem hiding this comment.
I factored out a common function and removed this template specialization function(transpose_mxn<float, 16, 16>).
There was a problem hiding this comment.
Are you going to support avx2 as well?
| typename std::enable_if_t<std::is_same<T, Half>::value && ((M < 32 && M != 16) || (N < 32 && N != 16)), int> = 0> | ||
| inline void transpose_mxn(const Half* src, int64_t ld_src, Half* dst, int64_t ld_dst) { | ||
| // load from src | ||
| __mmask32 src_mask = (1 << N) - 1; |
There was a problem hiding this comment.
Looks like same implementation between Half and BFloat16. Can we move it into a common util function?
There was a problem hiding this comment.
We must specialize this function here to ensure the priority of this function.
There was a problem hiding this comment.
We can still provide a util function as
template <typename T, typename std::enable_if_t<std::is_same_v<T, Half> || std::is_same_v<T, BF16>>> transpose_mxn_32_32(const T* src, int64_t ld_src, T* dst, int64_t ld_dst)
then reuse this util function for BF16 and FP16, it can save lines of code.
|
@jansel , could you please review this PR? Thanks! |
|
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…tatype (#131155) This PR supports masked vectorization for the tail_loop for torch.uint8 and torch.int8 datatype to improve performance. BTW, I fixed the UT of `byte` by setting the range of the sample inputs to [0, 255] since the range of `torch.uint8` is [0, 255]. Pull Request resolved: #131155 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel ghstack dependencies: #130724
…tatype (pytorch#131155) This PR supports masked vectorization for the tail_loop for torch.uint8 and torch.int8 datatype to improve performance. BTW, I fixed the UT of `byte` by setting the range of the sample inputs to [0, 255] since the range of `torch.uint8` is [0, 255]. Pull Request resolved: pytorch#131155 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel ghstack dependencies: pytorch#130724
Stack from ghstack (oldest at bottom):
This PR supports masked vectorization for the tail_loop of the 2d tiles kernel to improve the performance.
Example:
Generated code:
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang