Skip to content

Commit 5ed8aee

Browse files
committed
Update
[ghstack-poisoned]
2 parents 39590b1 + 46754a2 commit 5ed8aee

3 files changed

Lines changed: 2 additions & 24 deletions

File tree

test/test_meta.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1772,27 +1772,6 @@ def test_triangular_solve_out(self):
17721772
self.assertEqual(out[1].shape, meta_out[1].shape)
17731773
self.assertEqual(out[1].dtype, meta_out[1].dtype)
17741774

1775-
def test_meta_consistency_out_dtype_mismatch_pow_Tensor_Scalar(self):
1776-
S = (5,)
1777-
1778-
def run(device):
1779-
a = torch.rand(S, device=device, dtype=torch.float32)
1780-
b = 2
1781-
out = torch.empty(S, device=device, dtype=torch.float64)
1782-
1783-
try:
1784-
torch.pow(a, b, out=out)
1785-
except Exception as e:
1786-
return e
1787-
1788-
cpu_err = run("cpu")
1789-
meta_err = run("meta")
1790-
1791-
if cpu_err is None and meta_err is not None:
1792-
raise RuntimeError("cpu didn't fail, but meta did.") from meta_err
1793-
elif cpu_err is not None and meta_err is None:
1794-
raise RuntimeError("cpu failed, but meta didn't.") from cpu_err
1795-
17961775

17971776
instantiate_device_type_tests(TestMeta, globals())
17981777

test/test_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def reduction_dtype_filter(op):
207207
xfail("softmax"),
208208
xfail("sort"),
209209
xfail("sparse.sampled_addmm"),
210+
xfail("square"),
210211
xfail("squeeze_copy"),
211212
xfail("t_copy"),
212213
xfail("take"),

torch/_refs/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,6 @@ def _make_elementwise_binary_reference(
10621062
supports_rhs_python_scalar=True,
10631063
supports_two_python_scalars=False,
10641064
should_register_decomposition=True,
1065-
exact_dtype=False,
10661065
) -> Callable:
10671066
def inner(prim: Callable):
10681067
nonlocal aten_op, name
@@ -1098,7 +1097,7 @@ def _ref(
10981097
return handle_noncontiguous_outputs([a, b], output)
10991098

11001099
if has_out:
1101-
_ref = out_wrapper(exact_dtype=exact_dtype)(_ref) # type: ignore[assignment]
1100+
_ref = out_wrapper()(_ref) # type: ignore[assignment]
11021101

11031102
_ref.__name__ = name
11041103
if aten_op is infer_aten_op:
@@ -1241,7 +1240,6 @@ def eq(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
12411240

12421241
@_make_elementwise_binary_reference(
12431242
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
1244-
exact_dtype=True,
12451243
)
12461244
def pow(
12471245
a: Union[TensorLikeType, NumberType],

0 commit comments

Comments
 (0)