@@ -115,13 +115,9 @@ Tensor isfinite(const Tensor& self) {
115115
116116bool is_nonzero (const Tensor& self) {
117117 auto n = self.numel ();
118- AT_ASSERT (n >= 0 );
119- if (n == 0 ) {
120- AT_ERROR (" bool value of Tensor with no values is ambiguous" );
121- }
122- if (n > 1 ) {
123- AT_ERROR (" bool value of Tensor with more than one value is ambiguous" );
124- }
118+ TORCH_CHECK (n != 0 , " Boolean value of Tensor with no values is ambiguous" );
119+ TORCH_CHECK (n < 2 , " Boolean value of Tensor with more than one value is ambiguous" );
120+
125121 Scalar localScalar = self.item ();
126122 if (localScalar.isFloatingPoint ()) {
127123 return localScalar.to <double >() != 0 ;
@@ -132,18 +128,17 @@ bool is_nonzero(const Tensor& self) {
132128 } else if (localScalar.isBoolean ()) {
133129 return localScalar.to <bool >();
134130 }
135- AT_ERROR ( " expected non-Tensor backend scalar" );
131+ TORCH_INTERNAL_ASSERT ( false , " Expected non-Tensor backend scalar" );
136132}
137133
138134Tensor where (const Tensor& condition, const Tensor& self, const Tensor& other) {
139135 TORCH_CHECK (condition.device () == self.device () && self.device () == other.device (),
140- " expected condition, x and y to be on the same device, but condition is on " ,
136+ " Expected condition, x and y to be on the same device, but condition is on " ,
141137 condition.device (), " and x and y are on " , self.device (), " and " , other.device (),
142138 " respectively" );
143- if (condition.scalar_type () != ScalarType::Byte && condition.scalar_type () != ScalarType::Bool) {
144- AT_ERROR (" Expected condition to have ScalarType Byte, but got ScalarType " ,
145- toString (condition.scalar_type ()));
146- }
139+ TORCH_CHECK (condition.scalar_type () == ScalarType::Byte || condition.scalar_type () == ScalarType::Bool,
140+ " Expected condition to have ScalarType Byte, but got ScalarType " ,
141+ toString (condition.scalar_type ()));
147142 Tensor b_condition, b_self, b_other;
148143 std::tie (b_condition, b_self, b_other) = expand_outplace (condition, self, other, " where" );
149144 return at::_s_where (b_condition, b_self, b_other);
0 commit comments