File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -195,6 +195,7 @@ bool is_nonzero(const Tensor& self) {
195195
196196namespace {
197197
198+ // DO NOT USE THIS -- it's just an implementation detail of wrapped_scalar tensor below.
198199at::Tensor scalar_to_tensor_default_dtype (
199200 Scalar s,
200201 const Device device = at::kCPU ) {
@@ -207,11 +208,16 @@ at::Tensor scalar_to_tensor_default_dtype(
207208 return at::scalar_tensor (
208209 s, at::device (device).dtype (at::get_default_complex_dtype ()));
209210 } else {
210- AT_ASSERT (s.isIntegral (false ));
211+ TORCH_INTERNAL_ASSERT (s.isIntegral (false ));
211212 return at::scalar_tensor (s, at::device (device).dtype (at::kLong ));
212213 }
213214}
214215
216+ // TLDR: Don't call with `use_default_dtype` true -- this is only necessary to support the partial
217+ // type-promotion that torch.where supports. Once torch.where fully supports type promotion, we
218+ // won't need this function.
219+ //
220+ // Longer explanation:
215221// `use_default_dtype` is a bit of a hack because torch.where doesn't support type promotion, but
216222// does support `torch.where(tensor, scalar1, scalar2)` with default scalar types. The trickiness is we
217223// usually convert double scalars to doubles, and `set_wrapped_number` defines type promotion priority
You can’t perform that action at this time.
0 commit comments