Skip to content

Commit 11d8311

Browse files
committed
Update on "Making ops c10-full: list of optional tensors"
See for details: https://fb.quip.com/QRtJAin66lPN We need to model optional types explicitly, mostly for schema inference. So we cannot pass a `Tensor?[]` as `ArrayRef<Tensor>`, instead we need to pass it as an optional type. This PR changes it to `torch::List<c10::optional<Tensor>>`. It also makes the ops c10-full that were blocked by this. ## Benchmarks (C++ instruction counts): ### Forward #### Script ```py from torch.utils.benchmark import Timer counts = Timer( stmt=""" auto t = {{op call to measure}}; """, setup=""" using namespace torch::indexing; auto x = torch::ones({4, 4, 4}); """, language="cpp", ).collect_callgrind(number=1_000) print(counts) ``` #### Results | Op call |before |after |delta | | |------------------------------------------------------------------------|---------|--------|-------|------| |x[0] = 1 |11566015 |11566015|0 |0.00% | |x.index({0}) |6807019 |6801019 |-6000 |-0.09%| |x.index({0, 0}) |13529019 |13557019|28000 |0.21% | |x.index({0, 0, 0}) |10677004 |10692004|15000 |0.14% | |x.index({"..."}) |5512015 |5506015 |-6000 |-0.11%| |x.index({Slice(None, None, None)}) |6866016 |6936016 |70000 |1.02% | |x.index({None}) |8554015 |8548015 |-6000 |-0.07%| |x.index({false}) |22400000 |22744000|344000 |1.54% | |x.index({true}) |27624088 |27264393|-359695|-1.30%| |x.index({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})})|123472000|123463306|-8694|-0.01%| ### Autograd #### Script ```py from torch.utils.benchmark import Timer counts = Timer( stmt=""" auto t = {{op call to measure}}; """, setup=""" using namespace torch::indexing; auto x = torch::ones({4, 4, 4}, torch::requires_grad()); """, language="cpp", ).collect_callgrind(number=1_000) print(counts) ``` Note: the script measures the **forward** path of an op call with autograd enabled (i.e. calls into VariableType). It does not measure the backward path. #### Results | Op call |before |after |delta | | |------------------------------------------------------------------------|---------|--------|-------|------| |x.index({0}) |14839019|14833019|-6000| 0.00% | |x.index({0, 0}) |28342019|28370019|28000| 0.00% | |x.index({0, 0, 0}) |24434004|24449004|15000| 0.00% | |x.index({"..."}) |12773015|12767015|-6000| 0.00% | |x.index({Slice(None, None, None)}) |14837016|14907016|70000| 0.47% | |x.index({None}) |15926015|15920015|-6000| 0.00% | |x.index({false}) |36958000|37477000|519000| 1.40% | |x.index({true}) |41971408|42426094|454686| 1.08% | |x.index({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})}) |168184392|164545682|-3638710| -2.16% | ## Backwards Compatibility - This should not break the Python API because the representation in Python is the same and python_arg_parser just transforms the python list into a `List<optional<Tensor>>` instead of into a `List<Tensor>`. - This should not break serialized models because there's some logic that allows loading a serialized `List<Tensor>` as `List<optional<Tensor>>`, see https://github.com/pytorch/pytorch/pull/49138/files#diff-9315f5dd045f47114c677174dcaa2f982721233eee1aa19068a42ff3ef775315R57 - This will break backwards compatibility for the C++ API. There is no implicit conversion from `ArrayRef<Tensor>` (which was the old argument type) to `List<optional<Tensor>>`. One common call pattern is `tensor.index({indices_tensor})`, where indices_tensor is another `Tensor`, and that will continue working because the `{}` initializer_list constructor for `List<optional<Tensor>>` can take `Tensor` elements that are implicitly converted to `optional<Tensor>`, but another common call pattern was `tensor.index(indices_tensor)`, where previously, the `Tensor` got implicitly converted to an `ArrayRef<Tensor>`, and to implicitly convert `Tensor -> optional<Tensor> -> List<optional<Tensor>>` would be two implicit conversions. C++ doesn't allow chaining. two implicit conversions. So those call sites have to be rewritten to `tensor.index({indices_tensor})`. Differential Revision: [D25454632](https://our.internmc.facebook.com/intern/diff/D25454632/) [ghstack-poisoned]
2 parents 1ce09f1 + 15d11f3 commit 11d8311

155 files changed

Lines changed: 2540 additions & 665 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

aten/src/ATen/core/function_schema.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ struct Argument {
107107
c10::optional<int32_t> N_;
108108

109109
c10::optional<IValue> default_value_;
110-
// is this only specifyable as a keyword argument?
110+
// is this only specifiable as a keyword argument?
111111
bool kwarg_only_;
112112
c10::optional<AliasInfo> alias_info_;
113113
};

