-
Notifications
You must be signed in to change notification settings - Fork 27.3k
Closed
Labels
function requestA request for a new function or the addition of new arguments/modes to an existing function.A request for a new function or the addition of new arguments/modes to an existing function.module: python array apiIssues related to the Python Array APIIssues related to the Python Array APImodule: sorting and selectionmodule: type promotionRelated to semantics of type promotionRelated to semantics of type promotiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
a = torch.Tensor([1, 2, 3])
# works, correctly produces tensor([-1., 2., 3.])
b = torch.where(a == 1, torch.Tensor([-1]), a)
# fails: TypeError: where(): argument 'input' (position 2) must be Tensor, not int
c = torch.where(a == 1, -1, a)
# works in numpy
np.where(a.numpy() == 1, -1, a.numpy())
# both fail: RuntimeError: expected scalar type Long but found Float
d = torch.where(a == 1, torch.tensor(-1), a)
d = torch.where(a == 1, torch.tensor([-1]), a)cc @nairbv @mruberry @rgommers @pmeier @asmeurer @leofang @AnirudhDagar @asi1024 @emcastillo @kmaehashi
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
function requestA request for a new function or the addition of new arguments/modes to an existing function.A request for a new function or the addition of new arguments/modes to an existing function.module: python array apiIssues related to the Python Array APIIssues related to the Python Array APImodule: sorting and selectionmodule: type promotionRelated to semantics of type promotionRelated to semantics of type promotiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module