Skip to content

Commit 5189f92

Browse files
committed
Update on "std/var: Deprecate overloads with "unbiased" argument"
[ghstack-poisoned]
2 parents b4f0373 + 703b72a commit 5189f92

2 files changed

Lines changed: 9 additions & 10 deletions

File tree

test/test_fx_experimental.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1264,7 +1264,8 @@ class TestNormalizeOperators(JitTestCase):
12641264
@ops(op_db, allowed_dtypes=(torch.float,))
12651265
def test_normalize_operator_exhaustive(self, device, dtype, op):
12661266
# Unsupported input types
1267-
if op.name in {'index_put', '__getitem__', 'unfold', 'repeat', 'polygamma'}:
1267+
if op.name in {'index_put', '__getitem__', 'unfold', 'repeat', 'polygamma',
1268+
'std', 'var'}:
12681269
return
12691270
# These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors)
12701271
fx_fail = {'stack', 'hstack', 'vstack', 'dstack',

torch/fx/operator_schemas.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,9 @@ def is_homogeneous_int_tuple(t):
126126
return True
127127
return all(c is int or (c is Ellipsis) for c in contained)
128128

129-
if signature_type is List[int]:
129+
if signature_type is List[int] and is_homogeneous_int_tuple(argument_type):
130130
# Tuple[int] is accepted for List[int] parameters
131-
return is_homogeneous_int_tuple(argument_type)
131+
return True
132132

133133
# Dtype is an int in schemas
134134
if signature_type is int and argument_type is torch.dtype:
@@ -207,16 +207,14 @@ def normalize_function(
207207
arg_types = arg_types if arg_types else cast(Tuple[Any], ())
208208
kwarg_types = kwarg_types if kwarg_types else {}
209209
for candidate_signature in torch_op_schemas:
210+
sig_matches = True
210211
try:
211212
bound_types = candidate_signature.bind(*arg_types, **kwarg_types)
213+
for arg_name, arg_type in bound_types.arguments.items():
214+
param = candidate_signature.parameters[arg_name]
215+
sig_matches = sig_matches and type_matches(param.annotation, arg_type)
212216
except TypeError as e:
213-
continue
214-
215-
sig_matches = True
216-
for arg_name, arg_type in bound_types.arguments.items():
217-
param = candidate_signature.parameters[arg_name]
218-
sig_matches = sig_matches and type_matches(param.annotation, arg_type)
219-
217+
sig_matches = False
220218
if sig_matches:
221219
new_kwargs = _args_kwargs_to_normalized_kwargs(candidate_signature, args, kwargs)
222220
break

0 commit comments

Comments
 (0)