aten/src/ATen/core/interned_strings.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ namespace c10 {
1717
#define FORALL_NS_SYMBOLS(_) \
1818
_(namespaces, prim) \
1919
_(namespaces, aten) \
20+
_(namespaces, cuda) \
2021
_(namespaces, onnx) \
2122
_(namespaces, attr) \
2223
_(namespaces, scope) \
@@ -284,6 +285,9 @@ namespace c10 {
284285
_(aten, zero_) \
285286
_(aten, fill_) \
286287
_(aten, masked_fill_) \
288+
_(cuda, _set_device) \
289+
_(cuda, set_stream) \
290+
_(cuda, _current_device) \
287291
_(aten, swapaxes) \
288292
_(aten, swapaxes_) \
289293
_(aten, swapdims) \
@@ -383,6 +387,7 @@ namespace c10 {
383387
#define FORALL_NS_SYMBOLS(_) \
384388
_(namespaces, prim) \
385389
_(namespaces, aten) \
390+
_(namespaces, cuda) \
386391
_(namespaces, onnx) \
387392
_(namespaces, attr) \
388393
_(namespaces, scope) \
@@ -453,6 +458,7 @@ struct TORCH_API Symbol {
453458
// (and if it's not, you should add it to the built-ins list above.)
454459
static Symbol attr(const std::string & s);
455460
static Symbol aten(const std::string & s);
461+
static Symbol cuda(const std::string & s);
456462
static Symbol onnx(const std::string & s);
457463
static Symbol prim(const std::string & s);
458464
static Symbol user(const std::string & s);
@@ -463,6 +469,7 @@ struct TORCH_API Symbol {
463469

464470
bool is_attr() const;
465471
bool is_aten() const;
472+
bool is_cuda() const;
466473
bool is_prim() const;
467474
bool is_onnx() const;
468475
bool is_user() const;
@@ -523,6 +530,7 @@ FORALL_NS_SYMBOLS(DEFINE_SYMBOL)
523530

524531
inline Symbol Symbol::attr(const std::string & s) { return Symbol::fromQualString("attr::" + s); }
525532
inline Symbol Symbol::aten(const std::string & s) { return Symbol::fromQualString("aten::" + s); }
533+
inline Symbol Symbol::cuda(const std::string & s) { return Symbol::fromQualString("cuda::" + s); }
526534
inline Symbol Symbol::onnx(const std::string & s) { return Symbol::fromQualString("onnx::" + s); }
527535
inline Symbol Symbol::prim(const std::string & s) { return Symbol::fromQualString("prim::" + s); }
528536
inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); }
@@ -531,6 +539,7 @@ inline Symbol Symbol::caffe2(const std::string & s) { return Symbol::fromQualStr
531539
inline Symbol Symbol::dimname(const std::string & s) { return Symbol::fromQualString("dimname::" + s); }
532540
inline bool Symbol::is_attr() const { return ns() == namespaces::attr; }
533541
inline bool Symbol::is_aten() const { return ns() == namespaces::aten; }
542+
inline bool Symbol::is_cuda() const { return ns() == namespaces::cuda; }
534543
inline bool Symbol::is_prim() const { return ns() == namespaces::prim; }
535544
inline bool Symbol::is_onnx() const { return ns() == namespaces::onnx; }
536545
inline bool Symbol::is_user() const { return ns() == namespaces::user; }

aten/src/ATen/core/ivalue.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ TypePtr IValue::type() const {
125125

126126
void IValue::visit(const std::function<bool (const IValue &)>& visitor) const {
127127
if (visitor(*this)) {
128-
// Short cut.
128+
// Shortcut
129129
return;
130130
}
131131
switch (this->tag) {

aten/src/ATen/core/type.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ MatchTypeReturn matchTypeVariables(
440440
// unknown type).
441441
return matchTypeVariables(opt_formal->getElementType(), actual, type_env);
442442
}
443-
// note: if actual was non here we potentially did not fill in the type
443+
// note: if actual was None here we potentially did not fill in the type
444444
// variables contained in the formal. It is still a valid match because None
445445
// matches Optional[T] later error checking on tryEvalTypeVariables will
446446
// report the problem if we never match variables in type T

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,9 @@ static void apply_orgqr(Tensor& self, const Tensor& tau, int64_t m, int64_t n_co
944944
#endif
945945
}
946946

947-
std::tuple<Tensor, Tensor> _qr_helper_cpu(const Tensor& self, bool some) {
947+
std::tuple<Tensor, Tensor> _linalg_qr_helper_cpu(const Tensor& self, std::string mode) {
948+
bool compute_q, reduced;
949+
std::tie(compute_q, reduced) = _parse_qr_mode(mode);
948950
std::vector<int64_t> infos(batchCount(self), 0);
949951
int64_t m = self.size(-2), n = self.size(-1);
950952

@@ -954,25 +956,22 @@ std::tuple<Tensor, Tensor> _qr_helper_cpu(const Tensor& self, bool some) {
954956
self_sizes[self.dim() - 2] = std::min(m, n);
955957
auto tau_working_copy = at::empty(self_sizes, self.options());
956958
Tensor q_working_copy;
959+
Tensor R;
957960

958961
// Setup input geometry for apply_orgqr
959962
std::vector<int64_t> q_sizes, q_strides;
960963
int64_t n_columns_q;
961-
Tensor R;
962-
std::tie(q_sizes, q_strides, n_columns_q) = _compute_geometry_for_Q(self, some);
964+
std::tie(q_sizes, q_strides, n_columns_q) = _compute_geometry_for_Q(self, reduced);
963965

964966
// If there are no elements, then we simply return a pair of tensors of required dimensions
965967
if (self.numel() == 0) {
966-
// Fix the number of columns of q appropriately
967-
q_sizes[self.dim() - 1] = n_columns_q;
968-
q_working_copy = at::eye(q_sizes[self.dim() - 2], q_sizes[self.dim() - 1], self.options());
969-
q_working_copy = q_working_copy.expand_as(q_working_copy);
970-
971-
// We repurpose the same q_sizes for R
972-
// Fix the number of rows and columns of q_working_copy appropriately
973-
q_sizes[self.dim() - 1] = n;
974-
q_sizes[self.dim() - 2] = n_columns_q;
975-
R = at::empty(q_sizes, self.options());
968+
R = at::empty({n_columns_q, n}, self.options());
969+
if (compute_q) {
970+
int64_t n_rows_q = q_sizes[self.dim() - 2];
971+
q_working_copy = at::eye(n_rows_q, n_columns_q, self.options());
972+
} else {
973+
q_working_copy = at::empty({0}, self.options());
974+
}
976975
return std::make_tuple(q_working_copy, R);
977976
}
978977

@@ -992,6 +991,11 @@ std::tuple<Tensor, Tensor> _qr_helper_cpu(const Tensor& self, bool some) {
992991
}
993992

994993
R = q_working_copy.slice(-2, 0, n_columns_q).slice(-1, 0, n).triu();
994+
if (!compute_q) {
995+
// this is for mode='r'
996+
Tensor empty_Q = at::empty({0}, self.options());
997+
return std::make_tuple(empty_Q, R);
998+
}
995999

9961000
// Next perform ORGQR for Q using the results (both raw R and TAU) from GEQRF
9971001
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "qr_cpu", [&]{
@@ -1005,22 +1009,34 @@ std::tuple<Tensor, Tensor> _qr_helper_cpu(const Tensor& self, bool some) {
10051009
return std::make_tuple(q_working_copy.narrow(-1, 0, n_columns_q), R);
10061010
}
10071011

1008-
std::tuple<Tensor,Tensor> qr(const Tensor& self, bool some) {
1012+
std::tuple<Tensor,Tensor> linalg_qr(const Tensor& self, std::string mode) {
10091013
TORCH_CHECK(self.dim() >= 2,
10101014
"self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
1011-
return at::_qr_helper(self, some);
1015+
return at::_linalg_qr_helper(self, mode);
10121016
}
10131017

1014-
std::tuple<Tensor&,Tensor&> qr_out(Tensor& Q, Tensor& R, const Tensor& self, bool some) {
1018+
std::tuple<Tensor&,Tensor&> linalg_qr_out(Tensor& Q, Tensor& R, const Tensor& self, std::string mode) {
10151019
TORCH_CHECK(self.dim() >= 2,
10161020
"self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
10171021
Tensor Q_tmp, R_tmp;
1018-
std::tie(Q_tmp, R_tmp) = at::_qr_helper(self, some);
1019-
Q.resize_as_(Q_tmp).copy_(Q_tmp);
1020-
R.resize_as_(R_tmp).copy_(R_tmp);
1022+
std::tie(Q_tmp, R_tmp) = at::_linalg_qr_helper(self, mode);
1023+
at::native::resize_output(Q, Q_tmp.sizes());
1024+
Q.copy_(Q_tmp);
1025+
at::native::resize_output(R, R_tmp.sizes());
1026+
R.copy_(R_tmp);
10211027
return std::tuple<Tensor&, Tensor&>(Q, R);
10221028
}
10231029

1030+
std::tuple<Tensor,Tensor> qr(const Tensor& self, bool some) {
1031+
std::string mode = some ? "reduced" : "complete";
1032+
return at::linalg_qr(self, mode);
1033+
}
1034+
1035+
std::tuple<Tensor&,Tensor&> qr_out(Tensor& Q, Tensor& R, const Tensor& self, bool some) {
1036+
std::string mode = some ? "reduced" : "complete";
1037+
return at::linalg_qr_out(Q, R, self, mode);
1038+
}
1039+
10241040
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ syevd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
10251041

10261042
// This function computes eigenvalues 'w' and eigenvectors 'v' of the input that is stored initially in 'v'

aten/src/ATen/native/LinearAlgebraUtils.h

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,16 +192,35 @@ static inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
192192
return self.permute(perm);
193193
}
194194

195+
// parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
196+
static inline std::tuple<bool, bool> _parse_qr_mode(std::string mode) {
197+
bool compute_q;
198+
bool reduced;
199+
if (mode == "reduced") {
200+
compute_q = true;
201+
reduced = true;
202+
} else if (mode == "complete") {
203+
compute_q = true;
204+
reduced = false;
205+
} else if (mode == "r") {
206+
compute_q = false;
207+
reduced = true; // this is actually irrelevant in this mode
208+
} else {
209+
TORCH_CHECK(false, "Unrecognized mode '", mode, "'");
210+
}
211+
return std::make_tuple(compute_q, reduced);
212+
}
213+
195214
// Function to compute sizes, strides and the extra columns for the Q matrix in the QR Decomposition
196215
static inline std::tuple<std::vector<int64_t>,
197216
std::vector<int64_t>,
198-
int64_t> _compute_geometry_for_Q(const Tensor& input, bool some) {
217+
int64_t> _compute_geometry_for_Q(const Tensor& input, bool reduced) {
199218
int64_t m = input.size(-2), n = input.size(-1);
200219
int64_t n_columns_q;
201220

202-
// We need to compute the required size of Q based on the `some` option
221+
// We need to compute the required size of Q based on the `reduced` option
203222
auto q_sizes = input.sizes().vec();
204-
if (!some && m > n) {
223+
if (!reduced && m > n) {
205224
q_sizes[input.dim() - 1] = m;
206225
n_columns_q = m;
207226
} else {

aten/src/ATen/native/Pow.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,9 @@ Tensor& pow_out(Tensor& result, const Tensor& base, Scalar exp) {
3131
"result type ", common_dtype, "can't be cast to the desired output type ",
3232
result.scalar_type());
3333

34-
auto exponent = (exp.isComplex()) ? exp.toComplexDouble() : exp.toDouble();
35-
36-
if (exponent == 0.0) {
34+
if (exp.equal(0.0)) {
3735
result.resize_as_(base).fill_(1);
38-
} else if (exponent == 1.0) {
36+
} else if (exp.equal(1.0)) {
3937
result.resize_as_(base).copy_(base);
4038
} else {
4139
auto iter = TensorIterator::unary_op(result, base.to(common_dtype));

aten/src/ATen/native/UnaryOps.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,12 @@ Tensor& reciprocal_out(Tensor& result, const Tensor& self) { return unary_op_imp
326326
Tensor reciprocal(const Tensor& self) { return unary_op_impl_float(self, reciprocal_stub); }
327327
Tensor& reciprocal_(Tensor& self) { return unary_op_impl_(self, at::reciprocal_out); }
328328

329-
Tensor& rsqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, rsqrt_stub); }
330-
Tensor rsqrt(const Tensor& self) { return unary_op_impl(self, at::rsqrt_out); }
329+
Tensor& rsqrt_out(Tensor& result, const Tensor& self) {
330+
return unary_op_impl_float_out(result, self, rsqrt_stub);
331+
}
332+
Tensor rsqrt(const Tensor& self) {
333+
return unary_op_impl_float(self, rsqrt_stub);
334+
}
331335
Tensor& rsqrt_(Tensor& self) { return unary_op_impl_(self, at::rsqrt_out); }
332336

333337
Tensor& sign_out(Tensor& result, const Tensor& self) {

aten/src/ATen/native/cpu/PowKernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ void pow_tensor_scalar_kernel(TensorIterator& iter, Scalar exp_scalar) {
6363
);
6464
} else if (exp == -0.5) {
6565
cpu_kernel_vec(iter,
66-
[](scalar_t base) -> scalar_t {
66+
[](scalar_t base) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
6767
return 1.0 / std::sqrt(base);
6868
},
6969
[](Vec base) -> Vec { return base.rsqrt(); }

aten/src/ATen/native/cpu/ReduceOpsKernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ static void norm_kernel_tensor_iterator_impl(
225225
binary_kernel_reduce(
226226
iter,
227227
AbsMaxOps<scalar_t, acc_t>(),
228-
std::numeric_limits<acc_t>::min()
228+
acc_t(0)
229229
);
230230
});
231231
} else if (val == -INFINITY) {

0 commit comments

Comments
 (0)