@@ -1008,38 +1008,39 @@ bool ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_
10081008 return are_equal;
10091009}
10101010
1011- bool StridedTensorContentEquals (int dim_index, int64_t left_offset, int64_t right_offset,
1012- int elem_size, const Tensor& left, const Tensor& right) {
1011+ namespace {
1012+
1013+ bool StridedIntegerTensorContentEquals (const int dim_index, int64_t left_offset,
1014+ int64_t right_offset, int elem_size,
1015+ const Tensor& left, const Tensor& right) {
1016+ const auto n = left.shape ()[dim_index];
1017+ const auto left_stride = left.strides ()[dim_index];
1018+ const auto right_stride = right.strides ()[dim_index];
10131019 if (dim_index == left.ndim () - 1 ) {
1014- for (int64_t i = 0 ; i < left.shape ()[dim_index]; ++i) {
1015- if (memcmp (left.raw_data () + left_offset + i * left.strides ()[dim_index],
1016- right.raw_data () + right_offset + i * right.strides ()[dim_index],
1017- elem_size) != 0 ) {
1020+ for (int64_t i = 0 ; i < n; ++i) {
1021+ if (memcmp (left.raw_data () + left_offset + i * left_stride,
1022+ right.raw_data () + right_offset + i * right_stride, elem_size) != 0 ) {
10181023 return false ;
10191024 }
10201025 }
10211026 return true ;
10221027 }
1023- for (int64_t i = 0 ; i < left. shape ()[dim_index] ; ++i) {
1024- if (!StridedTensorContentEquals (dim_index + 1 , left_offset, right_offset, elem_size ,
1025- left, right)) {
1028+ for (int64_t i = 0 ; i < n ; ++i) {
1029+ if (!StridedIntegerTensorContentEquals (dim_index + 1 , left_offset, right_offset,
1030+ elem_size, left, right)) {
10261031 return false ;
10271032 }
1028- left_offset += left. strides ()[dim_index] ;
1029- right_offset += right. strides ()[dim_index] ;
1033+ left_offset += left_stride ;
1034+ right_offset += right_stride ;
10301035 }
10311036 return true ;
10321037}
10331038
1034- bool TensorEquals (const Tensor& left, const Tensor& right) {
1039+ bool IntegerTensorEquals (const Tensor& left, const Tensor& right) {
10351040 bool are_equal;
10361041 // The arrays are the same object
10371042 if (&left == &right) {
10381043 are_equal = true ;
1039- } else if (left.type_id () != right.type_id ()) {
1040- are_equal = false ;
1041- } else if (left.size () == 0 ) {
1042- are_equal = true ;
10431044 } else {
10441045 const bool left_row_major_p = left.is_row_major ();
10451046 const bool left_column_major_p = left.is_column_major ();
@@ -1048,14 +1049,9 @@ bool TensorEquals(const Tensor& left, const Tensor& right) {
10481049
10491050 if (!(left_row_major_p && right_row_major_p) &&
10501051 !(left_column_major_p && right_column_major_p)) {
1051- const auto & shape = left.shape ();
1052- if (shape != right.shape ()) {
1053- are_equal = false ;
1054- } else {
1055- const auto & type = checked_cast<const FixedWidthType&>(*left.type ());
1056- are_equal =
1057- StridedTensorContentEquals (0 , 0 , 0 , type.bit_width () / 8 , left, right);
1058- }
1052+ const auto & type = checked_cast<const FixedWidthType&>(*left.type ());
1053+ are_equal =
1054+ StridedIntegerTensorContentEquals (0 , 0 , 0 , type.bit_width () / 8 , left, right);
10591055 } else {
10601056 const auto & size_meta = checked_cast<const FixedWidthType&>(*left.type ());
10611057 const int byte_width = size_meta.bit_width () / CHAR_BIT;
@@ -1071,6 +1067,85 @@ bool TensorEquals(const Tensor& left, const Tensor& right) {
10711067 return are_equal;
10721068}
10731069
1070+ template <typename DataType>
1071+ bool StridedFloatTensorContentEquals (const int dim_index, int64_t left_offset,
1072+ int64_t right_offset, const Tensor& left,
1073+ const Tensor& right, const EqualOptions& opts) {
1074+ using c_type = typename DataType::c_type;
1075+ const auto n = left.shape ()[dim_index];
1076+ const auto left_stride = left.strides ()[dim_index];
1077+ const auto right_stride = right.strides ()[dim_index];
1078+ if (dim_index == left.ndim () - 1 ) {
1079+ auto left_data = left.raw_data ();
1080+ auto right_data = right.raw_data ();
1081+ if (opts.nans_equal ()) {
1082+ for (int64_t i = 0 ; i < n; ++i) {
1083+ c_type left_value =
1084+ *reinterpret_cast <const c_type*>(left_data + left_offset + i * left_stride);
1085+ c_type right_value = *reinterpret_cast <const c_type*>(right_data + right_offset +
1086+ i * right_stride);
1087+ if (!(left_value == right_value ||
1088+ (std::isnan (left_value) && std::isnan (right_value)))) {
1089+ return false ;
1090+ }
1091+ }
1092+ } else {
1093+ for (int64_t i = 0 ; i < n; ++i) {
1094+ c_type left_value =
1095+ *reinterpret_cast <const c_type*>(left_data + left_offset + i * left_stride);
1096+ c_type right_value = *reinterpret_cast <const c_type*>(right_data + right_offset +
1097+ i * right_stride);
1098+ if (left_value != right_value) {
1099+ return false ;
1100+ }
1101+ }
1102+ }
1103+ return true ;
1104+ }
1105+ for (int64_t i = 0 ; i < n; ++i) {
1106+ if (!StridedFloatTensorContentEquals<DataType>(dim_index + 1 , left_offset,
1107+ right_offset, left, right, opts)) {
1108+ return false ;
1109+ }
1110+ left_offset += left_stride;
1111+ right_offset += right_stride;
1112+ }
1113+ return true ;
1114+ }
1115+
1116+ template <typename DataType>
1117+ bool FloatTensorEquals (const Tensor& left, const Tensor& right,
1118+ const EqualOptions& opts) {
1119+ static_assert (std::is_floating_point<typename DataType::c_type>::value,
1120+ " DataType must be a floating point type" );
1121+ return StridedFloatTensorContentEquals<DataType>(0 , 0 , 0 , left, right, opts);
1122+ }
1123+
1124+ } // namespace
1125+
1126+ bool TensorEquals (const Tensor& left, const Tensor& right, const EqualOptions& opts) {
1127+ if (left.type_id () != right.type_id ()) {
1128+ return false ;
1129+ } else if (left.size () == 0 && right.size () == 0 ) {
1130+ return true ;
1131+ } else if (left.shape () != right.shape ()) {
1132+ return false ;
1133+ }
1134+
1135+ switch (left.type_id ()) {
1136+ // TODO: Support half-float tensors
1137+ // case Type::HALF_FLOAT:
1138+ case Type::FLOAT:
1139+ return FloatTensorEquals<FloatType>(left, right, opts);
1140+
1141+ case Type::DOUBLE:
1142+ return FloatTensorEquals<DoubleType>(left, right, opts);
1143+
1144+ default :
1145+ return IntegerTensorEquals (left, right);
1146+ }
1147+ }
1148+
10741149namespace {
10751150
10761151template <typename LeftSparseIndexType, typename RightSparseIndexType>
0 commit comments