Skip to content

[Inductor] Add proper regression test for Voxtral compilation on MPS#177207

Closed
malfet wants to merge 2 commits intogh/malfet/765/basefrom
gh/malfet/765/head
Closed

[Inductor] Add proper regression test for Voxtral compilation on MPS#177207
malfet wants to merge 2 commits intogh/malfet/765/basefrom
gh/malfet/765/head

Conversation

@malfet
Copy link
Copy Markdown
Contributor

@malfet malfet commented Mar 11, 2026

Stack from ghstack (oldest at bottom):


  • Remove test_bfloat_constant, test_lowp_reduction, and test_lowp_where as they don't test for anything beyond what existing tests cover.
  • Add test_pad_after_gelu as a regression test for Voxtral compilation on MPS, exercising pad(gelu(x)) across fp32, fp16, and bfloat16.

Before #176436 test will fail with

torch._inductor.exc.InductorError: SyntaxError: failed to compile
    #include <c10/metal/utils.h>
    #include <c10/metal/special_math.h>
    kernel void generated_kernel(
        device bfloat* out_ptr0,
        constant bfloat* in_ptr0,
        uint xindex [[thread_position_in_grid]]
    ) {
        int x0 = (xindex) % (17);
        int x1 = c10::metal::floor_divide(xindex, 17);
        int x2 = xindex;
        auto tmp0 = (-1) + x0;
        auto tmp1 = static_cast<long>(tmp0);
        auto tmp2 = 0;
        auto tmp3 = tmp1 >= tmp2;
        bfloat tmp4;
        if (tmp3) {
            auto tmp_scoped_0 = static_cast<float>(in_ptr0[(-1) + x0 + 16*x1]);
            auto tmp_scoped_1 = static_cast<float>(tmp_scoped_0);
            auto tmp_scoped_2 = 0.5;
            auto tmp_scoped_3 = tmp_scoped_1 * tmp_scoped_2;
            auto tmp_scoped_4 = 0.7071067811865476;
            auto tmp_scoped_5 = tmp_scoped_1 * tmp_scoped_4;
            auto tmp_scoped_6 = c10::metal::erf(tmp_scoped_5);
            auto tmp_scoped_7 = 1.0;
            auto tmp_scoped_8 = tmp_scoped_6 + tmp_scoped_7;
            auto tmp_scoped_9 = tmp_scoped_3 * tmp_scoped_8;
            auto tmp_scoped_10 = static_cast<bfloat>(tmp_scoped_9);
            tmp4 = tmp_scoped_10;
        } else tmp4 = 0.0;
        out_ptr0[x2] = static_cast<bfloat>(tmp4);
    }
 with program_source:4495:23: error: assigning to 'bfloat' from incompatible type 'float'
        } else tmp4 = 0.0;
                      ^~~

Authored with Claude.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 11, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177207

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit bb6423e with merge base ad67e7a (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

malfet added a commit that referenced this pull request Mar 11, 2026
- Remove `test_bfloat_constant`, `test_lowp_reduction`, and `test_lowp_where` as they don't test for anything beyond what existing tests cover.
- Add test_pad_after_gelu as a regression test for Voxtral compilation on MPS, exercising pad(gelu(x)) across fp32, fp16, and bfloat16.

Before #176436 test will fail with
```
torch._inductor.exc.InductorError: SyntaxError: failed to compile
    #include <c10/metal/utils.h>
    #include <c10/metal/special_math.h>
    kernel void generated_kernel(
        device bfloat* out_ptr0,
        constant bfloat* in_ptr0,
        uint xindex [[thread_position_in_grid]]
    ) {
        int x0 = (xindex) % (17);
        int x1 = c10::metal::floor_divide(xindex, 17);
        int x2 = xindex;
        auto tmp0 = (-1) + x0;
        auto tmp1 = static_cast<long>(tmp0);
        auto tmp2 = 0;
        auto tmp3 = tmp1 >= tmp2;
        bfloat tmp4;
        if (tmp3) {
            auto tmp_scoped_0 = static_cast<float>(in_ptr0[(-1) + x0 + 16*x1]);
            auto tmp_scoped_1 = static_cast<float>(tmp_scoped_0);
            auto tmp_scoped_2 = 0.5;
            auto tmp_scoped_3 = tmp_scoped_1 * tmp_scoped_2;
            auto tmp_scoped_4 = 0.7071067811865476;
            auto tmp_scoped_5 = tmp_scoped_1 * tmp_scoped_4;
            auto tmp_scoped_6 = c10::metal::erf(tmp_scoped_5);
            auto tmp_scoped_7 = 1.0;
            auto tmp_scoped_8 = tmp_scoped_6 + tmp_scoped_7;
            auto tmp_scoped_9 = tmp_scoped_3 * tmp_scoped_8;
            auto tmp_scoped_10 = static_cast<bfloat>(tmp_scoped_9);
            tmp4 = tmp_scoped_10;
        } else tmp4 = 0.0;
        out_ptr0[x2] = static_cast<bfloat>(tmp4);
    }
 with program_source:4495:23: error: assigning to 'bfloat' from incompatible type 'float'
        } else tmp4 = 0.0;
                      ^~~
```

Authored with Claude.


ghstack-source-id: 7919b53
Pull-Request: #177207
@malfet malfet added the ciflow/mps Run MPS tests (subset of trunk) label Mar 11, 2026
@malfet malfet requested review from jansel and mergennachin March 11, 2026 22:22
@malfet
Copy link
Copy Markdown
Contributor Author

malfet commented Mar 11, 2026

@pytorchbot fix-lint

Copy link
Copy Markdown
Contributor

@atalman atalman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

[ghstack-poisoned]
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Successfully applied lint patches in https://github.com/pytorch/pytorch/actions/runs/22979241446. Please pull locally before pushing more changes.

pytorchmergebot pushed a commit that referenced this pull request Mar 11, 2026
- Remove `test_bfloat_constant`, `test_lowp_reduction`, and `test_lowp_where` as they don't test for anything beyond what existing tests cover.
- Add test_pad_after_gelu as a regression test for Voxtral compilation on MPS, exercising pad(gelu(x)) across fp32, fp16, and bfloat16.

Before #176436 test will fail with
```
torch._inductor.exc.InductorError: SyntaxError: failed to compile
    #include <c10/metal/utils.h>
    #include <c10/metal/special_math.h>
    kernel void generated_kernel(
        device bfloat* out_ptr0,
        constant bfloat* in_ptr0,
        uint xindex [[thread_position_in_grid]]
    ) {
        int x0 = (xindex) % (17);
        int x1 = c10::metal::floor_divide(xindex, 17);
        int x2 = xindex;
        auto tmp0 = (-1) + x0;
        auto tmp1 = static_cast<long>(tmp0);
        auto tmp2 = 0;
        auto tmp3 = tmp1 >= tmp2;
        bfloat tmp4;
        if (tmp3) {
            auto tmp_scoped_0 = static_cast<float>(in_ptr0[(-1) + x0 + 16*x1]);
            auto tmp_scoped_1 = static_cast<float>(tmp_scoped_0);
            auto tmp_scoped_2 = 0.5;
            auto tmp_scoped_3 = tmp_scoped_1 * tmp_scoped_2;
            auto tmp_scoped_4 = 0.7071067811865476;
            auto tmp_scoped_5 = tmp_scoped_1 * tmp_scoped_4;
            auto tmp_scoped_6 = c10::metal::erf(tmp_scoped_5);
            auto tmp_scoped_7 = 1.0;
            auto tmp_scoped_8 = tmp_scoped_6 + tmp_scoped_7;
            auto tmp_scoped_9 = tmp_scoped_3 * tmp_scoped_8;
            auto tmp_scoped_10 = static_cast<bfloat>(tmp_scoped_9);
            tmp4 = tmp_scoped_10;
        } else tmp4 = 0.0;
        out_ptr0[x2] = static_cast<bfloat>(tmp4);
    }
 with program_source:4495:23: error: assigning to 'bfloat' from incompatible type 'float'
        } else tmp4 = 0.0;
                      ^~~
```

Authored with Claude.

ghstack-source-id: f075662
Pull-Request: #177207
@malfet
Copy link
Copy Markdown
Contributor Author

malfet commented Mar 12, 2026

@pytorchbot merge -f "I do have enough signal on this one"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfet malfet added autorevert: disable Disable autorevert for a specific PR and removed autorevert: disable Disable autorevert for a specific PR labels Mar 12, 2026
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…ytorch#177207)

----

- Remove `test_bfloat_constant`, `test_lowp_reduction`, and `test_lowp_where` as they don't test for anything beyond what existing tests cover.
- Add `test_pad_after_gelu` as a regression test for Voxtral compilation on MPS, exercising pad(gelu(x)) across fp32, fp16, and bfloat16.

Before pytorch#176436 test will fail with
```
torch._inductor.exc.InductorError: SyntaxError: failed to compile
    #include <c10/metal/utils.h>
    #include <c10/metal/special_math.h>
    kernel void generated_kernel(
        device bfloat* out_ptr0,
        constant bfloat* in_ptr0,
        uint xindex [[thread_position_in_grid]]
    ) {
        int x0 = (xindex) % (17);
        int x1 = c10::metal::floor_divide(xindex, 17);
        int x2 = xindex;
        auto tmp0 = (-1) + x0;
        auto tmp1 = static_cast<long>(tmp0);
        auto tmp2 = 0;
        auto tmp3 = tmp1 >= tmp2;
        bfloat tmp4;
        if (tmp3) {
            auto tmp_scoped_0 = static_cast<float>(in_ptr0[(-1) + x0 + 16*x1]);
            auto tmp_scoped_1 = static_cast<float>(tmp_scoped_0);
            auto tmp_scoped_2 = 0.5;
            auto tmp_scoped_3 = tmp_scoped_1 * tmp_scoped_2;
            auto tmp_scoped_4 = 0.7071067811865476;
            auto tmp_scoped_5 = tmp_scoped_1 * tmp_scoped_4;
            auto tmp_scoped_6 = c10::metal::erf(tmp_scoped_5);
            auto tmp_scoped_7 = 1.0;
            auto tmp_scoped_8 = tmp_scoped_6 + tmp_scoped_7;
            auto tmp_scoped_9 = tmp_scoped_3 * tmp_scoped_8;
            auto tmp_scoped_10 = static_cast<bfloat>(tmp_scoped_9);
            tmp4 = tmp_scoped_10;
        } else tmp4 = 0.0;
        out_ptr0[x2] = static_cast<bfloat>(tmp4);
    }
 with program_source:4495:23: error: assigning to 'bfloat' from incompatible type 'float'
        } else tmp4 = 0.0;
                      ^~~
```

Authored with Claude.

Pull Request resolved: pytorch#177207
Approved by: https://github.com/atalman, https://github.com/mergennachin, https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants