Skip to content

[Inductor][MPS] Fix half-precision type mismatches in Metal shader codegen#176436

Closed
mergennachin wants to merge 2 commits intomainfrom
mps_bf16
Closed

[Inductor][MPS] Fix half-precision type mismatches in Metal shader codegen#176436
mergennachin wants to merge 2 commits intomainfrom
mps_bf16

Conversation

@mergennachin
Copy link
Copy Markdown
Contributor

@mergennachin mergennachin commented Mar 4, 2026

Metal Shading Language rejects implicit float-to-bfloat conversions, so
bare float literals like 0.0 in generated shaders cause compilation
failures when the target variable is bfloat (or half). Three codegen
methods were affected:

  • constant() ignored its dtype parameter and returned raw literals.
  • masked() assigned a bare literal in the else-branch (} else tmp = 0.0;).
  • where() passed a bare literal through the ternary without casting.

All three now emit static_cast<bfloat>(...) / static_cast<half>(...)
where needed. Tests added for half-precision constants, reductions, and
conditionals.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 4, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176436

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit cc0589b with merge base e45dfba (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) module: inductor labels Mar 4, 2026
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 4, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@malfet
Copy link
Copy Markdown
Contributor

malfet commented Mar 5, 2026

Hmm, there are indeed this quirk present, but only for bfloat16 (which isn't the real type as far as I understand the HW architecture)

>>> torch.mps.compile_shader("kernel void foo(device bfloat &x) { x = 0.0;}")
Traceback (most recent call last):
  File "<python-input-7>", line 1, in <module>
    torch.mps.compile_shader("kernel void foo(device bfloat &x) { x = 0.0;}")
    ~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/nshulga/git/pytorch/pytorch/torch/mps/__init__.py", line 163, in compile_shader
    return torch._C._mps_compileShader(source)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^
SyntaxError: program_source:1:41: error: assigning to 'bfloat' from incompatible type 'float'
kernel void foo(device bfloat &x) { x = 0.0;}

[torch.bfloat16] if MACOS_VERSION < 14.0 else []
)
MPS_DTYPES = [t for t in get_all_dtypes() if t not in MPS_UNSUPPORTED_TYPES]
MPS_HALF_DTYPES = [torch.float16] + ([torch.bfloat16] if MACOS_VERSION >= 14.0 else [])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We(And Apple) don't support MacOS-13 anymore. But I believe this list already exists

Suggested change
MPS_HALF_DTYPES = [torch.float16] + ([torch.bfloat16] if MACOS_VERSION >= 14.0 else [])
MPS_HALF_DTYPES = [torch.float16, torch.bfloat16]

)

@parametrize("dtype", MPS_HALF_DTYPES)
def test_half_masked(self, dtype):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Err, why does it called test_half_masked? when you call sum?

