Skip to content

[Inductor] support masked vectorization for the tail_loop#126526

Closed
jiayisunx wants to merge 48 commits intogh/jiayisunx/10/basefrom
gh/jiayisunx/10/head
Closed

[Inductor] support masked vectorization for the tail_loop#126526
jiayisunx wants to merge 48 commits intogh/jiayisunx/10/basefrom
gh/jiayisunx/10/head

Conversation

@jiayisunx
Copy link
Copy Markdown
Collaborator

@jiayisunx jiayisunx commented May 17, 2024

Stack from ghstack (oldest at bottom):

Currently the tail_loop always uses the scalar kernel. This PR supports masked vectorization for the tail_loop to improve the performance.

Example:

import torch
import torch.nn as nn

class GN(nn.Module):
    def __init__(self, num_groups, num_channels):
        super(GN, self).__init__()
        self.gn = nn.GroupNorm(num_groups, num_channels)

    def forward(self, x):
        return self.gn(x)

input = torch.randn(2, 960, 96, 96).to(memory_format=torch.channels_last)
m = GN(32, 960).eval()
compiled_m = torch.compile(m)

with torch.no_grad():
    for _ in range(3):
        compiled_m(input)

Generated code:

  • Before:
cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_jiayisun/ky/cky2bufythacofebk7ujv36e4pxyqcqbpsy5r4vojoprjiwcwfxf.h"
extern "C"  void kernel(const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0,
                       float* out_ptr1,
                       float* out_ptr2)
{
    #pragma omp parallel num_threads(112)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(1L))
                {
                    {
                        Welford<float> tmp_acc0 = Welford<float>();
                        Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                        static WeightRecp<at::vec::Vectorized<float>> weight_recps(static_cast<long>(17280L));
                        for(long x2=static_cast<long>(0L); x2<static_cast<long>(9216L); x2+=static_cast<long>(1L))
                        {
                            for(long x3=static_cast<long>(0L); x3<static_cast<long>(16L); x3+=static_cast<long>(16L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x3 + (30L*x1) + (960L*x2) + (8847360L*x0)), 16);
                                tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &weight_recps);
                            }
                            #pragma omp simd simdlen(8) 
                            for(long x3=static_cast<long>(16L); x3<static_cast<long>(30L); x3+=static_cast<long>(1L))
                            {
                                auto tmp0 = in_ptr0[static_cast<long>(x3 + (30L*x1) + (960L*x2) + (8847360L*x0))];
                                tmp_acc0 = welford_combine(tmp_acc0, tmp0);
                            }
                        }
                        tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec));
                        out_ptr0[static_cast<long>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.mean);
                        out_ptr1[static_cast<long>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.m2);
                    }
                }
            }
        }
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(9216L); x1+=static_cast<long>(1L))
                {
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(960L); x2+=static_cast<long>(16L))
                    {
                        auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x2 + (960L*x1) + (8847360L*x0)), 16);
                        auto tmp1 =
                        [&]
                        {
                            __at_align__ std::array<float, 16> tmpbuf;
                            #pragma GCC unroll 16
                            for (long x2_inner = 0; x2_inner < 16; x2_inner++)
                            {
                                tmpbuf[x2_inner] = out_ptr0[static_cast<long>((32L*x0) + (c10::div_floor_integer((x2 + x2_inner), 30L)))];
                            }
                            return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
                        }
                        ()
                        ;
                        auto tmp3 =
                        [&]
                        {
                            __at_align__ std::array<float, 16> tmpbuf;
                            #pragma GCC unroll 16
                            for (long x2_inner = 0; x2_inner < 16; x2_inner++)
                            {
                                tmpbuf[x2_inner] = out_ptr1[static_cast<long>((32L*x0) + (c10::div_floor_integer((x2 + x2_inner), 30L)))];
                            }
                            return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
                        }
                        ()
                        ;
                        auto tmp12 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x2), 16);
                        auto tmp14 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x2), 16);
                        auto tmp2 = tmp0 - tmp1;
                        auto tmp4 = static_cast<float>(276480.0);
                        auto tmp5 = at::vec::Vectorized<float>(tmp4);
                        auto tmp6 = tmp3 / tmp5;
                        auto tmp7 = static_cast<float>(1e-05);
                        auto tmp8 = at::vec::Vectorized<float>(tmp7);
                        auto tmp9 = tmp6 + tmp8;
                        auto tmp10 = tmp9.rsqrt();
                        auto tmp11 = tmp2 * tmp10;
                        auto tmp13 = tmp11 * tmp12;
                        auto tmp15 = tmp13 + tmp14;
                        tmp15.store(out_ptr2 + static_cast<long>(x2 + (960L*x1) + (8847360L*x0)));
                    }
                }
            }
        }
    }
}
''')


async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, arg1_1, arg2_1 = args
    args.clear()
    assert_size_stride(arg0_1, (960, ), (1, ))
    assert_size_stride(arg1_1, (960, ), (1, ))
    assert_size_stride(arg2_1, (2, 960, 96, 96), (8847360, 1, 92160, 960))
    buf0 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32)
    buf1 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32)
    buf3 = empty_strided_cpu((2, 960, 96, 96), (8847360, 1, 92160, 960), torch.float32)
    cpp_fused_native_group_norm_0(arg2_1, arg0_1, arg1_1, buf0, buf1, buf3)
    del arg0_1
    del arg1_1
    del arg2_1
    return (buf3, )
  • After:
cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*'], '''
#include "/tmp/torchinductor_jiayisun/em/cemtujj65j5txpqlxc7w4pcunpmvz3qtiudkc5ocxxhcmdlknw2m.h"
extern "C"  void kernel(const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0,
                       float* out_ptr1,
                       float* out_ptr2)
{
    #pragma omp parallel num_threads(112)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(1L))
                {
                    {
                        Welford<float> tmp_acc0 = Welford<float>();
                        Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                        Welford<at::vec::Vectorized<float>> masked_tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                        static WeightRecp<at::vec::Vectorized<float>> wrecps0(static_cast<long>(17280L));
                        for(long x2=static_cast<long>(0L); x2<static_cast<long>(9216L); x2+=static_cast<long>(1L))
                        {
                            for(long x3=static_cast<long>(0L); x3<static_cast<long>(16L); x3+=static_cast<long>(16L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x3 + (30L*x1) + (960L*x2) + (8847360L*x0)), 16);
                                tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &wrecps0);
                            }
                            for(long x3=static_cast<long>(16L); x3<static_cast<long>(30L); x3+=static_cast<long>(14L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x3 + (30L*x1) + (960L*x2) + (8847360L*x0)), 14);
                                masked_tmp_acc0_vec = welford_combine(masked_tmp_acc0_vec, tmp0, 14, &wrecps0);
                            }
                        }
                        tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec));
                        tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec));
                        out_ptr0[static_cast<long>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.mean);
                        out_ptr1[static_cast<long>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.m2);
                    }
                }
            }
        }
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(9216L); x1+=static_cast<long>(1L))
                {
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(960L); x2+=static_cast<long>(16L))
                    {
                        auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x2 + (960L*x1) + (8847360L*x0)), 16);
                        auto tmp1 =
                        [&]
                        {
                            __at_align__ std::array<float, 16> tmpbuf;
                            #pragma GCC unroll 16
                            for (long x2_inner = 0; x2_inner < 16; x2_inner++)
                            {
                                tmpbuf[x2_inner] = out_ptr0[static_cast<long>((32L*x0) + (c10::div_floor_integer((x2 + x2_inner), 30L)))];
                            }
                            return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
                        }
                        ()
                        ;
                        auto tmp3 =
                        [&]
                        {
                            __at_align__ std::array<float, 16> tmpbuf;
                            #pragma GCC unroll 16
                            for (long x2_inner = 0; x2_inner < 16; x2_inner++)
                            {
                                tmpbuf[x2_inner] = out_ptr1[static_cast<long>((32L*x0) + (c10::div_floor_integer((x2 + x2_inner), 30L)))];
                            }
                            return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
                        }
                        ()
                        ;
                        auto tmp12 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x2), 16);
                        auto tmp14 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x2), 16);
                        auto tmp2 = tmp0 - tmp1;
                        auto tmp4 = static_cast<float>(276480.0);
                        auto tmp5 = at::vec::Vectorized<float>(tmp4);
                        auto tmp6 = tmp3 / tmp5;
                        auto tmp7 = static_cast<float>(1e-05);
                        auto tmp8 = at::vec::Vectorized<float>(tmp7);
                        auto tmp9 = tmp6 + tmp8;
                        auto tmp10 = tmp9.rsqrt();
                        auto tmp11 = tmp2 * tmp10;
                        auto tmp13 = tmp11 * tmp12;
                        auto tmp15 = tmp13 + tmp14;
                        tmp15.store(out_ptr2 + static_cast<long>(x2 + (960L*x1) + (8847360L*x0)));
                    }
                }
            }
        }
    }
}
''')


async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, arg1_1, arg2_1 = args
    args.clear()
    assert_size_stride(arg0_1, (960, ), (1, ))
    assert_size_stride(arg1_1, (960, ), (1, ))
    assert_size_stride(arg2_1, (2, 960, 96, 96), (8847360, 1, 92160, 960))
    buf0 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32)
    buf1 = empty_strided_cpu((2, 32, 1, 1), (32, 1, 64, 64), torch.float32)
    buf3 = empty_strided_cpu((2, 960, 96, 96), (8847360, 1, 92160, 960), torch.float32)
    cpp_fused_native_group_norm_0(arg2_1, arg0_1, arg1_1, buf0, buf1, buf3)
    del arg0_1
    del arg1_1
    del arg2_1
    return (buf3, )

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @peterbell10

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126526

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 37d7d32 with merge base 6753ee1 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

jiayisunx added a commit that referenced this pull request May 17, 2024
@jiayisunx jiayisunx marked this pull request as draft May 17, 2024 09:17
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
jiayisunx added a commit that referenced this pull request May 20, 2024
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
jiayisunx added a commit that referenced this pull request May 21, 2024
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@jiayisunx
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

jiayisunx and others added 3 commits August 3, 2024 19:19
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@pytorch-bot pytorch-bot Bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Aug 6, 2024
[ghstack-poisoned]
[ghstack-poisoned]
@jiayisunx jiayisunx requested a review from jgong5 August 7, 2024 00:49
std::min(count, (int64_t)Vectorized<T>::size()));
count -= Vectorized<T>::size();
} else {
result.values[i] = a.values[i];
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix.

@jiayisunx
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: inductor open source release notes: fx release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants