Skip to content

[Inductor][MPS] Fix half-precision type mismatches in Metal shader codegen (#176436)#177193

Merged
malfet merged 1 commit intorelease/2.11from
malfet/cp-/176436
Mar 11, 2026
Merged

[Inductor][MPS] Fix half-precision type mismatches in Metal shader codegen (#176436)#177193
malfet merged 1 commit intorelease/2.11from
malfet/cp-/176436

Conversation

@malfet
Copy link
Copy Markdown
Contributor

@malfet malfet commented Mar 11, 2026

Metal Shading Language rejects implicit float-to-bfloat conversions, so
bare float literals like 0.0 in generated shaders cause compilation
failures when the target variable is bfloat (or half). Three codegen
methods were affected:

  • constant() ignored its dtype parameter and returned raw literals.
  • masked() assigned a bare literal in the else-branch (} else tmp = 0.0;).
  • where() passed a bare literal through the ternary without casting.

All three now emit static_cast<bfloat>(...) / static_cast<half>(...)
where needed. Tests added for half-precision constants, reductions, and
conditionals.

Pull Request resolved: #176436
Approved by: https://github.com/malfet

Test plan: Run python -c "import torch;F=torch.nn.functional;print(torch.compile(lambda x: F.pad(F.gelu(x), [1, 0]))(torch.randn(4, device='mps', dtype=torch.bfloat16)))"

(cherry picked from commit 3b161e7)

…degen (#176436)

Metal Shading Language rejects implicit float-to-bfloat conversions, so
bare float literals like `0.0` in generated shaders cause compilation
failures when the target variable is `bfloat` (or `half`). Three codegen
methods were affected:

- `constant()` ignored its `dtype` parameter and returned raw literals.
- `masked()` assigned a bare literal in the else-branch (`} else tmp = 0.0;`).
- `where()` passed a bare literal through the ternary without casting.

All three now emit `static_cast<bfloat>(...)` / `static_cast<half>(...)`
where needed. Tests added for half-precision constants, reductions, and
conditionals.

Pull Request resolved: #176436
Approved by: https://github.com/malfet

Test plan: Run `python -c "import torch;F=torch.nn.functional;print(torch.compile(lambda x: F.pad(F.gelu(x), [1, 0]))(torch.randn(4, device='mps', dtype=torch.bfloat16)))"`

(cherry picked from commit 3b161e7)
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 11, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177193

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Cancelled Job, 67 Pending

As of commit 57929f7 with merge base 0fd766e (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) ciflow/torchtitan Run TorchTitan integration tests module: inductor labels Mar 11, 2026
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 11, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Copy Markdown
Contributor

@atalman atalman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@malfet malfet merged commit fa384de into release/2.11 Mar 11, 2026
73 of 144 checks passed
@yangw-dev
Copy link
Copy Markdown
Contributor

Verify the bug is fixed in release 2.11.0.

Repro Summary: Metal Shading Language bfloat/half implicit float conversion

Test: Ran the MRE from the bug report using torch.compile with bfloat16 on MPS.

Result: Fixed

if not resolved:
Metal shader compilation would fail with an implicit conversion error:

error: cannot implicitly convert type 'float' to 'bfloat'

Local Reproduce

Setup

uv venv mps-repro
source mps-repro/bin/activate

uv pip install torch==2.11.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cpu

python -c "import torch; print(torch.__version__)"
# output: 2.11.0

python -c "import torch;F=torch.nn.functional;print(torch.compile(lambda x: F.pad(F.gelu(x), [1, 0]))(torch.randn(4, device='mps', dtype=torch.bfloat16)))"

Output

tensor([ 0.0000,  0.3047,  0.2344, -0.1680,  1.1797], device='mps:0',
       dtype=torch.bfloat16)

Versions

PyTorch version: 2.11.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 26.3.1 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.6.4.2)
CMake version: version 4.3.0
Libc version: N/A

Python version: 3.13.12 (main, Feb 3 2026, 17:53:27) [Clang 17.0.0 (clang-1700.6.3.2)] (64-bit runtime)
Python platform: macOS-26.3.1-arm64-arm-64bit-Mach-O
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU: Apple M1 Pro

Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) ciflow/torchtitan Run TorchTitan integration tests module: inductor

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants