Skip to content

Commit 0f412aa

Browse files
gchananfacebook-github-bot
authored andcommitted
Move scalar_to_tensor_default_dtype out of ScalarOps.h because it's only useful for torch.where. (#50111)
Summary: Pull Request resolved: #50111 Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D25789638 Pulled By: gchanan fbshipit-source-id: 4254e11e08606b64e393433ef2c169889ff2ac07
1 parent 186fe48 commit 0f412aa

2 files changed

Lines changed: 29 additions & 27 deletions

File tree

aten/src/ATen/ScalarOps.h

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -45,30 +45,4 @@ inline at::Tensor scalar_to_tensor(Scalar s, const Device device = at::kCPU) {
4545
}
4646
}
4747

48-
// The above function is useful for type promotion
49-
// in Binary Ops where one argument is `Tensor` and other is `Scalar`.
50-
// In the above function, we generate wrapped tensor to type with highest
51-
// range and precision based on scalar's type (to support type promotion).
52-
// Eg. Floating Point Types -> Double
53-
// Complex Types -> Complex Double
54-
//
55-
// However for `Scalar-Scalar` Binary Op,we default the type of wrapped tensor
56-
// to the default type corresponding to scalar's type.
57-
inline at::Tensor scalar_to_tensor_default_dtype(
58-
Scalar s,
59-
const Device device = at::kCPU) {
60-
if (s.isFloatingPoint()) {
61-
return at::scalar_tensor(
62-
s, at::device(device).dtype(at::get_default_dtype()));
63-
} else if (s.isBoolean()) {
64-
return at::scalar_tensor(s, at::device(device).dtype(at::kBool));
65-
} else if (s.isComplex()) {
66-
return at::scalar_tensor(
67-
s, at::device(device).dtype(at::get_default_complex_dtype()));
68-
} else {
69-
AT_ASSERT(s.isIntegral(false));
70-
return at::scalar_tensor(s, at::device(device).dtype(at::kLong));
71-
}
72-
}
73-
7448
}

aten/src/ATen/native/TensorCompare.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,35 @@ bool is_nonzero(const Tensor& self) {
195195

196196
namespace {
197197

198-
static Tensor wrapped_scalar_tensor(
198+
// DO NOT USE THIS -- it's just an implementation detail of wrapped_scalar tensor below.
199+
at::Tensor scalar_to_tensor_default_dtype(
200+
Scalar s,
201+
const Device device = at::kCPU) {
202+
if (s.isFloatingPoint()) {
203+
return at::scalar_tensor(
204+
s, at::device(device).dtype(at::get_default_dtype()));
205+
} else if (s.isBoolean()) {
206+
return at::scalar_tensor(s, at::device(device).dtype(at::kBool));
207+
} else if (s.isComplex()) {
208+
return at::scalar_tensor(
209+
s, at::device(device).dtype(at::get_default_complex_dtype()));
210+
} else {
211+
TORCH_INTERNAL_ASSERT(s.isIntegral(false));
212+
return at::scalar_tensor(s, at::device(device).dtype(at::kLong));
213+
}
214+
}
215+
216+
// TLDR: Don't call with `use_default_dtype` true -- this is only necessary to support the partial
217+
// type-promotion that torch.where supports. Once torch.where fully supports type promotion, we
218+
// won't need this function.
219+
//
220+
// Longer explanation:
221+
// `use_default_dtype` is a bit of a hack because torch.where doesn't support type promotion, but
222+
// does support `torch.where(tensor, scalar1, scalar2)` with default scalar types. The trickiness is we
223+
// usually convert double scalars to doubles, and `set_wrapped_number` defines type promotion priority
224+
// as being below tensor types rather than as the default dtype (perhaps we should?). This wouldn't matter
225+
// if we just supported type normal type promotion on torch.where, however.
226+
Tensor wrapped_scalar_tensor(
199227
Scalar scalar,
200228
Device device,
201229
bool use_default_dtype = false) {

0 commit comments

Comments
 (0)