Skip to content

Commit 4926192

Browse files
mergennachinpytorchmergebot
authored andcommitted
[Inductor][MPS] Fix half-precision type mismatches in Metal shader codegen (#176436)
Metal Shading Language rejects implicit float-to-bfloat conversions, so bare float literals like `0.0` in generated shaders cause compilation failures when the target variable is `bfloat` (or `half`). Three codegen methods were affected: - `constant()` ignored its `dtype` parameter and returned raw literals. - `masked()` assigned a bare literal in the else-branch (`} else tmp = 0.0;`). - `where()` passed a bare literal through the ternary without casting. All three now emit `static_cast<bfloat>(...)` / `static_cast<half>(...)` where needed. Tests added for half-precision constants, reductions, and conditionals. Pull Request resolved: #176436 Approved by: https://github.com/malfet
1 parent ca5f3b9 commit 4926192

2 files changed

Lines changed: 35 additions & 3 deletions

File tree

test/inductor/test_torchinductor.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15914,6 +15914,34 @@ def fn(x, index, source):
1591415914
result = torch.compile(fn)(x_base.clone()[:, 2:, :], index, source)
1591515915
self.assertEqual(result, expected)
1591615916

15917+
def test_bfloat_constant(self):
15918+
if not self.is_dtype_supported(torch.bfloat16):
15919+
raise unittest.SkipTest("bfloat16 not supported")
15920+
self.common(
15921+
lambda x: x + 1.0,
15922+
(make_tensor(1024, dtype=torch.bfloat16, device=self.device),),
15923+
)
15924+
15925+
@parametrize("dtype", [torch.float16, torch.bfloat16])
15926+
def test_lowp_reduction(self, dtype):
15927+
if not self.is_dtype_supported(dtype):
15928+
raise unittest.SkipTest(f"{dtype} not supported")
15929+
self.common(
15930+
lambda x: x.sum(),
15931+
(make_tensor(1024, dtype=dtype, device=self.device),),
15932+
check_lowp=False,
15933+
)
15934+
15935+
@parametrize("dtype", [torch.float16, torch.bfloat16])
15936+
def test_lowp_where(self, dtype):
15937+
if not self.is_dtype_supported(dtype):
15938+
raise unittest.SkipTest(f"{dtype} not supported")
15939+
self.common(
15940+
lambda x: torch.where(x > 0.5, x, x.new_zeros(())),
15941+
(make_tensor(1024, dtype=dtype, device=self.device),),
15942+
check_lowp=False,
15943+
)
15944+
1591715945
# end of class CommonTemplate - add new tests here
1591815946

1591915947

torch/_inductor/codegen/mps.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,13 +243,17 @@ def masked(mask: CSEVariable, body: sympy.Expr, other: CSEVariable) -> str:
243243
)
244244
with V.kernel.compute.indent():
245245
V.kernel.compute.splice(scoped_body)
246-
V.kernel.compute.writeline(f"{var} = {rc};")
247-
V.kernel.compute.writeline(f"}} else {var} = {other_str};")
246+
V.kernel.compute.writeline(
247+
f"{var} = static_cast<decltype({var})>({rc});"
248+
)
249+
V.kernel.compute.writeline(
250+
f"}} else {var} = static_cast<decltype({var})>({other_str});"
251+
)
248252
return var
249253

250254
@staticmethod
251255
def where(a: OpVarT, b: OpVarT, c: OpVarT) -> str:
252-
return f"{a} ? {b} : {value_to_metal(c)}"
256+
return f"{a} ? {b} : static_cast<decltype({b})>({value_to_metal(c)})"
253257

254258
@staticmethod
255259
def remainder(a: OpVarT, b: OpVarT) -> str:

0 commit comments

Comments
 (0)