[Inductor][MPS] Fix half-precision type mismatches in Metal shader codegen#176436
[Inductor][MPS] Fix half-precision type mismatches in Metal shader codegen#176436mergennachin wants to merge 2 commits intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176436
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit cc0589b with merge base e45dfba ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
Hmm, there are indeed this quirk present, but only for |
test/inductor/test_mps_basic.py
Outdated
| [torch.bfloat16] if MACOS_VERSION < 14.0 else [] | ||
| ) | ||
| MPS_DTYPES = [t for t in get_all_dtypes() if t not in MPS_UNSUPPORTED_TYPES] | ||
| MPS_HALF_DTYPES = [torch.float16] + ([torch.bfloat16] if MACOS_VERSION >= 14.0 else []) |
There was a problem hiding this comment.
We(And Apple) don't support MacOS-13 anymore. But I believe this list already exists
| MPS_HALF_DTYPES = [torch.float16] + ([torch.bfloat16] if MACOS_VERSION >= 14.0 else []) | |
| MPS_HALF_DTYPES = [torch.float16, torch.bfloat16] |
test/inductor/test_mps_basic.py
Outdated
| ) | ||
|
|
||
| @parametrize("dtype", MPS_HALF_DTYPES) | ||
| def test_half_masked(self, dtype): |
There was a problem hiding this comment.
Err, why does it called test_half_masked? when you call sum?
torch/_inductor/codegen/mps.py
Outdated
| return value_to_metal(val) | ||
| raw = value_to_metal(val) | ||
| if ( | ||
| dtype in (torch.bfloat16, torch.float16) |
There was a problem hiding this comment.
Do you know if this condition is ever hit? (I would have checked that one by adding an assert and running test suite)
But I would argue, that math op should return float values, and only downcast to low precision dtype when it's written out (as opmath_t<bfloat> == float)
There was a problem hiding this comment.
Yes, the condition is hit. Verified empirically x + 1.0 with bf16 input calls constant(val=1.0, dtype=torch.bfloat16). Same for x * 2, x - 0.5, and torch.where with scalar constants.
You're right that math ops should return float values. The current fix does a pointless round-trip: the literal is float, we cast it to bfloat, then Metal implicitly promotes it back to float for arithmetic (since loads are already upcast to float32). I'll drop the cast in constant() and just return the raw literal. It's already float-compatible and matches the float32 compute dtype established by load().
torch/_inductor/codegen/mps.py
Outdated
| @@ -244,12 +251,18 @@ def masked(mask: CSEVariable, body: sympy.Expr, other: CSEVariable) -> str: | |||
| with V.kernel.compute.indent(): | |||
| V.kernel.compute.splice(scoped_body) | |||
| V.kernel.compute.writeline(f"{var} = {rc};") | |||
There was a problem hiding this comment.
Don't you need to do it here as well?
torch/_inductor/codegen/mps.py
Outdated
| V.kernel.compute.writeline(f"{var} = {rc};") | ||
| V.kernel.compute.writeline(f"}} else {var} = {other_str};") | ||
| V.kernel.compute.writeline( | ||
| f"}} else {var} = static_cast<{DTYPE_TO_METAL[rc.dtype]}>({other_str});" |
There was a problem hiding this comment.
Nit (rc.dtype is often wrong/undefined)
| f"}} else {var} = static_cast<{DTYPE_TO_METAL[rc.dtype]}>({other_str});" | |
| f"}} else {var} = static_cast<decltype(var)>({other_str});" |
torch/_inductor/codegen/mps.py
Outdated
| c_str = value_to_metal(c) | ||
| if isinstance(b, CSEVariable) and b.dtype in (torch.bfloat16, torch.float16): | ||
| assert b.dtype is not None | ||
| c_str = f"static_cast<{DTYPE_TO_METAL[b.dtype]}>({c_str})" | ||
| return f"{a} ? {b} : {c_str}" |
There was a problem hiding this comment.
Nit (i.e. always leave it to compiler rather than codegen)
| c_str = value_to_metal(c) | |
| if isinstance(b, CSEVariable) and b.dtype in (torch.bfloat16, torch.float16): | |
| assert b.dtype is not None | |
| c_str = f"static_cast<{DTYPE_TO_METAL[b.dtype]}>({c_str})" | |
| return f"{a} ? {b} : {c_str}" | |
| return f"{a} ? {b} : static_cast<decltype({a})>({value_to_metal(c)})" |
test/inductor/test_torchinductor.py
Outdated
| def test_half_constant(self): | ||
| for dtype in [torch.float16, torch.bfloat16]: | ||
| if not self.is_dtype_supported(dtype): | ||
| continue | ||
| self.common( | ||
| lambda x: x + 1.0, | ||
| (make_tensor(1024, dtype=dtype, device=self.device),), | ||
| check_lowp=False, | ||
| ) |
There was a problem hiding this comment.
Nit (check_lowp indeed checks for torch.half)
| def test_half_constant(self): | |
| for dtype in [torch.float16, torch.bfloat16]: | |
| if not self.is_dtype_supported(dtype): | |
| continue | |
| self.common( | |
| lambda x: x + 1.0, | |
| (make_tensor(1024, dtype=dtype, device=self.device),), | |
| check_lowp=False, | |
| ) | |
| def test_bfloat_constant(self): | |
| if not self.is_dtype_supported(torch.bfloat16): | |
| continue | |
| self.common( | |
| lambda x: x + 1.0, | |
| (make_tensor(1024, dtype=torch.bfloat16, device=self.device),), | |
| ) |
test/inductor/test_torchinductor.py
Outdated
| def test_half_reduction(self): | ||
| for dtype in [torch.float16, torch.bfloat16]: | ||
| if not self.is_dtype_supported(dtype): | ||
| continue | ||
| self.common( | ||
| lambda x: x.sum(), | ||
| (make_tensor(1024, dtype=dtype, device=self.device),), | ||
| check_lowp=False, | ||
| ) |
There was a problem hiding this comment.
Alternatively use paramterize
| def test_half_reduction(self): | |
| for dtype in [torch.float16, torch.bfloat16]: | |
| if not self.is_dtype_supported(dtype): | |
| continue | |
| self.common( | |
| lambda x: x.sum(), | |
| (make_tensor(1024, dtype=dtype, device=self.device),), | |
| check_lowp=False, | |
| ) | |
| @paramtertize(dtype, [torch.float16, torch.bfloat16]) | |
| def test_lowp_reduction(self, dtype): | |
| if not self.is_dtype_supported(dtype): | |
| continue | |
| self.common( | |
| lambda x: x.sum(), | |
| (make_tensor(1024, dtype=dtype, device=self.device),), | |
| check_lowp=False, | |
| ) |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda13.0-py3.10-gcc11 / test (default, 1, 5, lf.linux.g6.4xlarge.experimental.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -f "Lint + MPS are green, hopefully other failures are just broken trunk" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (default, 2, 3, macos-m1-stable) Details for Dev Infra teamRaised by workflow job |
…degen Metal Shading Language rejects implicit float-to-bfloat conversions, so bare float literals like `0.0` in generated shaders cause compilation failures when the target variable is `bfloat` (or `half`). Three codegen methods were affected: - `constant()` ignored its `dtype` parameter and returned raw literals. - `masked()` assigned a bare literal in the else-branch (`} else tmp = 0.0;`). - `where()` passed a bare literal through the ternary without casting. All three now emit `static_cast<bfloat>(...)` / `static_cast<half>(...)` where needed. Tests added for half-precision constants, reductions, and conditionals.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@pytorchbot cherry-pick --onto release/2.11 -c critical --fixes "Bug fix for MPS backend using inductor" |
Cherry picking #176436Command Details for Dev Infra teamRaised by workflow job |
…degen (#176436) Metal Shading Language rejects implicit float-to-bfloat conversions, so bare float literals like `0.0` in generated shaders cause compilation failures when the target variable is `bfloat` (or `half`). Three codegen methods were affected: - `constant()` ignored its `dtype` parameter and returned raw literals. - `masked()` assigned a bare literal in the else-branch (`} else tmp = 0.0;`). - `where()` passed a bare literal through the ternary without casting. All three now emit `static_cast<bfloat>(...)` / `static_cast<half>(...)` where needed. Tests added for half-precision constants, reductions, and conditionals. Pull Request resolved: #176436 Approved by: https://github.com/malfet Test plan: Run `python -c "import torch;F=torch.nn.functional;print(torch.compile(lambda x: F.pad(F.gelu(x), [1, 0]))(torch.randn(4, device='mps', dtype=torch.bfloat16)))"` (cherry picked from commit 3b161e7)
…degen (#176436) (#177193) Metal Shading Language rejects implicit float-to-bfloat conversions, so bare float literals like `0.0` in generated shaders cause compilation failures when the target variable is `bfloat` (or `half`). Three codegen methods were affected: - `constant()` ignored its `dtype` parameter and returned raw literals. - `masked()` assigned a bare literal in the else-branch (`} else tmp = 0.0;`). - `where()` passed a bare literal through the ternary without casting. All three now emit `static_cast<bfloat>(...)` / `static_cast<half>(...)` where needed. Tests added for half-precision constants, reductions, and conditionals. Pull Request resolved: #176436 Approved by: https://github.com/malfet Test plan: Run `python -c "import torch;F=torch.nn.functional;print(torch.compile(lambda x: F.pad(F.gelu(x), [1, 0]))(torch.randn(4, device='mps', dtype=torch.bfloat16)))"` (cherry picked from commit 3b161e7) Co-authored-by: Mergen Nachin <mnachin@meta.com>
- Remove `test_bfloat_constant`, `test_lowp_reduction`, and `test_lowp_where` as they don't test for anything beyond what existing tests cover. - Add test_pad_after_gelu as a regression test for Voxtral compilation on MPS, exercising pad(gelu(x)) across fp32, fp16, and bfloat16. Before #176436 test will fail with ``` torch._inductor.exc.InductorError: SyntaxError: failed to compile #include <c10/metal/utils.h> #include <c10/metal/special_math.h> kernel void generated_kernel( device bfloat* out_ptr0, constant bfloat* in_ptr0, uint xindex [[thread_position_in_grid]] ) { int x0 = (xindex) % (17); int x1 = c10::metal::floor_divide(xindex, 17); int x2 = xindex; auto tmp0 = (-1) + x0; auto tmp1 = static_cast<long>(tmp0); auto tmp2 = 0; auto tmp3 = tmp1 >= tmp2; bfloat tmp4; if (tmp3) { auto tmp_scoped_0 = static_cast<float>(in_ptr0[(-1) + x0 + 16*x1]); auto tmp_scoped_1 = static_cast<float>(tmp_scoped_0); auto tmp_scoped_2 = 0.5; auto tmp_scoped_3 = tmp_scoped_1 * tmp_scoped_2; auto tmp_scoped_4 = 0.7071067811865476; auto tmp_scoped_5 = tmp_scoped_1 * tmp_scoped_4; auto tmp_scoped_6 = c10::metal::erf(tmp_scoped_5); auto tmp_scoped_7 = 1.0; auto tmp_scoped_8 = tmp_scoped_6 + tmp_scoped_7; auto tmp_scoped_9 = tmp_scoped_3 * tmp_scoped_8; auto tmp_scoped_10 = static_cast<bfloat>(tmp_scoped_9); tmp4 = tmp_scoped_10; } else tmp4 = 0.0; out_ptr0[x2] = static_cast<bfloat>(tmp4); } with program_source:4495:23: error: assigning to 'bfloat' from incompatible type 'float' } else tmp4 = 0.0; ^~~ ``` Authored with Claude. ghstack-source-id: 7919b53 Pull-Request: #177207
- Remove `test_bfloat_constant`, `test_lowp_reduction`, and `test_lowp_where` as they don't test for anything beyond what existing tests cover. - Add test_pad_after_gelu as a regression test for Voxtral compilation on MPS, exercising pad(gelu(x)) across fp32, fp16, and bfloat16. Before #176436 test will fail with ``` torch._inductor.exc.InductorError: SyntaxError: failed to compile #include <c10/metal/utils.h> #include <c10/metal/special_math.h> kernel void generated_kernel( device bfloat* out_ptr0, constant bfloat* in_ptr0, uint xindex [[thread_position_in_grid]] ) { int x0 = (xindex) % (17); int x1 = c10::metal::floor_divide(xindex, 17); int x2 = xindex; auto tmp0 = (-1) + x0; auto tmp1 = static_cast<long>(tmp0); auto tmp2 = 0; auto tmp3 = tmp1 >= tmp2; bfloat tmp4; if (tmp3) { auto tmp_scoped_0 = static_cast<float>(in_ptr0[(-1) + x0 + 16*x1]); auto tmp_scoped_1 = static_cast<float>(tmp_scoped_0); auto tmp_scoped_2 = 0.5; auto tmp_scoped_3 = tmp_scoped_1 * tmp_scoped_2; auto tmp_scoped_4 = 0.7071067811865476; auto tmp_scoped_5 = tmp_scoped_1 * tmp_scoped_4; auto tmp_scoped_6 = c10::metal::erf(tmp_scoped_5); auto tmp_scoped_7 = 1.0; auto tmp_scoped_8 = tmp_scoped_6 + tmp_scoped_7; auto tmp_scoped_9 = tmp_scoped_3 * tmp_scoped_8; auto tmp_scoped_10 = static_cast<bfloat>(tmp_scoped_9); tmp4 = tmp_scoped_10; } else tmp4 = 0.0; out_ptr0[x2] = static_cast<bfloat>(tmp4); } with program_source:4495:23: error: assigning to 'bfloat' from incompatible type 'float' } else tmp4 = 0.0; ^~~ ``` Authored with Claude. ghstack-source-id: f075662 Pull-Request: #177207
…177207) ---- - Remove `test_bfloat_constant`, `test_lowp_reduction`, and `test_lowp_where` as they don't test for anything beyond what existing tests cover. - Add `test_pad_after_gelu` as a regression test for Voxtral compilation on MPS, exercising pad(gelu(x)) across fp32, fp16, and bfloat16. Before #176436 test will fail with ``` torch._inductor.exc.InductorError: SyntaxError: failed to compile #include <c10/metal/utils.h> #include <c10/metal/special_math.h> kernel void generated_kernel( device bfloat* out_ptr0, constant bfloat* in_ptr0, uint xindex [[thread_position_in_grid]] ) { int x0 = (xindex) % (17); int x1 = c10::metal::floor_divide(xindex, 17); int x2 = xindex; auto tmp0 = (-1) + x0; auto tmp1 = static_cast<long>(tmp0); auto tmp2 = 0; auto tmp3 = tmp1 >= tmp2; bfloat tmp4; if (tmp3) { auto tmp_scoped_0 = static_cast<float>(in_ptr0[(-1) + x0 + 16*x1]); auto tmp_scoped_1 = static_cast<float>(tmp_scoped_0); auto tmp_scoped_2 = 0.5; auto tmp_scoped_3 = tmp_scoped_1 * tmp_scoped_2; auto tmp_scoped_4 = 0.7071067811865476; auto tmp_scoped_5 = tmp_scoped_1 * tmp_scoped_4; auto tmp_scoped_6 = c10::metal::erf(tmp_scoped_5); auto tmp_scoped_7 = 1.0; auto tmp_scoped_8 = tmp_scoped_6 + tmp_scoped_7; auto tmp_scoped_9 = tmp_scoped_3 * tmp_scoped_8; auto tmp_scoped_10 = static_cast<bfloat>(tmp_scoped_9); tmp4 = tmp_scoped_10; } else tmp4 = 0.0; out_ptr0[x2] = static_cast<bfloat>(tmp4); } with program_source:4495:23: error: assigning to 'bfloat' from incompatible type 'float' } else tmp4 = 0.0; ^~~ ``` Authored with Claude. Pull Request resolved: #177207 Approved by: https://github.com/atalman, https://github.com/mergennachin, https://github.com/jansel
…degen (pytorch#176436) Metal Shading Language rejects implicit float-to-bfloat conversions, so bare float literals like `0.0` in generated shaders cause compilation failures when the target variable is `bfloat` (or `half`). Three codegen methods were affected: - `constant()` ignored its `dtype` parameter and returned raw literals. - `masked()` assigned a bare literal in the else-branch (`} else tmp = 0.0;`). - `where()` passed a bare literal through the ternary without casting. All three now emit `static_cast<bfloat>(...)` / `static_cast<half>(...)` where needed. Tests added for half-precision constants, reductions, and conditionals. Pull Request resolved: pytorch#176436 Approved by: https://github.com/malfet
…hader codegen (pytorch#176436)" This reverts commit 4926192. Reverted pytorch#176436 on behalf of https://github.com/zou3519 due to sorry I need to revert this in order to revert pytorch#176606 ([comment](pytorch#176436 (comment)))
…degen (pytorch#176436) Metal Shading Language rejects implicit float-to-bfloat conversions, so bare float literals like `0.0` in generated shaders cause compilation failures when the target variable is `bfloat` (or `half`). Three codegen methods were affected: - `constant()` ignored its `dtype` parameter and returned raw literals. - `masked()` assigned a bare literal in the else-branch (`} else tmp = 0.0;`). - `where()` passed a bare literal through the ternary without casting. All three now emit `static_cast<bfloat>(...)` / `static_cast<half>(...)` where needed. Tests added for half-precision constants, reductions, and conditionals. Pull Request resolved: pytorch#176436 Approved by: https://github.com/malfet
…ytorch#177207) ---- - Remove `test_bfloat_constant`, `test_lowp_reduction`, and `test_lowp_where` as they don't test for anything beyond what existing tests cover. - Add `test_pad_after_gelu` as a regression test for Voxtral compilation on MPS, exercising pad(gelu(x)) across fp32, fp16, and bfloat16. Before pytorch#176436 test will fail with ``` torch._inductor.exc.InductorError: SyntaxError: failed to compile #include <c10/metal/utils.h> #include <c10/metal/special_math.h> kernel void generated_kernel( device bfloat* out_ptr0, constant bfloat* in_ptr0, uint xindex [[thread_position_in_grid]] ) { int x0 = (xindex) % (17); int x1 = c10::metal::floor_divide(xindex, 17); int x2 = xindex; auto tmp0 = (-1) + x0; auto tmp1 = static_cast<long>(tmp0); auto tmp2 = 0; auto tmp3 = tmp1 >= tmp2; bfloat tmp4; if (tmp3) { auto tmp_scoped_0 = static_cast<float>(in_ptr0[(-1) + x0 + 16*x1]); auto tmp_scoped_1 = static_cast<float>(tmp_scoped_0); auto tmp_scoped_2 = 0.5; auto tmp_scoped_3 = tmp_scoped_1 * tmp_scoped_2; auto tmp_scoped_4 = 0.7071067811865476; auto tmp_scoped_5 = tmp_scoped_1 * tmp_scoped_4; auto tmp_scoped_6 = c10::metal::erf(tmp_scoped_5); auto tmp_scoped_7 = 1.0; auto tmp_scoped_8 = tmp_scoped_6 + tmp_scoped_7; auto tmp_scoped_9 = tmp_scoped_3 * tmp_scoped_8; auto tmp_scoped_10 = static_cast<bfloat>(tmp_scoped_9); tmp4 = tmp_scoped_10; } else tmp4 = 0.0; out_ptr0[x2] = static_cast<bfloat>(tmp4); } with program_source:4495:23: error: assigning to 'bfloat' from incompatible type 'float' } else tmp4 = 0.0; ^~~ ``` Authored with Claude. Pull Request resolved: pytorch#177207 Approved by: https://github.com/atalman, https://github.com/mergennachin, https://github.com/jansel
Metal Shading Language rejects implicit float-to-bfloat conversions, so
bare float literals like
0.0in generated shaders cause compilationfailures when the target variable is
bfloat(orhalf). Three codegenmethods were affected:
constant()ignored itsdtypeparameter and returned raw literals.masked()assigned a bare literal in the else-branch (} else tmp = 0.0;).where()passed a bare literal through the ternary without casting.All three now emit
static_cast<bfloat>(...)/static_cast<half>(...)where needed. Tests added for half-precision constants, reductions, and
conditionals.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo