[Inductor][MPS] Fix half-precision type mismatches in Metal shader codegen (#176436)#177193
[Inductor][MPS] Fix half-precision type mismatches in Metal shader codegen (#176436)#177193malfet merged 1 commit intorelease/2.11from
Conversation
…degen (#176436) Metal Shading Language rejects implicit float-to-bfloat conversions, so bare float literals like `0.0` in generated shaders cause compilation failures when the target variable is `bfloat` (or `half`). Three codegen methods were affected: - `constant()` ignored its `dtype` parameter and returned raw literals. - `masked()` assigned a bare literal in the else-branch (`} else tmp = 0.0;`). - `where()` passed a bare literal through the ternary without casting. All three now emit `static_cast<bfloat>(...)` / `static_cast<half>(...)` where needed. Tests added for half-precision constants, reductions, and conditionals. Pull Request resolved: #176436 Approved by: https://github.com/malfet Test plan: Run `python -c "import torch;F=torch.nn.functional;print(torch.compile(lambda x: F.pad(F.gelu(x), [1, 0]))(torch.randn(4, device='mps', dtype=torch.bfloat16)))"` (cherry picked from commit 3b161e7)
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177193
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Cancelled Job, 67 PendingAs of commit 57929f7 with merge base 0fd766e ( NEW FAILURES - The following jobs have failed:
UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
Verify the bug is fixed in release 2.11.0.Repro Summary: Metal Shading Language bfloat/half implicit float conversionTest: Ran the MRE from the bug report using Result: Fixed
if not resolved: Local ReproduceSetupuv venv mps-repro
source mps-repro/bin/activate
uv pip install torch==2.11.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cpu
python -c "import torch; print(torch.__version__)"
# output: 2.11.0
python -c "import torch;F=torch.nn.functional;print(torch.compile(lambda x: F.pad(F.gelu(x), [1, 0]))(torch.randn(4, device='mps', dtype=torch.bfloat16)))"OutputVersions |
Metal Shading Language rejects implicit float-to-bfloat conversions, so
bare float literals like
0.0in generated shaders cause compilationfailures when the target variable is
bfloat(orhalf). Three codegenmethods were affected:
constant()ignored itsdtypeparameter and returned raw literals.masked()assigned a bare literal in the else-branch (} else tmp = 0.0;).where()passed a bare literal through the ternary without casting.All three now emit
static_cast<bfloat>(...)/static_cast<half>(...)where needed. Tests added for half-precision constants, reductions, and
conditionals.
Pull Request resolved: #176436
Approved by: https://github.com/malfet
Test plan: Run
python -c "import torch;F=torch.nn.functional;print(torch.compile(lambda x: F.pad(F.gelu(x), [1, 0]))(torch.randn(4, device='mps', dtype=torch.bfloat16)))"(cherry picked from commit 3b161e7)