return value_to_metal(val)
raw = value_to_metal(val)
if (
dtype in (torch.bfloat16, torch.float16)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if this condition is ever hit? (I would have checked that one by adding an assert and running test suite)

But I would argue, that math op should return float values, and only downcast to low precision dtype when it's written out (as opmath_t<bfloat> == float)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the condition is hit. Verified empirically x + 1.0 with bf16 input calls constant(val=1.0, dtype=torch.bfloat16). Same for x * 2, x - 0.5, and torch.where with scalar constants.

You're right that math ops should return float values. The current fix does a pointless round-trip: the literal is float, we cast it to bfloat, then Metal implicitly promotes it back to float for arithmetic (since loads are already upcast to float32). I'll drop the cast in constant() and just return the raw literal. It's already float-compatible and matches the float32 compute dtype established by load().

@@ -244,12 +251,18 @@ def masked(mask: CSEVariable, body: sympy.Expr, other: CSEVariable) -> str:
with V.kernel.compute.indent():
V.kernel.compute.splice(scoped_body)
V.kernel.compute.writeline(f"{var} = {rc};")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you need to do it here as well?

V.kernel.compute.writeline(f"{var} = {rc};")
V.kernel.compute.writeline(f"}} else {var} = {other_str};")
V.kernel.compute.writeline(
f"}} else {var} = static_cast<{DTYPE_TO_METAL[rc.dtype]}>({other_str});"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit (rc.dtype is often wrong/undefined)

Suggested change
f"}} else {var} = static_cast<{DTYPE_TO_METAL[rc.dtype]}>({other_str});"
f"}} else {var} = static_cast<decltype(var)>({other_str});"

Comment on lines +261 to +265
c_str = value_to_metal(c)
if isinstance(b, CSEVariable) and b.dtype in (torch.bfloat16, torch.float16):
assert b.dtype is not None
c_str = f"static_cast<{DTYPE_TO_METAL[b.dtype]}>({c_str})"
return f"{a} ? {b} : {c_str}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit (i.e. always leave it to compiler rather than codegen)

Suggested change
c_str = value_to_metal(c)
if isinstance(b, CSEVariable) and b.dtype in (torch.bfloat16, torch.float16):
assert b.dtype is not None
c_str = f"static_cast<{DTYPE_TO_METAL[b.dtype]}>({c_str})"
return f"{a} ? {b} : {c_str}"
return f"{a} ? {b} : static_cast<decltype({a})>({value_to_metal(c)})"

Comment on lines +15841 to +15849
def test_half_constant(self):
for dtype in [torch.float16, torch.bfloat16]:
if not self.is_dtype_supported(dtype):
continue
self.common(
lambda x: x + 1.0,
(make_tensor(1024, dtype=dtype, device=self.device),),
check_lowp=False,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit (check_lowp indeed checks for torch.half)

Suggested change
def test_half_constant(self):
for dtype in [torch.float16, torch.bfloat16]:
if not self.is_dtype_supported(dtype):
continue
self.common(
lambda x: x + 1.0,
(make_tensor(1024, dtype=dtype, device=self.device),),
check_lowp=False,
)
def test_bfloat_constant(self):
if not self.is_dtype_supported(torch.bfloat16):
continue
self.common(
lambda x: x + 1.0,
(make_tensor(1024, dtype=torch.bfloat16, device=self.device),),
)

Comment on lines +15851 to +15859
def test_half_reduction(self):
for dtype in [torch.float16, torch.bfloat16]:
if not self.is_dtype_supported(dtype):
continue
self.common(
lambda x: x.sum(),
(make_tensor(1024, dtype=dtype, device=self.device),),
check_lowp=False,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively use paramterize

Suggested change
def test_half_reduction(self):
for dtype in [torch.float16, torch.bfloat16]:
if not self.is_dtype_supported(dtype):
continue
self.common(
lambda x: x.sum(),
(make_tensor(1024, dtype=dtype, device=self.device),),
check_lowp=False,
)
@paramtertize(dtype, [torch.float16, torch.bfloat16])
def test_lowp_reduction(self, dtype):
if not self.is_dtype_supported(dtype):
continue
self.common(
lambda x: x.sum(),
(make_tensor(1024, dtype=dtype, device=self.device),),
check_lowp=False,
)

@mergennachin
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 6, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda13.0-py3.10-gcc11 / test (default, 1, 5, lf.linux.g6.4xlarge.experimental.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@malfet malfet added this to the 2.11.0 milestone Mar 6, 2026
@malfet malfet added the module: regression It used to work, and now it doesn't label Mar 6, 2026
@malfet
Copy link
Copy Markdown
Contributor

malfet commented Mar 6, 2026

@pytorchbot merge -f "Lint + MPS are green, hopefully other failures are just broken trunk"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (default, 2, 3, macos-m1-stable)

Details for Dev Infra team Raised by workflow job

…degen

Metal Shading Language rejects implicit float-to-bfloat conversions, so
bare float literals like `0.0` in generated shaders cause compilation
failures when the target variable is `bfloat` (or `half`). Three codegen
methods were affected:

- `constant()` ignored its `dtype` parameter and returned raw literals.
- `masked()` assigned a bare literal in the else-branch (`} else tmp = 0.0;`).
- `where()` passed a bare literal through the ternary without casting.

All three now emit `static_cast<bfloat>(...)` / `static_cast<half>(...)`
where needed. Tests added for half-precision constants, reductions, and
conditionals.
@mergennachin
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@mergennachin
Copy link
Copy Markdown
Contributor Author

@pytorchbot cherry-pick --onto release/2.11 -c critical --fixes "Bug fix for MPS backend using inductor"

@pytorchbot
Copy link
Copy Markdown
Collaborator

Cherry picking #176436

Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x 3b161e7a756798e6eb1ab096f4ef1232d163a68d returned non-zero exit code 1

Auto-merging test/inductor/test_torchinductor.py
CONFLICT (content): Merge conflict in test/inductor/test_torchinductor.py
Auto-merging torch/_inductor/codegen/mps.py
error: could not apply 3b161e7a756... [Inductor][MPS] Fix half-precision type mismatches in Metal shader codegen (#176436)
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

@malfet malfet deleted the mps_bf16 branch March 11, 2026 18:00
malfet pushed a commit that referenced this pull request Mar 11, 2026
…degen (#176436)

Metal Shading Language rejects implicit float-to-bfloat conversions, so
bare float literals like `0.0` in generated shaders cause compilation
failures when the target variable is `bfloat` (or `half`). Three codegen
methods were affected:

- `constant()` ignored its `dtype` parameter and returned raw literals.
- `masked()` assigned a bare literal in the else-branch (`} else tmp = 0.0;`).
- `where()` passed a bare literal through the ternary without casting.

All three now emit `static_cast<bfloat>(...)` / `static_cast<half>(...)`
where needed. Tests added for half-precision constants, reductions, and
conditionals.

Pull Request resolved: #176436
Approved by: https://github.com/malfet

Test plan: Run `python -c "import torch;F=torch.nn.functional;print(torch.compile(lambda x: F.pad(F.gelu(x), [1, 0]))(torch.randn(4, device='mps', dtype=torch.bfloat16)))"`

(cherry picked from commit 3b161e7)
malfet added a commit that referenced this pull request Mar 11, 2026
…degen (#176436) (#177193)

Metal Shading Language rejects implicit float-to-bfloat conversions, so
bare float literals like `0.0` in generated shaders cause compilation
failures when the target variable is `bfloat` (or `half`). Three codegen
methods were affected:

- `constant()` ignored its `dtype` parameter and returned raw literals.
- `masked()` assigned a bare literal in the else-branch (`} else tmp = 0.0;`).
- `where()` passed a bare literal through the ternary without casting.

All three now emit `static_cast<bfloat>(...)` / `static_cast<half>(...)`
where needed. Tests added for half-precision constants, reductions, and
conditionals.

Pull Request resolved: #176436
Approved by: https://github.com/malfet

Test plan: Run `python -c "import torch;F=torch.nn.functional;print(torch.compile(lambda x: F.pad(F.gelu(x), [1, 0]))(torch.randn(4, device='mps', dtype=torch.bfloat16)))"`

(cherry picked from commit 3b161e7)

Co-authored-by: Mergen Nachin <mnachin@meta.com>
malfet added a commit that referenced this pull request Mar 11, 2026
- Remove `test_bfloat_constant`, `test_lowp_reduction`, and `test_lowp_where` as they don't test for anything beyond what existing tests cover.
- Add test_pad_after_gelu as a regression test for Voxtral compilation on MPS, exercising pad(gelu(x)) across fp32, fp16, and bfloat16.

Before #176436 test will fail with
```
torch._inductor.exc.InductorError: SyntaxError: failed to compile
    #include <c10/metal/utils.h>
    #include <c10/metal/special_math.h>
    kernel void generated_kernel(
        device bfloat* out_ptr0,
        constant bfloat* in_ptr0,
        uint xindex [[thread_position_in_grid]]
    ) {
        int x0 = (xindex) % (17);
        int x1 = c10::metal::floor_divide(xindex, 17);
        int x2 = xindex;
        auto tmp0 = (-1) + x0;
        auto tmp1 = static_cast<long>(tmp0);
        auto tmp2 = 0;
        auto tmp3 = tmp1 >= tmp2;
        bfloat tmp4;
        if (tmp3) {
            auto tmp_scoped_0 = static_cast<float>(in_ptr0[(-1) + x0 + 16*x1]);
            auto tmp_scoped_1 = static_cast<float>(tmp_scoped_0);
            auto tmp_scoped_2 = 0.5;
            auto tmp_scoped_3 = tmp_scoped_1 * tmp_scoped_2;
            auto tmp_scoped_4 = 0.7071067811865476;
            auto tmp_scoped_5 = tmp_scoped_1 * tmp_scoped_4;
            auto tmp_scoped_6 = c10::metal::erf(tmp_scoped_5);
            auto tmp_scoped_7 = 1.0;
            auto tmp_scoped_8 = tmp_scoped_6 + tmp_scoped_7;
            auto tmp_scoped_9 = tmp_scoped_3 * tmp_scoped_8;
            auto tmp_scoped_10 = static_cast<bfloat>(tmp_scoped_9);
            tmp4 = tmp_scoped_10;
        } else tmp4 = 0.0;
        out_ptr0[x2] = static_cast<bfloat>(tmp4);
    }
 with program_source:4495:23: error: assigning to 'bfloat' from incompatible type 'float'
        } else tmp4 = 0.0;
                      ^~~
```

Authored with Claude.


ghstack-source-id: 7919b53
Pull-Request: #177207
pytorchmergebot pushed a commit that referenced this pull request Mar 11, 2026
- Remove `test_bfloat_constant`, `test_lowp_reduction`, and `test_lowp_where` as they don't test for anything beyond what existing tests cover.
- Add test_pad_after_gelu as a regression test for Voxtral compilation on MPS, exercising pad(gelu(x)) across fp32, fp16, and bfloat16.

Before #176436 test will fail with
```
torch._inductor.exc.InductorError: SyntaxError: failed to compile
    #include <c10/metal/utils.h>
    #include <c10/metal/special_math.h>
    kernel void generated_kernel(
        device bfloat* out_ptr0,
        constant bfloat* in_ptr0,
        uint xindex [[thread_position_in_grid]]
    ) {
        int x0 = (xindex) % (17);
        int x1 = c10::metal::floor_divide(xindex, 17);
        int x2 = xindex;
        auto tmp0 = (-1) + x0;
        auto tmp1 = static_cast<long>(tmp0);
        auto tmp2 = 0;
        auto tmp3 = tmp1 >= tmp2;
        bfloat tmp4;
        if (tmp3) {
            auto tmp_scoped_0 = static_cast<float>(in_ptr0[(-1) + x0 + 16*x1]);
            auto tmp_scoped_1 = static_cast<float>(tmp_scoped_0);
            auto tmp_scoped_2 = 0.5;
            auto tmp_scoped_3 = tmp_scoped_1 * tmp_scoped_2;
            auto tmp_scoped_4 = 0.7071067811865476;
            auto tmp_scoped_5 = tmp_scoped_1 * tmp_scoped_4;
            auto tmp_scoped_6 = c10::metal::erf(tmp_scoped_5);
            auto tmp_scoped_7 = 1.0;
            auto tmp_scoped_8 = tmp_scoped_6 + tmp_scoped_7;
            auto tmp_scoped_9 = tmp_scoped_3 * tmp_scoped_8;
            auto tmp_scoped_10 = static_cast<bfloat>(tmp_scoped_9);
            tmp4 = tmp_scoped_10;
        } else tmp4 = 0.0;
        out_ptr0[x2] = static_cast<bfloat>(tmp4);
    }
 with program_source:4495:23: error: assigning to 'bfloat' from incompatible type 'float'
        } else tmp4 = 0.0;
                      ^~~
```

Authored with Claude.

ghstack-source-id: f075662
Pull-Request: #177207
pytorchmergebot pushed a commit that referenced this pull request Mar 12, 2026
…177207)

----

- Remove `test_bfloat_constant`, `test_lowp_reduction`, and `test_lowp_where` as they don't test for anything beyond what existing tests cover.
- Add `test_pad_after_gelu` as a regression test for Voxtral compilation on MPS, exercising pad(gelu(x)) across fp32, fp16, and bfloat16.

Before #176436 test will fail with
```
torch._inductor.exc.InductorError: SyntaxError: failed to compile
    #include <c10/metal/utils.h>
    #include <c10/metal/special_math.h>
    kernel void generated_kernel(
        device bfloat* out_ptr0,
        constant bfloat* in_ptr0,
        uint xindex [[thread_position_in_grid]]
    ) {
        int x0 = (xindex) % (17);
        int x1 = c10::metal::floor_divide(xindex, 17);
        int x2 = xindex;
        auto tmp0 = (-1) + x0;
        auto tmp1 = static_cast<long>(tmp0);
        auto tmp2 = 0;
        auto tmp3 = tmp1 >= tmp2;
        bfloat tmp4;
        if (tmp3) {
            auto tmp_scoped_0 = static_cast<float>(in_ptr0[(-1) + x0 + 16*x1]);
            auto tmp_scoped_1 = static_cast<float>(tmp_scoped_0);
            auto tmp_scoped_2 = 0.5;
            auto tmp_scoped_3 = tmp_scoped_1 * tmp_scoped_2;
            auto tmp_scoped_4 = 0.7071067811865476;
            auto tmp_scoped_5 = tmp_scoped_1 * tmp_scoped_4;
            auto tmp_scoped_6 = c10::metal::erf(tmp_scoped_5);
            auto tmp_scoped_7 = 1.0;
            auto tmp_scoped_8 = tmp_scoped_6 + tmp_scoped_7;
            auto tmp_scoped_9 = tmp_scoped_3 * tmp_scoped_8;
            auto tmp_scoped_10 = static_cast<bfloat>(tmp_scoped_9);
            tmp4 = tmp_scoped_10;
        } else tmp4 = 0.0;
        out_ptr0[x2] = static_cast<bfloat>(tmp4);
    }
 with program_source:4495:23: error: assigning to 'bfloat' from incompatible type 'float'
        } else tmp4 = 0.0;
                      ^~~
```

Authored with Claude.

Pull Request resolved: #177207
Approved by: https://github.com/atalman, https://github.com/mergennachin, https://github.com/jansel
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…degen (pytorch#176436)

Metal Shading Language rejects implicit float-to-bfloat conversions, so
bare float literals like `0.0` in generated shaders cause compilation
failures when the target variable is `bfloat` (or `half`). Three codegen
methods were affected:

- `constant()` ignored its `dtype` parameter and returned raw literals.
- `masked()` assigned a bare literal in the else-branch (`} else tmp = 0.0;`).
- `where()` passed a bare literal through the ternary without casting.

All three now emit `static_cast<bfloat>(...)` / `static_cast<half>(...)`
where needed. Tests added for half-precision constants, reductions, and
conditionals.

Pull Request resolved: pytorch#176436
Approved by: https://github.com/malfet
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…hader codegen (pytorch#176436)"

This reverts commit 4926192.

Reverted pytorch#176436 on behalf of https://github.com/zou3519 due to sorry I need to revert this in order to revert pytorch#176606 ([comment](pytorch#176436 (comment)))
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…degen (pytorch#176436)

Metal Shading Language rejects implicit float-to-bfloat conversions, so
bare float literals like `0.0` in generated shaders cause compilation
failures when the target variable is `bfloat` (or `half`). Three codegen
methods were affected:

- `constant()` ignored its `dtype` parameter and returned raw literals.
- `masked()` assigned a bare literal in the else-branch (`} else tmp = 0.0;`).
- `where()` passed a bare literal through the ternary without casting.

All three now emit `static_cast<bfloat>(...)` / `static_cast<half>(...)`
where needed. Tests added for half-precision constants, reductions, and
conditionals.

Pull Request resolved: pytorch#176436
Approved by: https://github.com/malfet
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…ytorch#177207)

----

- Remove `test_bfloat_constant`, `test_lowp_reduction`, and `test_lowp_where` as they don't test for anything beyond what existing tests cover.
- Add `test_pad_after_gelu` as a regression test for Voxtral compilation on MPS, exercising pad(gelu(x)) across fp32, fp16, and bfloat16.

Before pytorch#176436 test will fail with
```
torch._inductor.exc.InductorError: SyntaxError: failed to compile
    #include <c10/metal/utils.h>
    #include <c10/metal/special_math.h>
    kernel void generated_kernel(
        device bfloat* out_ptr0,
        constant bfloat* in_ptr0,
        uint xindex [[thread_position_in_grid]]
    ) {
        int x0 = (xindex) % (17);
        int x1 = c10::metal::floor_divide(xindex, 17);
        int x2 = xindex;
        auto tmp0 = (-1) + x0;
        auto tmp1 = static_cast<long>(tmp0);
        auto tmp2 = 0;
        auto tmp3 = tmp1 >= tmp2;
        bfloat tmp4;
        if (tmp3) {
            auto tmp_scoped_0 = static_cast<float>(in_ptr0[(-1) + x0 + 16*x1]);
            auto tmp_scoped_1 = static_cast<float>(tmp_scoped_0);
            auto tmp_scoped_2 = 0.5;
            auto tmp_scoped_3 = tmp_scoped_1 * tmp_scoped_2;
            auto tmp_scoped_4 = 0.7071067811865476;
            auto tmp_scoped_5 = tmp_scoped_1 * tmp_scoped_4;
            auto tmp_scoped_6 = c10::metal::erf(tmp_scoped_5);
            auto tmp_scoped_7 = 1.0;
            auto tmp_scoped_8 = tmp_scoped_6 + tmp_scoped_7;
            auto tmp_scoped_9 = tmp_scoped_3 * tmp_scoped_8;
            auto tmp_scoped_10 = static_cast<bfloat>(tmp_scoped_9);
            tmp4 = tmp_scoped_10;
        } else tmp4 = 0.0;
        out_ptr0[x2] = static_cast<bfloat>(tmp4);
    }
 with program_source:4495:23: error: assigning to 'bfloat' from incompatible type 'float'
        } else tmp4 = 0.0;
                      ^~~
```

Authored with Claude.

Pull Request resolved: pytorch#177207
Approved by: https://github.com/atalman, https://github.com/mergennachin, https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor module: regression It used to work, and now it doesn't release notes: mps Release notes category Reverted topic: bug fixes topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants