Skip to content

Commit 7d029cf

Browse files
committed
Update on "Add cuBLAS path for batched torch.geqrf"
`geqrfBatched` from cuBLAS is used if ``` (input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)) ``` Differential Revision: [D27960156](https://our.internmc.facebook.com/intern/diff/D27960156) [ghstack-poisoned]
2 parents 48ac7b6 + 9ec5c96 commit 7d029cf

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

aten/src/ATen/cuda/CUDABlas.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
206206

207207
template <class Dtype>
208208
void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) {
209-
TORCH_CHECK(
209+
TORCH_INTERNAL_ASSERT(
210210
false,
211211
"at::cuda::blas::geqrfBatched: not implemented for ",
212212
typeid(Dtype).name());

0 commit comments

Comments
 (0)