Skip to content

Commit 4ed790d

Browse files
Krovatkinfacebook-github-bot
authored andcommitted
Adding symbolic sizes, contiguity, stride indices (#36101)
Summary: Pull Request resolved: #36101 Reviewed By: jamesr66a Differential Revision: D20908711 Pulled By: Krovatkin fbshipit-source-id: f90ce74acffeb645d7d906d07e293164d65ed7e6
1 parent 9e32a1f commit 4ed790d

15 files changed

Lines changed: 503 additions & 250 deletions

File tree

aten/src/ATen/core/jit_type.h

Lines changed: 187 additions & 118 deletions
Large diffs are not rendered by default.

aten/src/ATen/core/type.cpp

Lines changed: 265 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
3939
}
4040
out << ")";
4141
}
42+
4243
if (value->undefined() && *value->undefined()) {
4344
out << "[Undefined]";
4445
}
@@ -81,55 +82,9 @@ AnyTypePtr AnyType::get() {
8182
return value;
8283
}
8384

84-
template <typename T>
85-
static bool compatible_optional(c10::optional<T> e, T a) {
86-
return !e.has_value() || e.value() == a;
87-
}
88-
89-
static bool compatible_varying_shape(const VaryingShape& e, at::IntArrayRef a) {
90-
if (!e.size().has_value()) {
91-
return true;
92-
}
93-
94-
if (e.size().value() != a.size()) {
95-
return false;
96-
}
97-
98-
auto ndim = a.size();
99-
for (size_t i = 0; i < ndim; i++) {
100-
if (!compatible_optional(e[i], a[i])) {
101-
return false;
102-
}
103-
}
104-
return true;
105-
}
106-
107-
bool TensorType::isCompatibleWithInCurrentExecutionContext(
108-
at::Tensor& t) const {
109-
// any updates to `isSubtypeOf`, TensorType c-tor or
110-
// `isCompatibleWithInCurrentExecutionContext` need to maintain the following
111-
// `TensorType::create(actual_tensor)->isSubtypeOf(expected_type)
112-
// == expected_type->isCompatibleWithInCurrentExecutionContext(t)`
113-
if (!t.defined()) {
114-
return compatible_optional(undefined(), !t.defined());
115-
}
116-
117-
return compatible_varying_shape(sizes(), t.sizes()) &&
118-
(t.is_sparse() || t.is_mkldnn() ||
119-
compatible_varying_shape(strides(), t.strides())) &&
120-
compatible_optional(
121-
requiresGrad(), t.requires_grad() && at::GradMode::is_enabled()) &&
122-
compatible_optional(scalarType(), t.scalar_type()) &&
123-
compatible_optional(device(), t.device());
124-
}
125-
12685
TensorTypePtr TensorType::get() {
12786
static auto value = TensorType::create(
128-
{},
129-
{},
130-
VaryingShape{c10::optional<size_t>()},
131-
VaryingShape{c10::optional<size_t>()},
132-
{});
87+
{}, {}, VaryingShape<ShapeSymbol>{}, VaryingShape<Stride>{}, {});
13388
return value;
13489
}
13590

@@ -527,51 +482,111 @@ std::string TensorType::str() const {
527482
return "Tensor";
528483
}
529484

