Skip to content

[Inductor] support masked vectorization for the tail_loop of the 2d tiles kernel#130724

Closed
jiayisunx wants to merge 17 commits intogh/jiayisunx/14/basefrom
gh/jiayisunx/14/head
Closed

[Inductor] support masked vectorization for the tail_loop of the 2d tiles kernel#130724
jiayisunx wants to merge 17 commits intogh/jiayisunx/14/basefrom
gh/jiayisunx/14/head

Conversation

@jiayisunx
Copy link
Copy Markdown
Collaborator

@jiayisunx jiayisunx commented Jul 15, 2024

Stack from ghstack (oldest at bottom):

This PR supports masked vectorization for the tail_loop of the 2d tiles kernel to improve the performance.

Example:

import torch

def fn(a):
    return torch.permute(a, (2, 0, 1)).contiguous()

input = torch.randn(2, 20, 40)
compiled_fn = torch.compile(fn)

with torch.no_grad():
    for _ in range(3):
        compiled_fn(input)

Generated code:

  • Before:
cpp_fused_clone_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_jiayisun/z2/cz2ry4ghylembzwx7hkbanur76fi3mkiu7s6jm3zdi2amy5egq4b.h"
extern "C"  void kernel(const float* in_ptr0,
                       float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(16L))
        {
            #pragma GCC ivdep
            for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(16L))
            {
                float tmp0[16*16] __attribute__ ((aligned (16)));
                at::vec::transpose_mxn<float,16,16>(in_ptr0 + static_cast<long>(x0 + (40L*x1)), static_cast<long>(40L), tmp0, 16);
                for (long x0_inner = 0; x0_inner < 16; x0_inner++)
                {
                    auto tmp1 = at::vec::Vectorized<float>::loadu(tmp0 + static_cast<long>(16L*x0_inner), 16);
                    tmp1.store(out_ptr0 + static_cast<long>(x1 + (40L*x0) + (40L*x0_inner)));
                }
            }
            #pragma GCC ivdep
            for(long x1=static_cast<long>(32L); x1<static_cast<long>(40L); x1+=static_cast<long>(1L))
            {
                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x0 + (40L*x1)), 16);
                [&]
                {
                    __at_align__ std::array<float, 16> tmpbuf;
                    tmp0.store(tmpbuf.data(), 16);
                    #pragma GCC unroll 16
                    for (long x0_inner = 0; x0_inner < 16; x0_inner++)
                    {
                        out_ptr0[static_cast<long>(x1 + (40L*x0) + (40L*x0_inner))] = tmpbuf[x0_inner];
                    }
                }
                ()
                ;
            }
        }
        #pragma GCC ivdep
        for(long x0=static_cast<long>(32L); x0<static_cast<long>(40L); x0+=static_cast<long>(1L))
        {
            #pragma GCC ivdep
            for(long x1=static_cast<long>(0L); x1<static_cast<long>(40L); x1+=static_cast<long>(1L))
            {
                auto tmp0 = in_ptr0[static_cast<long>(x0 + (40L*x1))];
                out_ptr0[static_cast<long>(x1 + (40L*x0))] = tmp0;
            }
        }
    }
}
''')


async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (2, 20, 40), (800, 40, 1))
    buf0 = empty_strided_cpu((40, 2, 20), (40, 20, 1), torch.float32)
    cpp_fused_clone_0(arg0_1, buf0)
    del arg0_1
    return (buf0, )
  • After:
cpp_fused_clone_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_jiayisun/z2/cz2ry4ghylembzwx7hkbanur76fi3mkiu7s6jm3zdi2amy5egq4b.h"
extern "C"  void kernel(const float* in_ptr0,
                       float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(16L))
        {
            #pragma GCC ivdep
            for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(16L))
            {
                float tmp0[16*16] __attribute__ ((aligned (16)));
                at::vec::transpose_mxn<float,16,16>(in_ptr0 + static_cast<long>(x0 + (40L*x1)), static_cast<long>(40L), tmp0, 16);
                for (long x0_inner = 0; x0_inner < 16; x0_inner++)
                {
                    auto tmp1 = at::vec::Vectorized<float>::loadu(tmp0 + static_cast<long>(16L*x0_inner), 16);
                    tmp1.store(out_ptr0 + static_cast<long>(x1 + (40L*x0) + (40L*x0_inner)));
                }
            }
            #pragma GCC ivdep
            for(long x1=static_cast<long>(32L); x1<static_cast<long>(40L); x1+=static_cast<long>(8L))
            {
                float tmp0[16*8] __attribute__ ((aligned (16)));
                at::vec::transpose_mxn<float,8,16>(in_ptr0 + static_cast<long>(x0 + (40L*x1)), static_cast<long>(40L), tmp0, 8);
                for (long x0_inner = 0; x0_inner < 16; x0_inner++)
                {
                    auto tmp1 = at::vec::Vectorized<float>::loadu(tmp0 + static_cast<long>(8L*x0_inner), 8);
                    tmp1.store(out_ptr0 + static_cast<long>(x1 + (40L*x0) + (40L*x0_inner)), 8);
                }
            }
        }
        #pragma GCC ivdep
        for(long x0=static_cast<long>(32L); x0<static_cast<long>(40L); x0+=static_cast<long>(8L))
        {
            #pragma GCC ivdep
            for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(16L))
            {
                float tmp0[8*16] __attribute__ ((aligned (16)));
                at::vec::transpose_mxn<float,16,8>(in_ptr0 + static_cast<long>(x0 + (40L*x1)), static_cast<long>(40L), tmp0, 16);
                for (long x0_inner = 0; x0_inner < 8; x0_inner++)
                {
                    auto tmp1 = at::vec::Vectorized<float>::loadu(tmp0 + static_cast<long>(16L*x0_inner), 16);
                    tmp1.store(out_ptr0 + static_cast<long>(x1 + (40L*x0) + (40L*x0_inner)));
                }
            }
            #pragma GCC ivdep
            for(long x1=static_cast<long>(32L); x1<static_cast<long>(40L); x1+=static_cast<long>(8L))
            {
                float tmp0[8*8] __attribute__ ((aligned (16)));
                at::vec::transpose_mxn<float,8,8>(in_ptr0 + static_cast<long>(x0 + (40L*x1)), static_cast<long>(40L), tmp0, 8);
                for (long x0_inner = 0; x0_inner < 8; x0_inner++)
                {
                    auto tmp1 = at::vec::Vectorized<float>::loadu(tmp0 + static_cast<long>(8L*x0_inner), 8);
                    tmp1.store(out_ptr0 + static_cast<long>(x1 + (40L*x0) + (40L*x0_inner)), 8);
                }
            }
        }
    }
}
''')


