@@ -39,6 +39,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
3939 }
4040 out << " )" ;
4141 }
42+
4243 if (value->undefined () && *value->undefined ()) {
4344 out << " [Undefined]" ;
4445 }
@@ -81,55 +82,9 @@ AnyTypePtr AnyType::get() {
8182 return value;
8283}
8384
84- template <typename T>
85- static bool compatible_optional (c10::optional<T> e, T a) {
86- return !e.has_value () || e.value () == a;
87- }
88-
89- static bool compatible_varying_shape (const VaryingShape& e, at::IntArrayRef a) {
90- if (!e.size ().has_value ()) {
91- return true ;
92- }
93-
94- if (e.size ().value () != a.size ()) {
95- return false ;
96- }
97-
98- auto ndim = a.size ();
99- for (size_t i = 0 ; i < ndim; i++) {
100- if (!compatible_optional (e[i], a[i])) {
101- return false ;
102- }
103- }
104- return true ;
105- }
106-
107- bool TensorType::isCompatibleWithInCurrentExecutionContext (
108- at::Tensor& t) const {
109- // any updates to `isSubtypeOf`, TensorType c-tor or
110- // `isCompatibleWithInCurrentExecutionContext` need to maintain the following
111- // `TensorType::create(actual_tensor)->isSubtypeOf(expected_type)
112- // == expected_type->isCompatibleWithInCurrentExecutionContext(t)`
113- if (!t.defined ()) {
114- return compatible_optional (undefined (), !t.defined ());
115- }
116-
117- return compatible_varying_shape (sizes (), t.sizes ()) &&
118- (t.is_sparse () || t.is_mkldnn () ||
119- compatible_varying_shape (strides (), t.strides ())) &&
120- compatible_optional (
121- requiresGrad (), t.requires_grad () && at::GradMode::is_enabled ()) &&
122- compatible_optional (scalarType (), t.scalar_type ()) &&
123- compatible_optional (device (), t.device ());
124- }
125-
12685TensorTypePtr TensorType::get () {
12786 static auto value = TensorType::create (
128- {},
129- {},
130- VaryingShape{c10::optional<size_t >()},
131- VaryingShape{c10::optional<size_t >()},
132- {});
87+ {}, {}, VaryingShape<ShapeSymbol>{}, VaryingShape<Stride>{}, {});
13388 return value;
13489}
13590
@@ -527,51 +482,111 @@ std::string TensorType::str() const {
527482 return " Tensor" ;
528483}
529484
530- VaryingShape VaryingShape::merge (const VaryingShape& other) const {
485+ template <typename T>
486+ VaryingShape<T> VaryingShape<T>::merge(const VaryingShape<T>& other) const {
531487 if (!dims_ || !other.dims_ || dims_->size () != other.dims_ ->size ()) {
532- return VaryingShape ();
488+ return VaryingShape<T> ();
533489 }
534- ListOfOptionalInts dims;
490+ ListOfOptionalElements dims;
535491 for (size_t i = 0 , n = dims_->size (); i < n; i++) {
536492 dims.push_back (merge_primitive ((*dims_)[i], (*other.dims_ )[i]));
537493 }
538- return VaryingShape (std::move (dims));
494+ return VaryingShape<T> (std::move (dims));
539495}
540496
541- TensorTypePtr TensorType::merge (TensorTypePtr other) const {
497+ VaryingShape<int64_t > TensorType::sizes () const {
498+ if (!sizes_.size ().has_value ()) {
499+ return VaryingShape<int64_t >();
500+ }
501+ return VaryingShape<int64_t >(
502+ fmap (*sizes_.sizes (), [](c10::optional<ShapeSymbol> ss) {
503+ // we turn symbolic shapes into unknowns
504+ return ss.has_value () && ss->is_static ()
505+ ? c10::optional<int64_t >(ss->static_size ())
506+ : c10::nullopt ;
507+ }));
508+ }
509+
510+ TensorTypePtr TensorType::merge (TensorTypePtr other, bool merge_sizes) const {
542511 auto scalar_type = merge_primitive (scalarType (), other->scalarType ());
543512 auto dev = merge_primitive (device (), other->device ());
544- auto sz = sizes ().merge (other->sizes ());
545- auto srs = strides ().merge (other->strides ());
513+ auto sprops = stride_properties ().merge (other->stride_properties ());
546514 auto gr = merge_primitive (requiresGrad (), other->requiresGrad ());
547515 auto undef = merge_primitive (undefined (), other->undefined ());
548- return TensorType::create (scalar_type, dev, sz, srs, gr, undef);
516+ return TensorType::create (
517+ scalar_type,
518+ dev,
519+ merge_sizes ? symbolic_sizes ().merge (other->symbolic_sizes ())
520+ : symbolic_sizes (),
521+ sprops,
522+ gr,
523+ undef);
524+ }
525+
526+ bool TensorType::operator ==(const c10::Type& rhs) const {
527+ if (rhs.kind () != kind ()) {
528+ return false ;
529+ }
530+ auto rt = rhs.expect <TensorType>();
531+
532+ return scalar_type_ == rt->scalarType () && sizes () == rt->sizes () &&
533+ stride_properties () == rt->stride_properties () &&
534+ device () == rt->device () && requiresGrad () == rt->requiresGrad () &&
535+ undefined () == rt->undefined ();
549536}
550537
551- std::ostream& operator <<(std::ostream & out, const VaryingShape & vs) {
538+ template <typename T>
539+ std::ostream& operator <<(std::ostream& out, const VaryingShape<T>& vs) {
540+ out << " (" ;
541+ if (!vs.size ()) {
542+ out << " *)" ;
543+ return out;
544+ }
552545
553- out << " (" ;
554- if (!vs.size ()) {
555- out << " *)" ;
556- return out;
546+ for (size_t i = 0 ; i < vs.size (); i++) {
547+ if (i > 0 ) {
548+ out << " , " ;
557549 }
558-
559- for (size_t i = 0 ; i < vs.size (); i++)
560- {
561- if (i > 0 ) {
562- out << " , " ;
563- }
564- if (vs[i].has_value ())
565- {
566- out << vs[i].value ();
567- }
568- else
569- {
570- out << " *" ;
571- }
550+ if (vs[i].has_value ()) {
551+ out << vs[i].value ();
552+ } else {
553+ out << " *" ;
572554 }
573- out << " )" ;
574- return out;
555+ }
556+ out << " )" ;
557+ return out;
558+ }
559+
560+ template std::ostream& operator <<(
561+ std::ostream& out,
562+ const VaryingShape<int64_t >& vs);
563+ template std::ostream& operator <<(
564+ std::ostream& out,
565+ const VaryingShape<ShapeSymbol>& vs);
566+ template std::ostream& operator <<(
567+ std::ostream& out,
568+ const VaryingShape<Stride>& vs);
569+
570+ std::ostream& operator <<(std::ostream& os, const ShapeSymbol& s) {
571+ os << " SS(" << s.value_ << ' )' ;
572+ return os;
573+ }
574+
575+ std::ostream& operator <<(std::ostream& os, const Stride& s) {
576+ os << " {" ;
577+ if (s.stride_index_ .has_value ()) {
578+ os << *s.stride_index_ ;
579+ } else {
580+ os << " *" ;
581+ }
582+ os << " :" ;
583+ if (s.stride_ .has_value ()) {
584+ os << *s.stride_ ;
585+ } else {
586+ os << " *" ;
587+ }
588+ os << ' }' ;
589+ return os;
575590}
576591
577592TupleTypePtr TupleType::createNamed (
@@ -708,6 +723,180 @@ std::string TupleType::python_str_impl(TypePrinter printer) const {
708723 return ss.str ();
709724}
710725
726+ static std::vector<bool > findContiguous (
727+ const at::IntArrayRef& sizes,
728+ const at::IntArrayRef& strides) {
729+ AT_ASSERT (sizes.size () == strides.size ());
730+ std::vector<bool > cont (sizes.size ());
731+ for (size_t i = 0 ; i < sizes.size (); ++i) {
732+ const auto expected_stride =
733+ (i + 1 < sizes.size ()) ? sizes[i + 1 ] * strides[i + 1 ] : 1 ;
734+ cont[i] = (strides[i] == expected_stride);
735+ }
736+ return cont;
737+ }
738+
739+ VaryingShape<int64_t > TensorType::strides () const {
740+ if (!strides_.size ().has_value ()) {
741+ return VaryingShape<int64_t >();
742+ }
743+ std::vector<c10::optional<int64_t >> ss (*strides_.size ());
744+ for (size_t i = 0 ; i < *strides_.size (); i++) {
745+ if (!strides_[i].has_value ()) {
746+ continue ;
747+ }
748+ auto s = *strides_[i];
749+ if (s.stride_index_ .has_value () && s.stride_ .has_value ()) {
750+ ss[*s.stride_index_ ] = *s.stride_ ;
751+ }
752+ }
753+ return VaryingShape<int64_t >(ss);
754+ }
755+
756+ VaryingShape<Stride> TensorType::computeStrideProps (
757+ at::IntArrayRef sizes,
758+ at::IntArrayRef strides) {
759+ std::vector<size_t > stride_indices (sizes.size ());
760+ std::iota (stride_indices.begin (), stride_indices.end (), 0 );
761+
762+ std::sort (
763+ stride_indices.begin (),
764+ stride_indices.end (),
765+ [&strides](const int & a, const int & b) {
766+ // break ties in case of unsqueezed dims
767+ // i.e. (1, 1, 5)
768+ if (strides[a] == strides[b]) {
769+ return a > b;
770+ }
771+ return strides[a] < strides[b];
772+ });
773+
774+ std::vector<Stride> stride_properties;
775+ for (size_t i = 0 ; i < stride_indices.size (); i++) {
776+ Stride s{stride_indices[i], false , strides[stride_indices[i]]};
777+ // innermost stride expected to be 1
778+ // TODO: turn contiguous_ into an enum CONTIGUOUS, NONCONTIGUOUS,
779+ // BROADCASTED
780+ if (i == 0 ) {
781+ s.contiguous_ = strides[stride_indices[i]] == 1 ;
782+ } else {
783+ s.contiguous_ = strides[stride_indices[i]] == 1 ||
784+ (strides[stride_indices[i]] != 0 &&
785+ strides[stride_indices[i]] ==
786+ strides[stride_indices[i - 1 ]] * sizes[stride_indices[i - 1 ]]);
787+ }
788+ stride_properties.push_back (s);
789+ }
790+
791+ return VaryingShape<Stride>{stride_properties};
792+ }
793+
794+ std::atomic<size_t > ShapeSymbol::num_symbols{1 };
795+
796+ template struct VaryingShape <c10::ShapeSymbol>;
797+ template struct VaryingShape <bool >;
798+ template struct VaryingShape <size_t >;
799+ template struct VaryingShape <int64_t >;
800+
801+ TensorType::TensorType (
802+ c10::optional<at::ScalarType> scalar_type,
803+ c10::optional<Device> device,
804+ const VaryingShape<ShapeSymbol>& sizes,
805+ const VaryingShape<Stride>& strides,
806+ c10::optional<bool > requires_grad,
807+ c10::optional<bool > undefined)
808+ : Type(TypeKind::TensorType),
809+ scalar_type_ (scalar_type),
810+ device_(device),
811+ sizes_(sizes),
812+ strides_(strides),
813+ requires_grad_(requires_grad),
814+ undefined_(undefined) {}
815+
816+ TensorTypePtr TensorType::create (const at::Tensor& t) {
817+ VaryingShape<bool > contiguity;
818+ VaryingShape<size_t > stride_indices;
819+ VaryingShape<int64_t > strides;
820+ VaryingShape<int64_t > sizes;
821+ if (!t.is_mkldnn () && !t.is_sparse ()) {
822+ sizes = VaryingShape<int64_t >{t.sizes ().vec ()};
823+ strides = VaryingShape<int64_t >{t.strides ().vec ()};
824+ return TensorType::create (
825+ t.scalar_type (), t.device (), sizes, strides, t.requires_grad (), false );
826+ }
827+
828+ return TensorType::create (
829+ t.scalar_type (),
830+ t.device (),
831+ VaryingShape<ShapeSymbol>{},
832+ VaryingShape<Stride>{},
833+ t.requires_grad (),
834+ false );
835+ }
836+
837+ TensorTypePtr TensorType::create (
838+ c10::optional<at::ScalarType> scalar_type,
839+ c10::optional<Device> device,
840+ const VaryingShape<int64_t >& sizes,
841+ const VaryingShape<int64_t >& strides,
842+ c10::optional<bool > requires_grad,
843+ c10::optional<bool > undefined) {
844+ TORCH_INTERNAL_ASSERT (sizes.concrete_sizes ().has_value ());
845+ TORCH_INTERNAL_ASSERT (
846+ !strides.concrete_sizes ().has_value () ||
847+ sizes.concrete_sizes ()->size () == strides.concrete_sizes ()->size ());
848+ auto sprops = strides.concrete_sizes ().has_value ()
849+ ? computeStrideProps (*sizes.concrete_sizes (), *strides.concrete_sizes ())
850+ : VaryingShape<Stride>();
851+
852+ auto symbol_sizes =
853+ VaryingShape<ShapeSymbol>::fromStaticShape (*sizes.concrete_sizes ());
854+ return TensorType::create (
855+ scalar_type, device, symbol_sizes, sprops, requires_grad, undefined);
856+ }
857+
858+ TensorTypePtr TensorType::create (
859+ c10::optional<at::ScalarType> scalar_type,
860+ c10::optional<Device> device,
861+ const VaryingShape<ShapeSymbol>& sizes,
862+ const VaryingShape<Stride>& strides,
863+ c10::optional<bool > requires_grad,
864+ c10::optional<bool > undefined) {
865+ return TensorTypePtr (new TensorType (
866+ scalar_type, device, sizes, strides, requires_grad, undefined));
867+ }
868+
869+ TensorTypePtr TensorType::create (
870+ c10::optional<at::ScalarType> scalar_type,
871+ c10::optional<Device> device,
872+ c10::optional<size_t > dim,
873+ c10::optional<bool > requires_grad) {
874+ return TensorType::create (
875+ scalar_type,
876+ device,
877+ VaryingShape<ShapeSymbol>(dim),
878+ VaryingShape<Stride>(dim),
879+ requires_grad);
880+ }
881+
882+ TensorTypePtr TensorType::createContiguous (
883+ at::ScalarType scalar_type,
884+ at::Device device,
885+ at::IntArrayRef sizes) {
886+ auto strides = contiguousStridesOf (sizes);
887+ TORCH_INTERNAL_ASSERT (strides.size () == sizes.size ());
888+ return create (
889+ scalar_type,
890+ device,
891+ VaryingShape<int64_t >(sizes),
892+ VaryingShape<int64_t >(strides),
893+ c10::nullopt );
894+ }
895+
896+ const VaryingShape<ShapeSymbol>& TensorType::symbolic_sizes () const {
897+ return sizes_;
898+ }
899+
711900bool TensorType::isSubtypeOfExt (const TypePtr rhs, std::ostream* why_not) const {
712901 if (auto rhs_p = rhs->cast <TensorType>()) {
713902 // if we have the same pointer, avoid computing the merge
0 commit comments