Skip to content

Commit 46f158b

Browse files
eqySkylion007
authored andcommitted
[cuDNN] Check shapes during graph capture in cuDNN CTCLoss (#130071)
Found out from #125952 about the existence of `_assert_async`. Pull Request resolved: #130071 Approved by: https://github.com/Skylion007 Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
1 parent 592e3a3 commit 46f158b

1 file changed

Lines changed: 21 additions & 25 deletions

File tree

aten/src/ATen/native/cudnn/LossCTC.cpp

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111
#include <ATen/Functions.h>
1212
#include <ATen/NativeFunctions.h>
1313
#else
14+
#include <ATen/ops/_assert_async.h>
1415
#include <ATen/ops/_cudnn_ctc_loss.h>
1516
#include <ATen/ops/_cudnn_ctc_loss_native.h>
1617
#include <ATen/ops/_use_cudnn_ctc_loss.h>
1718
#include <ATen/ops/_use_cudnn_ctc_loss_native.h>
1819
#include <ATen/ops/empty.h>
1920
#include <ATen/ops/empty_like.h>
21+
#include <ATen/ops/le.h>
22+
#include <ATen/ops/lt.h>
2023
#endif
2124

2225
#if (!AT_CUDNN_ENABLED())
@@ -81,11 +84,6 @@ std::tuple<Tensor, Tensor> _cudnn_ctc_loss_tensor(
8184
namespace at {
8285
namespace native {
8386

84-
namespace {
85-
// "cache" whether we've previously failed the target lengths check
86-
static bool tensor_failed_target_lengths_check = false;
87-
} // namespace
88-
8987
bool _use_cudnn_ctc_loss(
9088
const Tensor& log_probs,
9189
const Tensor& targets,
@@ -132,29 +130,27 @@ bool _use_cudnn_ctc_loss_tensor(
132130
(log_probs.dim() == 3) && (input_lengths.scalar_type() == at::kInt) &&
133131
(target_lengths.scalar_type() == at::kInt);
134132

135-
if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
136-
Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous();
137-
IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
138-
for (const auto b : c10::irange(tl.size())) {
139-
// target length < 256 is documented, but we see illegal memory accesses
140-
// when target lengths > input lengths for CuDNN
141-
Tensor ilc = input_lengths.to(Device(at::kCPU), at::kLong).contiguous();
133+
if (use_cudnn) {
134+
if (at::cuda::currentStreamCaptureStatus() ==
135+
at::cuda::CaptureStatus::None) {
142136
Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous();
143-
IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
144137
IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
145-
use_cudnn = use_cudnn && (tl[b] < 256) && (tl[b] <= il[b]);
146-
if (!use_cudnn) {
147-
tensor_failed_target_lengths_check = true;
148-
break;
138+
for (const auto b : c10::irange(tl.size())) {
139+
// target length < 256 is documented, but we see illegal memory accesses
140+
// when target lengths > input lengths for CuDNN
141+
Tensor ilc = input_lengths.to(Device(at::kCPU), at::kLong).contiguous();
142+
Tensor tlc =
143+
target_lengths.to(Device(at::kCPU), at::kLong).contiguous();
144+
IntArrayRef il(ilc.const_data_ptr<int64_t>(), ilc.numel());
145+
IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
146+
use_cudnn = use_cudnn && (tl[b] < 256) && (tl[b] <= il[b]);
147+
if (!use_cudnn) {
148+
break;
149+
}
149150
}
150-
}
151-
} else {
152-
use_cudnn = use_cudnn && !tensor_failed_target_lengths_check;
153-
if (tensor_failed_target_lengths_check) {
154-
TORCH_WARN(
155-
"cuDNN max target length restriction < 256 cannot be checked during graph capture,"
156-
" but target length >= 256 was observed previously e.g., during warmup, so we"
157-
" presume it is unsafe to dispatch to cuDNN ctc_loss.");
151+
} else {
152+
at::_assert_async(at::lt(input_lengths.max(), 256));
153+
at::_assert_async(at::le(target_lengths, input_lengths).all());
158154
}
159155
}
160156

0 commit comments

Comments
 (0)