Commit 4926192
[Inductor][MPS] Fix half-precision type mismatches in Metal shader codegen (#176436)
Metal Shading Language rejects implicit float-to-bfloat conversions, so
bare float literals like `0.0` in generated shaders cause compilation
failures when the target variable is `bfloat` (or `half`). Three codegen
methods were affected:
- `constant()` ignored its `dtype` parameter and returned raw literals.
- `masked()` assigned a bare literal in the else-branch (`} else tmp = 0.0;`).
- `where()` passed a bare literal through the ternary without casting.
All three now emit `static_cast<bfloat>(...)` / `static_cast<half>(...)`
where needed. Tests added for half-precision constants, reductions, and
conditionals.
Pull Request resolved: #176436
Approved by: https://github.com/malfet1 parent ca5f3b9 commit 4926192
2 files changed
Lines changed: 35 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
15914 | 15914 | | |
15915 | 15915 | | |
15916 | 15916 | | |
| 15917 | + | |
| 15918 | + | |
| 15919 | + | |
| 15920 | + | |
| 15921 | + | |
| 15922 | + | |
| 15923 | + | |
| 15924 | + | |
| 15925 | + | |
| 15926 | + | |
| 15927 | + | |
| 15928 | + | |
| 15929 | + | |
| 15930 | + | |
| 15931 | + | |
| 15932 | + | |
| 15933 | + | |
| 15934 | + | |
| 15935 | + | |
| 15936 | + | |
| 15937 | + | |
| 15938 | + | |
| 15939 | + | |
| 15940 | + | |
| 15941 | + | |
| 15942 | + | |
| 15943 | + | |
| 15944 | + | |
15917 | 15945 | | |
15918 | 15946 | | |
15919 | 15947 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
243 | 243 | | |
244 | 244 | | |
245 | 245 | | |
246 | | - | |
247 | | - | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
248 | 252 | | |
249 | 253 | | |
250 | 254 | | |
251 | 255 | | |
252 | | - | |
| 256 | + | |
253 | 257 | | |
254 | 258 | | |
255 | 259 | | |
| |||
0 commit comments