Skip to content

Commit b27be3e

Browse files
gchananfacebook-github-bot
authored andcommitted
Avoid double dispatch in logical_not for compilation speed reasons. (#38565)
Summary: Pull Request resolved: #38565 Also note this turns on "-Wno-unused-local-typedefs" because we are using dispatch macros for error checking. Test Plan: Imported from OSS Differential Revision: D21598478 Pulled By: gchanan fbshipit-source-id: 28f9ad01bd678df0601a10d0daf3ed31c47c4ab2
1 parent 176174a commit b27be3e

3 files changed

Lines changed: 9 additions & 4 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,7 @@ if(NOT MSVC)
504504
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-variable")
505505
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-function")
506506
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-result")
507+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-local-typedefs")
507508
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-strict-overflow")
508509
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-strict-aliasing")
509510
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=deprecated-declarations")

aten/src/ATen/native/cpu/UnaryOpsKernel.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ static void frac_kernel(TensorIterator& iter) {
129129
}
130130

131131
static void logical_not_kernel(TensorIterator& iter) {
132+
// NOTE: this implementation differs from the CUDA implementation which only does single dispatch
133+
// (to avoid expensive compilation) because CPU kernels don't handle dynamic_casting
134+
// (see needs_dynamic_casting).
132135
AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, iter.dtype(1), "logical_not_cpu", [&]() {
133136
using self_t = scalar_t;
134137
AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, iter.dtype(0), "logical_not_cpu", [&]() {

aten/src/ATen/native/cuda/UnarySignKernels.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
namespace at { namespace native {
1212

1313
void logical_not_kernel_cuda(TensorIterator& iter) {
14+
// error check -- this is just ensuring we don't dispatch on types that aren't in ALL_TYPES_AND2(...)
15+
// so we don't have to maintain a separate list or to do double dispatch.
16+
AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, iter.dtype(0), "logical_not_cuda", [&]() {});
17+
1418
AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, iter.dtype(1), "logical_not_cuda", [&]() {
15-
using self_t = scalar_t;
16-
AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, iter.dtype(0), "logical_not_cuda", [&]() {
17-
gpu_kernel(iter, []GPU_LAMBDA(self_t a) -> scalar_t { return static_cast<scalar_t>(!a); });
18-
});
19+
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> bool { return !a; });
1920
});
2021
}
2122

0 commit comments

Comments
 (0)