Skip to content

Commit 0ea4eb7

Browse files
kshitij12345facebook-github-bot
authored andcommitted
[opinfo] torch.lerp: move remaining cases from tensor_methods to opinfo (#55665)
Summary: Fixes : #54304 Reference: #54261 Pull Request resolved: #55665 Reviewed By: bdhirsh Differential Revision: D27845528 Pulled By: mruberry fbshipit-source-id: 36bdf14c4923a83fb8e4f4d361467d9568784011
1 parent df8bb5a commit 0ea4eb7

1 file changed

Lines changed: 25 additions & 31 deletions

File tree

torch/testing/_internal/common_methods_invocations.py

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2657,46 +2657,47 @@ def sample_inputs_msort(op_info, device, dtype, requires_grad):
26572657

26582658
return sample
26592659

2660-
def sample_inputs_lerp(op_info, device, dtype, requires_grad):
2661-
def _make_tensor_helper(shape, low=None, high=None):
2662-
return make_tensor(shape, device, dtype, low=low, high=high, requires_grad=requires_grad)
2660+
def sample_inputs_lerp(op_info, device, dtype, requires_grad, **kwargs):
2661+
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
26632662

26642663
samples = (
26652664
# no broadcast
2666-
SampleInput(_make_tensor_helper((S, S)), args=(_make_tensor_helper((S, S)), 0.4)),
2665+
SampleInput(make_arg((S, S)), args=(make_arg((S, S)), 0.4)),
26672666
# broadcast rhs
2668-
SampleInput(_make_tensor_helper((S, S)), args=(_make_tensor_helper((S,)), 0.4)),
2667+
SampleInput(make_arg((S, S)), args=(make_arg((S,)), 0.4)),
26692668
# scalar tensor
2670-
SampleInput(_make_tensor_helper(()), args=(_make_tensor_helper(()), 0.4)),
2669+
SampleInput(make_arg(()), args=(make_arg(()), 0.4)),
26712670
# broadcast rhs scalar-tensor
2672-
SampleInput(_make_tensor_helper((S, S)), args=(_make_tensor_helper(()), 0.4)),
2671+
SampleInput(make_arg((S, S)), args=(make_arg(()), 0.4)),
26732672
# broadcast rhs with weight tensor
2674-
SampleInput(_make_tensor_helper((S, S)), args=(_make_tensor_helper((S,)), _make_tensor_helper((S, S)))),
2673+
SampleInput(make_arg((S, S)), args=(make_arg((S,)), make_arg((S, S)))),
26752674
# broadcast rhs and weight tensor
2676-
SampleInput(_make_tensor_helper((S, S)), args=(_make_tensor_helper((S, 1)), _make_tensor_helper((S,)))),
2677-
2678-
# Broadcasts `self` : Issue with inplace-variants
2679-
# Reference: https://github.com/pytorch/pytorch/issues/50747
2680-
# SampleInput((_make_tensor_helper((S,)), _make_tensor_helper((S, S)), 0.4)),
2681-
# SampleInput((_make_tensor_helper(()), _make_tensor_helper((S, S)), 0.4)),
2682-
# SampleInput((_make_tensor_helper((S, 1)), _make_tensor_helper((S, S)), 0.4)),
2683-
# SampleInput((_make_tensor_helper((S, 1)), _make_tensor_helper((S, S)), _make_tensor_helper((S, 1)))),
2675+
SampleInput(make_arg((S, S)), args=(make_arg((S, 1)), make_arg((S,)))),
2676+
# broadcast_lhs
2677+
SampleInput(make_arg((S,)), args=(make_arg((S, S)), 0.4), broadcasts_input=True),
2678+
# scalar broadcast_lhs
2679+
SampleInput(make_arg(()), args=(make_arg((S, S)), 0.4), broadcasts_input=True),
2680+
# broadcast all
2681+
SampleInput(make_arg((S, 1)), args=(make_arg((S, S)), 0.4), broadcasts_input=True),
2682+
# tensor broadcast all
2683+
SampleInput(make_arg((S, 1)), args=(make_arg((S, S)), make_arg((S, 1))),
2684+
broadcasts_input=True),
26842685
) # type: ignore
26852686

26862687
if dtype.is_complex:
26872688
samples = samples + ( # type: ignore
26882689
# no broadcast
2689-
SampleInput(_make_tensor_helper((S, S)), args=(_make_tensor_helper((S, S)), 0.4j)),
2690-
SampleInput(_make_tensor_helper((S, S)), args=(_make_tensor_helper((S, S)), 1.2 + 0.1j)),
2690+
SampleInput(make_arg((S, S)), args=(make_arg((S, S)), 0.4j)),
2691+
SampleInput(make_arg((S, S)), args=(make_arg((S, S)), 1.2 + 0.1j)),
26912692
# broadcast rhs
2692-
SampleInput(_make_tensor_helper((S, S)), args=(_make_tensor_helper((S,)), 0.4j)),
2693-
SampleInput(_make_tensor_helper((S, S)), args=(_make_tensor_helper((S, S)), 5.4 + 9j)),
2693+
SampleInput(make_arg((S, S)), args=(make_arg((S,)), 0.4j)),
2694+
SampleInput(make_arg((S, S)), args=(make_arg((S, S)), 5.4 + 9j)),
26942695
# scalar tensor
2695-
SampleInput(_make_tensor_helper(()), args=(_make_tensor_helper(()), 0.4j)),
2696-
SampleInput(_make_tensor_helper(()), args=(_make_tensor_helper(()), 6.1 + 0.004j)),
2696+
SampleInput(make_arg(()), args=(make_arg(()), 0.4j)),
2697+
SampleInput(make_arg(()), args=(make_arg(()), 6.1 + 0.004j)),
26972698
# broadcast rhs scalar-tensor
2698-
SampleInput(_make_tensor_helper((S, S)), args=(_make_tensor_helper(()), 0.4j)),
2699-
SampleInput(_make_tensor_helper((S, S)), args=(_make_tensor_helper(()), 1 + 2j)),
2699+
SampleInput(make_arg((S, S)), args=(make_arg(()), 0.4j)),
2700+
SampleInput(make_arg((S, S)), args=(make_arg(()), 1 + 2j)),
27002701
)
27012702

27022703
return samples
@@ -4405,9 +4406,6 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
44054406
dtypesIfCUDA=floating_and_complex_types_and(torch.half),
44064407
dtypesIfROCM=floating_and_complex_types_and(torch.half),
44074408
sample_inputs_func=sample_inputs_lerp,
4408-
skips=(
4409-
SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),
4410-
),
44114409
assert_autodiffed=True),
44124410
OpInfo('linalg.inv',
44134411
aten_name='linalg_inv',
@@ -5220,10 +5218,6 @@ def method_tests():
52205218
('remainder', (S, 1, S), (non_differentiable(torch.rand(S, S) + 1.5),), 'tensor_broadcast_all'),
52215219
('remainder', (), (non_differentiable(uniform_scalar(1.5)),), 'scalar_tensor'),
52225220
('remainder', (), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'scalar_tensor_broadcast_lhs'),
5223-
('lerp', (S,), ((S, S, S), 0.4), 'broadcast_lhs', (True,)),
5224-
('lerp', (S, 1, S), ((S, S), 0.4), 'broadcast_all', (True,)),
5225-
('lerp', (), ((S, S, S), 0.4), 'scalar_broadcast_lhs', (True,)),
5226-
('lerp', (S, 1, S), ((S, S), (S, 1, 1, S)), 'tensor_broadcast_all', (True,)),
52275221
('kthvalue', (S, S, S), (2,)),
52285222
('kthvalue', (S, S, S), (2, 1,), 'dim', (), [1]),
52295223
('kthvalue', (S, S, S), (2, 1,), 'dim_alert_nondeterministic', (), [1],

0 commit comments

Comments
 (0)