Skip to content

Commit 0f8a3f0

Browse files
committed
Improve addmm error messages
1 parent fbcf17c commit 0f8a3f0

1 file changed

Lines changed: 6 additions & 5 deletions

File tree

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,18 +174,19 @@ Tensor ger(const Tensor& self, const Tensor& vec2) {
174174

175175
static void addmm_impl_cpu_(
176176
Tensor &result, const Tensor &self, Tensor m1, Tensor m2, Scalar beta, Scalar alpha) {
177-
TORCH_CHECK(self.dim() == 2, "input must be a matrix");
178-
TORCH_CHECK(m1.dim() == 2, "m1 must be a matrix");
179-
TORCH_CHECK(m2.dim() == 2, "m2 must be a matrix");
177+
TORCH_CHECK(self.dim() == 2, "input must be a matrix, got ", self.dim(), "-D tensor");
178+
TORCH_CHECK(m1.dim() == 2, "mat1 must be a matrix, got ", m1.dim(), "-D tensor");
179+
TORCH_CHECK(m2.dim() == 2, "mat2 must be a matrix, got ", m2.dim(), "-D tensor");
180180

181181
TORCH_CHECK(
182-
m1.size(1) == m2.size(0), "m1 and m2 shapes cannot be multiplied (",
182+
m1.size(1) == m2.size(0), "mat1 and mat2 shapes cannot be multiplied (",
183183
m1.size(0), "x", m1.size(1), " and ", m2.size(0), "x", m2.size(1), ")");
184184

185185
TORCH_CHECK(
186186
self.size(0) == m1.size(0) && self.size(1) == m2.size(1),
187187
"input shape is incompatible with matrix multiplication (",
188-
m1.size(0), "x", m1.size(1), " and ", m2.size(0), "x", m2.size(1), ")");
188+
m1.size(0), "x", m1.size(1), " @ ", m2.size(0), "x", m2.size(1), " != ",
189+
self.size(0), "x", self.size(1), ")");
189190

190191
result.resize_as_(self);
191192

0 commit comments

Comments
 (0)