530-
VaryingShape VaryingShape::merge(const VaryingShape& other) const {
485+
template <typename T>
486+
VaryingShape<T> VaryingShape<T>::merge(const VaryingShape<T>& other) const {
531487
if (!dims_ || !other.dims_ || dims_->size() != other.dims_->size()) {
532-
return VaryingShape();
488+
return VaryingShape<T>();
533489
}
534-
ListOfOptionalInts dims;
490+
ListOfOptionalElements dims;
535491
for (size_t i = 0, n = dims_->size(); i < n; i++) {
536492
dims.push_back(merge_primitive((*dims_)[i], (*other.dims_)[i]));
537493
}
538-
return VaryingShape(std::move(dims));
494+
return VaryingShape<T>(std::move(dims));
539495
}
540496

541-
TensorTypePtr TensorType::merge(TensorTypePtr other) const {
497+
VaryingShape<int64_t> TensorType::sizes() const {
498+
if (!sizes_.size().has_value()) {
499+
return VaryingShape<int64_t>();
500+
}
501+
return VaryingShape<int64_t>(
502+
fmap(*sizes_.sizes(), [](c10::optional<ShapeSymbol> ss) {
503+
// we turn symbolic shapes into unknowns
504+
return ss.has_value() && ss->is_static()
505+
? c10::optional<int64_t>(ss->static_size())
506+
: c10::nullopt;
507+
}));
508+
}
509+
510+
TensorTypePtr TensorType::merge(TensorTypePtr other, bool merge_sizes) const {
542511
auto scalar_type = merge_primitive(scalarType(), other->scalarType());
543512
auto dev = merge_primitive(device(), other->device());
544-
auto sz = sizes().merge(other->sizes());
545-
auto srs = strides().merge(other->strides());
513+
auto sprops = stride_properties().merge(other->stride_properties());
546514
auto gr = merge_primitive(requiresGrad(), other->requiresGrad());
547515
auto undef = merge_primitive(undefined(), other->undefined());
548-
return TensorType::create(scalar_type, dev, sz, srs, gr, undef);
516+
return TensorType::create(
517+
scalar_type,
518+
dev,
519+
merge_sizes ? symbolic_sizes().merge(other->symbolic_sizes())
520+
: symbolic_sizes(),
521+
sprops,
522+
gr,
523+
undef);
524+
}
525+
526+
bool TensorType::operator==(const c10::Type& rhs) const {
527+
if (rhs.kind() != kind()) {
528+
return false;
529+
}
530+
auto rt = rhs.expect<TensorType>();
531+
532+
return scalar_type_ == rt->scalarType() && sizes() == rt->sizes() &&
533+
stride_properties() == rt->stride_properties() &&
534+
device() == rt->device() && requiresGrad() == rt->requiresGrad() &&
535+
undefined() == rt->undefined();
549536
}
550537

551-
std::ostream& operator<<(std::ostream & out, const VaryingShape & vs) {
538+
template <typename T>
539+
std::ostream& operator<<(std::ostream& out, const VaryingShape<T>& vs) {
540+
out << "(";
541+
if (!vs.size()) {
542+
out << "*)";
543+
return out;
544+
}
552545

553-
out << "(";
554-
if (!vs.size()) {
555-
out << "*)";
556-
return out;
546+
for (size_t i = 0; i < vs.size(); i++) {
547+
if (i > 0) {
548+
out << ", ";
557549
}
558-
559-
for (size_t i = 0; i < vs.size(); i++)
560-
{
561-
if (i > 0) {
562-
out << ", ";
563-
}
564-
if (vs[i].has_value())
565-
{
566-
out << vs[i].value();
567-
}
568-
else
569-
{
570-
out << "*";
571-
}
550+
if (vs[i].has_value()) {
551+
out << vs[i].value();
552+
} else {
553+
out << "*";
572554
}
573-
out << ")";
574-
return out;
555+
}
556+
out << ")";
557+
return out;
558+
}
559+
560+
template std::ostream& operator<<(
561+
std::ostream& out,
562+
const VaryingShape<int64_t>& vs);
563+
template std::ostream& operator<<(
564+
std::ostream& out,
565+
const VaryingShape<ShapeSymbol>& vs);
566+
template std::ostream& operator<<(
567+
std::ostream& out,
568+
const VaryingShape<Stride>& vs);
569+
570+
std::ostream& operator<<(std::ostream& os, const ShapeSymbol& s) {
571+
os << "SS(" << s.value_ << ')';
572+
return os;
573+
}
574+
575+
std::ostream& operator<<(std::ostream& os, const Stride& s) {
576+
os << "{";
577+
if (s.stride_index_.has_value()) {
578+
os << *s.stride_index_;
579+
} else {
580+
os << "*";
581+
}
582+
os << ":";
583+
if (s.stride_.has_value()) {
584+
os << *s.stride_;
585+
} else {
586+
os << "*";
587+
}
588+
os << '}';
589+
return os;
575590
}
576591

577592
TupleTypePtr TupleType::createNamed(
@@ -708,6 +723,180 @@ std::string TupleType::python_str_impl(TypePrinter printer) const {
708723
return ss.str();
709724
}
710725

726+
static std::vector<bool> findContiguous(
727+
const at::IntArrayRef& sizes,
728+
const at::IntArrayRef& strides) {
729+
AT_ASSERT(sizes.size() == strides.size());
730+
std::vector<bool> cont(sizes.size());
731+
for (size_t i = 0; i < sizes.size(); ++i) {
732+
const auto expected_stride =
733+
(i + 1 < sizes.size()) ? sizes[i + 1] * strides[i + 1] : 1;
734+
cont[i] = (strides[i] == expected_stride);
735+
}
736+
return cont;
737+
}
738+
739+
VaryingShape<int64_t> TensorType::strides() const {
740+
if (!strides_.size().has_value()) {
741+
return VaryingShape<int64_t>();
742+
}
743+
std::vector<c10::optional<int64_t>> ss(*strides_.size());
744+
for (size_t i = 0; i < *strides_.size(); i++) {
745+
if (!strides_[i].has_value()) {
746+
continue;
747+
}
748+
auto s = *strides_[i];
749+
if (s.stride_index_.has_value() && s.stride_.has_value()) {
750+
ss[*s.stride_index_] = *s.stride_;
751+
}
752+
}
753+
return VaryingShape<int64_t>(ss);
754+
}
755+
756+
VaryingShape<Stride> TensorType::computeStrideProps(
757+
at::IntArrayRef sizes,
758+
at::IntArrayRef strides) {
759+
std::vector<size_t> stride_indices(sizes.size());
760+
std::iota(stride_indices.begin(), stride_indices.end(), 0);
761+
762+
std::sort(
763+
stride_indices.begin(),
764+
stride_indices.end(),
765+
[&strides](const int& a, const int& b) {
766+
// break ties in case of unsqueezed dims
767+
// i.e. (1, 1, 5)
768+
if (strides[a] == strides[b]) {
769+
return a > b;
770+
}
771+
return strides[a] < strides[b];
772+
});
773+
774+
std::vector<Stride> stride_properties;
775+
for (size_t i = 0; i < stride_indices.size(); i++) {
776+
Stride s{stride_indices[i], false, strides[stride_indices[i]]};
777+
// innermost stride expected to be 1
778+
// TODO: turn contiguous_ into an enum CONTIGUOUS, NONCONTIGUOUS,
779+
// BROADCASTED
780+
if (i == 0) {
781+
s.contiguous_ = strides[stride_indices[i]] == 1;
782+
} else {
783+
s.contiguous_ = strides[stride_indices[i]] == 1 ||
784+
(strides[stride_indices[i]] != 0 &&
785+
strides[stride_indices[i]] ==
786+
strides[stride_indices[i - 1]] * sizes[stride_indices[i - 1]]);
787+
}
788+
stride_properties.push_back(s);
789+
}
790+
791+
return VaryingShape<Stride>{stride_properties};
792+
}
793+
794+
std::atomic<size_t> ShapeSymbol::num_symbols{1};
795+
796+
template struct VaryingShape<c10::ShapeSymbol>;
797+
template struct VaryingShape<bool>;
798+
template struct VaryingShape<size_t>;
799+
template struct VaryingShape<int64_t>;
800+
801+
TensorType::TensorType(
802+
c10::optional<at::ScalarType> scalar_type,
803+
c10::optional<Device> device,
804+
const VaryingShape<ShapeSymbol>& sizes,
805+
const VaryingShape<Stride>& strides,
806+
c10::optional<bool> requires_grad,
807+
c10::optional<bool> undefined)
808+
: Type(TypeKind::TensorType),
809+
scalar_type_(scalar_type),
810+
device_(device),
811+
sizes_(sizes),
812+
strides_(strides),
813+
requires_grad_(requires_grad),
814+
undefined_(undefined) {}
815+
816+
TensorTypePtr TensorType::create(const at::Tensor& t) {
817+
VaryingShape<bool> contiguity;
818+
VaryingShape<size_t> stride_indices;
819+
VaryingShape<int64_t> strides;
820+
VaryingShape<int64_t> sizes;
821+
if (!t.is_mkldnn() && !t.is_sparse()) {
822+
sizes = VaryingShape<int64_t>{t.sizes().vec()};
823+
strides = VaryingShape<int64_t>{t.strides().vec()};
824+
return TensorType::create(
825+
t.scalar_type(), t.device(), sizes, strides, t.requires_grad(), false);
826+
}
827+
828+
return TensorType::create(
829+
t.scalar_type(),
830+
t.device(),
831+
VaryingShape<ShapeSymbol>{},
832+
VaryingShape<Stride>{},
833+
t.requires_grad(),
834+
false);
835+
}
836+
837+
TensorTypePtr TensorType::create(
838+
c10::optional<at::ScalarType> scalar_type,
839+
c10::optional<Device> device,
840+
const VaryingShape<int64_t>& sizes,
841+
const VaryingShape<int64_t>& strides,
842+
c10::optional<bool> requires_grad,
843+
c10::optional<bool> undefined) {
844+
TORCH_INTERNAL_ASSERT(sizes.concrete_sizes().has_value());
845+
TORCH_INTERNAL_ASSERT(
846+
!strides.concrete_sizes().has_value() ||
847+
sizes.concrete_sizes()->size() == strides.concrete_sizes()->size());
848+
auto sprops = strides.concrete_sizes().has_value()
849+
? computeStrideProps(*sizes.concrete_sizes(), *strides.concrete_sizes())
850+
: VaryingShape<Stride>();
851+
852+
auto symbol_sizes =
853+
VaryingShape<ShapeSymbol>::fromStaticShape(*sizes.concrete_sizes());
854+
return TensorType::create(
855+
scalar_type, device, symbol_sizes, sprops, requires_grad, undefined);
856+
}
857+
858+
TensorTypePtr TensorType::create(
859+
c10::optional<at::ScalarType> scalar_type,
860+
c10::optional<Device> device,
861+
const VaryingShape<ShapeSymbol>& sizes,
862+
const VaryingShape<Stride>& strides,
863+
c10::optional<bool> requires_grad,
864+
c10::optional<bool> undefined) {
865+
return TensorTypePtr(new TensorType(
866+
scalar_type, device, sizes, strides, requires_grad, undefined));
867+
}
868+
869+
TensorTypePtr TensorType::create(
870+
c10::optional<at::ScalarType> scalar_type,
871+
c10::optional<Device> device,
872+
c10::optional<size_t> dim,
873+
c10::optional<bool> requires_grad) {
874+
return TensorType::create(
875+
scalar_type,
876+
device,
877+
VaryingShape<ShapeSymbol>(dim),
878+
VaryingShape<Stride>(dim),
879+
requires_grad);
880+
}
881+
882+
TensorTypePtr TensorType::createContiguous(
883+
at::ScalarType scalar_type,
884+
at::Device device,
885+
at::IntArrayRef sizes) {
886+
auto strides = contiguousStridesOf(sizes);
887+
TORCH_INTERNAL_ASSERT(strides.size() == sizes.size());
888+
return create(
889+
scalar_type,
890+
device,
891+
VaryingShape<int64_t>(sizes),
892+
VaryingShape<int64_t>(strides),
893+
c10::nullopt);
894+
}
895+
896+
const VaryingShape<ShapeSymbol>& TensorType::symbolic_sizes() const {
897+
return sizes_;
898+
}
899+
711900
bool TensorType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const {
712901
if (auto rhs_p = rhs->cast<TensorType>()) {
713902
// if we have the same pointer, avoid computing the merge

0 commit comments

Comments
 (0)