async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (2, 20, 40), (800, 40, 1))
    buf0 = empty_strided_cpu((40, 2, 20), (40, 20, 1), torch.float32)
    cpp_fused_clone_0(arg0_1, buf0)
    del arg0_1
    return (buf0, )

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jul 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130724

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f81ff80 with merge base dc00eeb (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@pytorch-bot pytorch-bot Bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Jul 18, 2024
[ghstack-poisoned]
@jiayisunx jiayisunx marked this pull request as ready for review July 18, 2024 09:43
Comment thread torch/_inductor/codegen/cpp.py
[ghstack-poisoned]
@jiayisunx jiayisunx requested a review from jgong5 July 19, 2024 07:06
Comment thread torch/_inductor/codegen/cpp.py
Comment thread torch/_inductor/codegen/cpp.py
Comment thread torch/_inductor/codegen/cpp.py Outdated
)
inner_main_loop.set_kernel(tile2d_kernel)
inner_tail_loop.set_kernel(vec_kernel)
if could_masked_vec:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we combine the logic under two if could_masked_vec into one?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Comment thread torch/_inductor/codegen/cpp.py Outdated
inner_tail_loop_of_outer_tail_loop.steps,
outer_tail_loop.steps,
)
inner_tail_loop_of_outer_tail_loop.set_kernel(masked_tile2d_kernel3)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simplify the code with loops or functions to avoid dup?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are three types of tail loops for the 2D tiles kernel: inner_tail_loop(inner_tail_loop of outer main loop) , inner_main_loop_of_outer_tail_loop, inner_tail_loop_of_outer_tail_loop. These three types of tail loops require codegen_kernel and set_kernel respectively.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean you can perhaps further simplify the code. This code pattern now looks like an unrolled loop...

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your advice, I have simplified the code.

input[i] = _mm512_setzero_ps();
}

// unpacking and interleaving 32-bit elements
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic below is same as that in transpose_mxn<float, 16, 16>. Can we factor out a common function to share and avoid dup?

Copy link
Copy Markdown
Collaborator Author

@jiayisunx jiayisunx Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I factored out a common function and removed this template specialization function(transpose_mxn<float, 16, 16>).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you going to support avx2 as well?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

[ghstack-poisoned]
@jiayisunx jiayisunx requested a review from jgong5 July 22, 2024 10:49
[ghstack-poisoned]
typename std::enable_if_t<std::is_same<T, Half>::value && ((M < 32 && M != 16) || (N < 32 && N != 16)), int> = 0>
inline void transpose_mxn(const Half* src, int64_t ld_src, Half* dst, int64_t ld_dst) {
// load from src
__mmask32 src_mask = (1 << N) - 1;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like same implementation between Half and BFloat16. Can we move it into a common util function?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We must specialize this function here to ensure the priority of this function.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can still provide a util function as
template <typename T, typename std::enable_if_t<std::is_same_v<T, Half> || std::is_same_v<T, BF16>>> transpose_mxn_32_32(const T* src, int64_t ld_src, T* dst, int64_t ld_dst)
then reuse this util function for BF16 and FP16, it can save lines of code.

Comment thread torch/_inductor/codegen/cpp.py
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@jiayisunx jiayisunx requested a review from jansel August 12, 2024 00:56
@jiayisunx
Copy link
Copy Markdown
Collaborator Author

@jansel , could you please review this PR? Thanks!

@jiayisunx
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@jiayisunx
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Aug 13, 2024
…tatype (#131155)

This PR supports masked vectorization for the tail_loop for torch.uint8 and torch.int8 datatype to improve performance.
BTW, I fixed the UT of `byte` by setting the range of the sample inputs  to [0, 255] since the range of `torch.uint8` is [0, 255].

Pull Request resolved: #131155
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
ghstack dependencies: #130724
@github-actions github-actions Bot deleted the gh/jiayisunx/14/head branch September 12, 2024 02:03
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 26, 2024
…tatype (pytorch#131155)

This PR supports masked vectorization for the tail_loop for torch.uint8 and torch.int8 datatype to improve performance.
BTW, I fixed the UT of `byte` by setting the range of the sample inputs  to [0, 255] since the range of `torch.uint8` is [0, 255].

Pull Request resolved: pytorch#131155
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
ghstack dependencies: pytorch#130724
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: inductor open source topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants