@@ -2657,46 +2657,47 @@ def sample_inputs_msort(op_info, device, dtype, requires_grad):
26572657
26582658 return sample
26592659
2660- def sample_inputs_lerp (op_info , device , dtype , requires_grad ):
2661- def _make_tensor_helper (shape , low = None , high = None ):
2662- return make_tensor (shape , device , dtype , low = low , high = high , requires_grad = requires_grad )
2660+ def sample_inputs_lerp (op_info , device , dtype , requires_grad , ** kwargs ):
2661+ make_arg = partial (make_tensor , dtype = dtype , device = device , requires_grad = requires_grad )
26632662
26642663 samples = (
26652664 # no broadcast
2666- SampleInput (_make_tensor_helper ((S , S )), args = (_make_tensor_helper ((S , S )), 0.4 )),
2665+ SampleInput (make_arg ((S , S )), args = (make_arg ((S , S )), 0.4 )),
26672666 # broadcast rhs
2668- SampleInput (_make_tensor_helper ((S , S )), args = (_make_tensor_helper ((S ,)), 0.4 )),
2667+ SampleInput (make_arg ((S , S )), args = (make_arg ((S ,)), 0.4 )),
26692668 # scalar tensor
2670- SampleInput (_make_tensor_helper (()), args = (_make_tensor_helper (()), 0.4 )),
2669+ SampleInput (make_arg (()), args = (make_arg (()), 0.4 )),
26712670 # broadcast rhs scalar-tensor
2672- SampleInput (_make_tensor_helper ((S , S )), args = (_make_tensor_helper (()), 0.4 )),
2671+ SampleInput (make_arg ((S , S )), args = (make_arg (()), 0.4 )),
26732672 # broadcast rhs with weight tensor
2674- SampleInput (_make_tensor_helper ((S , S )), args = (_make_tensor_helper ((S ,)), _make_tensor_helper ((S , S )))),
2673+ SampleInput (make_arg ((S , S )), args = (make_arg ((S ,)), make_arg ((S , S )))),
26752674 # broadcast rhs and weight tensor
2676- SampleInput (_make_tensor_helper ((S , S )), args = (_make_tensor_helper ((S , 1 )), _make_tensor_helper ((S ,)))),
2677-
2678- # Broadcasts `self` : Issue with inplace-variants
2679- # Reference: https://github.com/pytorch/pytorch/issues/50747
2680- # SampleInput((_make_tensor_helper((S,)), _make_tensor_helper((S, S)), 0.4)),
2681- # SampleInput((_make_tensor_helper(()), _make_tensor_helper((S, S)), 0.4)),
2682- # SampleInput((_make_tensor_helper((S, 1)), _make_tensor_helper((S, S)), 0.4)),
2683- # SampleInput((_make_tensor_helper((S, 1)), _make_tensor_helper((S, S)), _make_tensor_helper((S, 1)))),
2675+ SampleInput (make_arg ((S , S )), args = (make_arg ((S , 1 )), make_arg ((S ,)))),
2676+ # broadcast_lhs
2677+ SampleInput (make_arg ((S ,)), args = (make_arg ((S , S )), 0.4 ), broadcasts_input = True ),
2678+ # scalar broadcast_lhs
2679+ SampleInput (make_arg (()), args = (make_arg ((S , S )), 0.4 ), broadcasts_input = True ),
2680+ # broadcast all
2681+ SampleInput (make_arg ((S , 1 )), args = (make_arg ((S , S )), 0.4 ), broadcasts_input = True ),
2682+ # tensor broadcast all
2683+ SampleInput (make_arg ((S , 1 )), args = (make_arg ((S , S )), make_arg ((S , 1 ))),
2684+ broadcasts_input = True ),
26842685 ) # type: ignore
26852686
26862687 if dtype .is_complex :
26872688 samples = samples + ( # type: ignore
26882689 # no broadcast
2689- SampleInput (_make_tensor_helper ((S , S )), args = (_make_tensor_helper ((S , S )), 0.4j )),
2690- SampleInput (_make_tensor_helper ((S , S )), args = (_make_tensor_helper ((S , S )), 1.2 + 0.1j )),
2690+ SampleInput (make_arg ((S , S )), args = (make_arg ((S , S )), 0.4j )),
2691+ SampleInput (make_arg ((S , S )), args = (make_arg ((S , S )), 1.2 + 0.1j )),
26912692 # broadcast rhs
2692- SampleInput (_make_tensor_helper ((S , S )), args = (_make_tensor_helper ((S ,)), 0.4j )),
2693- SampleInput (_make_tensor_helper ((S , S )), args = (_make_tensor_helper ((S , S )), 5.4 + 9j )),
2693+ SampleInput (make_arg ((S , S )), args = (make_arg ((S ,)), 0.4j )),
2694+ SampleInput (make_arg ((S , S )), args = (make_arg ((S , S )), 5.4 + 9j )),
26942695 # scalar tensor
2695- SampleInput (_make_tensor_helper (()), args = (_make_tensor_helper (()), 0.4j )),
2696- SampleInput (_make_tensor_helper (()), args = (_make_tensor_helper (()), 6.1 + 0.004j )),
2696+ SampleInput (make_arg (()), args = (make_arg (()), 0.4j )),
2697+ SampleInput (make_arg (()), args = (make_arg (()), 6.1 + 0.004j )),
26972698 # broadcast rhs scalar-tensor
2698- SampleInput (_make_tensor_helper ((S , S )), args = (_make_tensor_helper (()), 0.4j )),
2699- SampleInput (_make_tensor_helper ((S , S )), args = (_make_tensor_helper (()), 1 + 2j )),
2699+ SampleInput (make_arg ((S , S )), args = (make_arg (()), 0.4j )),
2700+ SampleInput (make_arg ((S , S )), args = (make_arg (()), 1 + 2j )),
27002701 )
27012702
27022703 return samples
@@ -4405,9 +4406,6 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
44054406 dtypesIfCUDA = floating_and_complex_types_and (torch .half ),
44064407 dtypesIfROCM = floating_and_complex_types_and (torch .half ),
44074408 sample_inputs_func = sample_inputs_lerp ,
4408- skips = (
4409- SkipInfo ('TestOpInfo' , 'test_duplicate_method_tests' ),
4410- ),
44114409 assert_autodiffed = True ),
44124410 OpInfo ('linalg.inv' ,
44134411 aten_name = 'linalg_inv' ,
@@ -5220,10 +5218,6 @@ def method_tests():
52205218 ('remainder' , (S , 1 , S ), (non_differentiable (torch .rand (S , S ) + 1.5 ),), 'tensor_broadcast_all' ),
52215219 ('remainder' , (), (non_differentiable (uniform_scalar (1.5 )),), 'scalar_tensor' ),
52225220 ('remainder' , (), (non_differentiable (torch .rand (S , S , S ) + 1.5 ),), 'scalar_tensor_broadcast_lhs' ),
5223- ('lerp' , (S ,), ((S , S , S ), 0.4 ), 'broadcast_lhs' , (True ,)),
5224- ('lerp' , (S , 1 , S ), ((S , S ), 0.4 ), 'broadcast_all' , (True ,)),
5225- ('lerp' , (), ((S , S , S ), 0.4 ), 'scalar_broadcast_lhs' , (True ,)),
5226- ('lerp' , (S , 1 , S ), ((S , S ), (S , 1 , 1 , S )), 'tensor_broadcast_all' , (True ,)),
52275221 ('kthvalue' , (S , S , S ), (2 ,)),
52285222 ('kthvalue' , (S , S , S ), (2 , 1 ,), 'dim' , (), [1 ]),
52295223 ('kthvalue' , (S , S , S ), (2 , 1 ,), 'dim_alert_nondeterministic' , (), [1 ],
0 commit comments