Skip to content

Commit da49573

Browse files
committed
Update on "Stop using an unnecessary scalar_to_tensor(..., device) call."
In this case, the function only dispatches on cpu anyway. Differential Revision: [D25790155](https://our.internmc.facebook.com/intern/diff/D25790155)
2 parents b90ff21 + 033a616 commit da49573

1 file changed

Lines changed: 7 additions & 1 deletion

File tree

aten/src/ATen/native/TensorCompare.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ bool is_nonzero(const Tensor& self) {
195195

196196
namespace {
197197

198+
// DO NOT USE THIS -- it's just an implementation detail of wrapped_scalar tensor below.
198199
at::Tensor scalar_to_tensor_default_dtype(
199200
Scalar s,
200201
const Device device = at::kCPU) {
@@ -207,11 +208,16 @@ at::Tensor scalar_to_tensor_default_dtype(
207208
return at::scalar_tensor(
208209
s, at::device(device).dtype(at::get_default_complex_dtype()));
209210
} else {
210-
AT_ASSERT(s.isIntegral(false));
211+
TORCH_INTERNAL_ASSERT(s.isIntegral(false));
211212
return at::scalar_tensor(s, at::device(device).dtype(at::kLong));
212213
}
213214
}
214215

216+
// TLDR: Don't call with `use_default_dtype` true -- this is only necessary to support the partial
217+
// type-promotion that torch.where supports. Once torch.where fully supports type promotion, we
218+
// won't need this function.
219+
//
220+
// Longer explanation:
215221
// `use_default_dtype` is a bit of a hack because torch.where doesn't support type promotion, but
216222
// does support `torch.where(tensor, scalar1, scalar2)` with default scalar types. The trickiness is we
217223
// usually convert double scalars to doubles, and `set_wrapped_number` defines type promotion priority

0 commit comments

Comments
 (0)