|
11 | 11 | #include <ATen/Functions.h> |
12 | 12 | #include <ATen/NativeFunctions.h> |
13 | 13 | #else |
| 14 | +#include <ATen/ops/_assert_async.h> |
14 | 15 | #include <ATen/ops/_cudnn_ctc_loss.h> |
15 | 16 | #include <ATen/ops/_cudnn_ctc_loss_native.h> |
16 | 17 | #include <ATen/ops/_use_cudnn_ctc_loss.h> |
17 | 18 | #include <ATen/ops/_use_cudnn_ctc_loss_native.h> |
18 | 19 | #include <ATen/ops/empty.h> |
19 | 20 | #include <ATen/ops/empty_like.h> |
| 21 | +#include <ATen/ops/le.h> |
| 22 | +#include <ATen/ops/lt.h> |
20 | 23 | #endif |
21 | 24 |
|
22 | 25 | #if (!AT_CUDNN_ENABLED()) |
@@ -81,11 +84,6 @@ std::tuple<Tensor, Tensor> _cudnn_ctc_loss_tensor( |
81 | 84 | namespace at { |
82 | 85 | namespace native { |
83 | 86 |
|
84 | | -namespace { |
85 | | -// "cache" whether we've previously failed the target lengths check |
86 | | -static bool tensor_failed_target_lengths_check = false; |
87 | | -} // namespace |
88 | | - |
89 | 87 | bool _use_cudnn_ctc_loss( |
90 | 88 | const Tensor& log_probs, |
91 | 89 | const Tensor& targets, |
@@ -132,29 +130,27 @@ bool _use_cudnn_ctc_loss_tensor( |
132 | 130 | (log_probs.dim() == 3) && (input_lengths.scalar_type() == at::kInt) && |
133 | 131 | (target_lengths.scalar_type() == at::kInt); |
134 | 132 |
|
135 | | - if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) { |
136 | | - Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous(); |
137 | | - IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel()); |
138 | | - for (const auto b : c10::irange(tl.size())) { |
139 | | - // target length < 256 is documented, but we see illegal memory accesses |
140 | | - // when target lengths > input lengths for CuDNN |
141 | | - Tensor ilc = input_lengths.to(Device(at::kCPU), at::kLong).contiguous(); |
| 133 | + if (use_cudnn) { |
| 134 | + if (at::cuda::currentStreamCaptureStatus() == |
| 135 | + at::cuda::CaptureStatus::None) { |
142 | 136 | Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous(); |
143 | | - IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel()); |
144 | 137 | IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel()); |
145 | | - use_cudnn = use_cudnn && (tl[b] < 256) && (tl[b] <= il[b]); |
146 | | - if (!use_cudnn) { |
147 | | - tensor_failed_target_lengths_check = true; |
148 | | - break; |
| 138 | + for (const auto b : c10::irange(tl.size())) { |
| 139 | + // target length < 256 is documented, but we see illegal memory accesses |
| 140 | + // when target lengths > input lengths for CuDNN |
| 141 | + Tensor ilc = input_lengths.to(Device(at::kCPU), at::kLong).contiguous(); |
| 142 | + Tensor tlc = |
| 143 | + target_lengths.to(Device(at::kCPU), at::kLong).contiguous(); |
| 144 | + IntArrayRef il(ilc.const_data_ptr<int64_t>(), ilc.numel()); |
| 145 | + IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel()); |
| 146 | + use_cudnn = use_cudnn && (tl[b] < 256) && (tl[b] <= il[b]); |
| 147 | + if (!use_cudnn) { |
| 148 | + break; |
| 149 | + } |
149 | 150 | } |
150 | | - } |
151 | | - } else { |
152 | | - use_cudnn = use_cudnn && !tensor_failed_target_lengths_check; |
153 | | - if (tensor_failed_target_lengths_check) { |
154 | | - TORCH_WARN( |
155 | | - "cuDNN max target length restriction < 256 cannot be checked during graph capture," |
156 | | - " but target length >= 256 was observed previously e.g., during warmup, so we" |
157 | | - " presume it is unsafe to dispatch to cuDNN ctc_loss."); |
| 151 | + } else { |
| 152 | + at::_assert_async(at::lt(input_lengths.max(), 256)); |
| 153 | + at::_assert_async(at::le(target_lengths, input_lengths).all()); |
158 | 154 | } |
159 | 155 | } |
160 | 156 |
|
|
0 commit comments