Skip to content

Commit 60cf0b9

Browse files
committed
Update on "[jit] Renamed prim::Concat as prim::VarConcat"
Differential Revision: [D29647586](https://our.internmc.facebook.com/intern/diff/D29647586) [ghstack-poisoned]
2 parents 4ac16bf + 9bcf4f3 commit 60cf0b9

41 files changed

Lines changed: 1473 additions & 1482 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

aten/src/ATen/TensorIterator.cpp

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,38 @@ void TensorIteratorBase::build_borrowing_binary_float_op(const Tensor& out, cons
799799
.add_input(b));
800800
}
801801

802+
void TensorIteratorBase::build_comparison_op(const Tensor& out, const Tensor& a,
803+
const Tensor& b) {
804+
TensorIteratorConfig config;
805+
806+
config.set_check_mem_overlap(true);
807+
config.add_owned_output(out);
808+
config.add_owned_input(a);
809+
config.add_owned_input(b);
810+
config.allow_cpu_scalars(true);
811+
config.promote_inputs_to_common_dtype(true);
812+
813+
// When 'out' isn't defined (e.g. for the functional operator 'a == b'), we
814+
// want the output to be bool. Otherwise (e.g. 'torch.eq(a, b, out=c)') we
815+
// don't coerce the output.
816+
if (!out.defined()) {
817+
config.declare_static_dtype_and_device(kBool, a.device());
818+
}
819+
820+
// Note [special-case bool outputs]
821+
// We explicitly don't call `cast_common_dtype_to_outputs` when the output tensor
822+
// has `bool` dtype. This is a performance optimization: the functional
823+
// version of all comparison/logical ops uses a bool output tensor, and we'd like to
824+
// avoid creating a temporary copy of the output.
825+
// However, note that all kernels using this TensorIterator will need to special-case when
826+
// the output tensor has bool dtype, and provide a lambda of type (scalar_t, scalar_t -> bool).
827+
if (out.defined() && out.scalar_type() != kBool) {
828+
config.cast_common_dtype_to_outputs(true);
829+
}
830+
831+
build(config);
832+
}
833+
802834
// This cannot be a function because TensorIteratorConfig is not
803835
// copyable or movable, so it can't be returned from the function.
804836
#define BINARY_OP_CONFIG() \
@@ -875,33 +907,9 @@ TensorIterator TensorIterator::binary_float_op(Tensor& out, const Tensor& a, con
875907

876908
TensorIterator TensorIterator::comparison_op(Tensor& out, const Tensor& a,
877909
const Tensor& b) {
878-
// Note [special-case bool outputs]
879-
// We explicitly don't call `cast_common_dtype_to_outputs` when the output tensor
880-
// has `bool` dtype. This is a performance optimization: the functional
881-
// version of all comparison/logical ops uses a bool output tensor, and we'd like to
882-
// avoid creating a temporary copy of the output.
883-
// However, note that all kernels using this TensorIterator will need to special-case when
884-
// the output tensor has bool dtype, and provide a lambda of type (scalar_t, scalar_t -> bool).
885-
if (out.scalar_type() == kBool) {
886-
return TensorIteratorConfig()
887-
.set_check_mem_overlap(true)
888-
.add_owned_output(out)
889-
.add_owned_input(a)
890-
.add_owned_input(b)
891-
.allow_cpu_scalars(true)
892-
.promote_inputs_to_common_dtype(true)
893-
.build();
894-
} else {
895-
return TensorIteratorConfig()
896-
.set_check_mem_overlap(true)
897-
.add_owned_output(out)
898-
.add_owned_input(a)
899-
.add_owned_input(b)
900-
.allow_cpu_scalars(true)
901-
.promote_inputs_to_common_dtype(true)
902-
.cast_common_dtype_to_outputs(true)
903-
.build();
904-
}
910+
TensorIterator iter;
911+
iter.build_comparison_op(out, a, b);
912+
return iter;
905913
}
906914

907915
TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a) {

aten/src/ATen/TensorIterator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ struct TORCH_API TensorIteratorBase : public impl::MetaBase {
351351
void build_unary_float_op(const Tensor& out, const Tensor& a);
352352
void build_unary_op(const Tensor& out, const Tensor& a);
353353
void build_unary_force_boolean_op(const Tensor& out, const Tensor& a);
354+
void build_comparison_op(const Tensor& out, const Tensor& a, const Tensor& b);
354355

355356
#undef TORCH_DISALLOW_TEMPORARIES
356357
protected:

aten/src/ATen/native/BinaryOps.cpp

Lines changed: 104 additions & 121 deletions
Large diffs are not rendered by default.

aten/src/ATen/native/BinaryOps.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ inline void sub_check(const Tensor& self, const Scalar& scalar) {
3939
using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
4040
using structured_binary_fn = void(*)(TensorIteratorBase&);
4141

42-
using binary_fn_alpha = void(*)(TensorIterator&, const Scalar& alpha);
42+
using binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
4343
using binary_fn_double = void(*)(TensorIterator&, double);
4444
using binary_fn = void(*)(TensorIterator&);
4545
using binary_clamp_fn_alpha =
@@ -62,12 +62,12 @@ DECLARE_DISPATCH(structured_binary_fn, rshift_stub);
6262
DECLARE_DISPATCH(binary_fn, logical_xor_stub);
6363
DECLARE_DISPATCH(binary_fn, logical_and_stub);
6464
DECLARE_DISPATCH(binary_fn, logical_or_stub);
65-
DECLARE_DISPATCH(binary_fn, lt_stub);
66-
DECLARE_DISPATCH(binary_fn, le_stub);
67-
DECLARE_DISPATCH(binary_fn, gt_stub);
68-
DECLARE_DISPATCH(binary_fn, ge_stub);
69-
DECLARE_DISPATCH(binary_fn, eq_stub);
70-
DECLARE_DISPATCH(binary_fn, ne_stub);
65+
DECLARE_DISPATCH(structured_binary_fn, lt_stub);
66+
DECLARE_DISPATCH(structured_binary_fn, le_stub);
67+
DECLARE_DISPATCH(structured_binary_fn, gt_stub);
68+
DECLARE_DISPATCH(structured_binary_fn, ge_stub);
69+
DECLARE_DISPATCH(structured_binary_fn, eq_stub);
70+
DECLARE_DISPATCH(structured_binary_fn, ne_stub);
7171
DECLARE_DISPATCH(binary_fn, max_elementwise_stub);
7272
DECLARE_DISPATCH(binary_fn, min_elementwise_stub);
7373
DECLARE_DISPATCH(structured_binary_fn, maximum_stub);
@@ -76,9 +76,9 @@ DECLARE_DISPATCH(structured_binary_fn, fmax_stub);
7676
DECLARE_DISPATCH(structured_binary_fn, fmin_stub);
7777
DECLARE_DISPATCH(binary_fn_double, smooth_l1_stub);
7878
DECLARE_DISPATCH(binary_fn_double, huber_stub);
79-
DECLARE_DISPATCH(binary_fn, sigmoid_backward_stub);
79+
DECLARE_DISPATCH(structured_binary_fn, sigmoid_backward_stub);
8080
DECLARE_DISPATCH(binary_fn_alpha, logit_backward_stub);
81-
DECLARE_DISPATCH(binary_fn, tanh_backward_stub);
81+
DECLARE_DISPATCH(structured_binary_fn, tanh_backward_stub);
8282
DECLARE_DISPATCH(binary_fn, mse_stub);
8383
DECLARE_DISPATCH(structured_binary_fn, fmod_stub);
8484
DECLARE_DISPATCH(structured_binary_fn, logaddexp_stub);
@@ -91,7 +91,7 @@ DECLARE_DISPATCH(structured_binary_fn, igammac_stub);
9191
DECLARE_DISPATCH(structured_binary_fn, nextafter_stub);
9292
DECLARE_DISPATCH(structured_binary_fn, heaviside_stub);
9393
DECLARE_DISPATCH(structured_binary_fn, copysign_stub);
94-
DECLARE_DISPATCH(binary_fn, xlogy_stub);
94+
DECLARE_DISPATCH(structured_binary_fn, xlogy_stub);
9595
DECLARE_DISPATCH(structured_binary_fn, xlog1py_stub);
9696
DECLARE_DISPATCH(structured_binary_fn, zeta_stub);
9797

aten/src/ATen/native/cpu/BinaryOpsKernel.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ void rshift_kernel(TensorIteratorBase& iter) {
408408
}
409409
}
410410

411-
void lt_kernel(TensorIterator& iter) {
411+
void lt_kernel(TensorIteratorBase& iter) {
412412
// See Note [special-case bool outputs]
413413
if (iter.dtype() == ScalarType::Bool) {
414414
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "lt_cpu", [&]() {
@@ -431,7 +431,7 @@ void lt_kernel(TensorIterator& iter) {
431431
}
432432
}
433433

434-
void le_kernel(TensorIterator& iter) {
434+
void le_kernel(TensorIteratorBase& iter) {
435435
// See Note [special-case bool outputs]
436436
if (iter.dtype() == ScalarType::Bool) {
437437
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "le_cpu", [&]() {
@@ -454,7 +454,7 @@ void le_kernel(TensorIterator& iter) {
454454
}
455455
}
456456

457-
void gt_kernel(TensorIterator& iter) {
457+
void gt_kernel(TensorIteratorBase& iter) {
458458
// See Note [special-case bool outputs]
459459
if (iter.dtype() == ScalarType::Bool) {
460460
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "gt_cpu", [&]() {
@@ -477,7 +477,7 @@ void gt_kernel(TensorIterator& iter) {
477477
}
478478
}
479479

480-
void ge_kernel(TensorIterator& iter) {
480+
void ge_kernel(TensorIteratorBase& iter) {
481481
// See Note [special-case bool outputs]
482482
if (iter.dtype() == ScalarType::Bool) {
483483
AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "ge_cpu", [&]() {
@@ -500,7 +500,7 @@ void ge_kernel(TensorIterator& iter) {
500500
}
501501
}
502502

503-
void eq_kernel(TensorIterator& iter) {
503+
void eq_kernel(TensorIteratorBase& iter) {
504504
// See Note [special-case bool outputs]
505505
if (iter.dtype() == ScalarType::Bool) {
506506
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "eq_cpu", [&]() {
@@ -523,7 +523,7 @@ void eq_kernel(TensorIterator& iter) {
523523
}
524524
}
525525

526-
void ne_kernel(TensorIterator& iter) {
526+
void ne_kernel(TensorIteratorBase& iter) {
527527
// See Note [special-case bool outputs]
528528
if (iter.dtype() == ScalarType::Bool) {
529529
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.common_dtype(), "ne_cpu", [&]() {
@@ -671,7 +671,7 @@ void huber_kernel(TensorIterator& iter, double delta) {
671671
});
672672
}
673673

674-
void sigmoid_backward_kernel(TensorIterator& iter) {
674+
void sigmoid_backward_kernel(TensorIteratorBase& iter) {
675675
if (isComplexType(iter.dtype())) {
676676
AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "sigmoid_backward_cpu", [&]() {
677677
auto one_vec = Vectorized<scalar_t>(scalar_t{1});
@@ -700,7 +700,7 @@ void sigmoid_backward_kernel(TensorIterator& iter) {
700700
}
701701
}
702702

703-
void logit_backward_kernel(TensorIterator& iter, const Scalar& eps_scalar) {
703+
void logit_backward_kernel(TensorIteratorBase& iter, const Scalar& eps_scalar) {
704704
AT_DISPATCH_FLOATING_TYPES_AND(
705705
kBFloat16, iter.dtype(), "logit_backward_cpu", [&]() {
706706
const scalar_t eps = eps_scalar.to<scalar_t>();
@@ -750,7 +750,7 @@ void logit_backward_kernel(TensorIterator& iter, const Scalar& eps_scalar) {
750750
});
751751
}
752752

753-
void tanh_backward_kernel(TensorIterator& iter) {
753+
void tanh_backward_kernel(TensorIteratorBase& iter) {
754754
if (isComplexType(iter.dtype())) {
755755
AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "tanh_backward_cpu", [&]() {
756756
auto one_vec = Vectorized<scalar_t>(scalar_t{1});
@@ -961,7 +961,7 @@ void copysign_kernel(TensorIteratorBase& iter) {
961961
});
962962
}
963963

964-
void xlogy_kernel(TensorIterator& iter) {
964+
void xlogy_kernel(TensorIteratorBase& iter) {
965965
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "xlogy_cpu", [&]() {
966966
cpu_kernel(iter, [](scalar_t x, scalar_t y) -> scalar_t {
967967
if (at::_isnan(y)){

aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
namespace at {
1515
namespace native {
1616

17-
void sigmoid_backward_kernel_cuda(TensorIterator& iter) {
17+
void sigmoid_backward_kernel_cuda(TensorIteratorBase& iter) {
1818
if(isComplexType(iter.dtype())) {
1919
AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "sigmoid_backward_cuda", [&]() {
2020
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
@@ -30,7 +30,7 @@ void sigmoid_backward_kernel_cuda(TensorIterator& iter) {
3030
}
3131
}
3232

33-
void logit_backward_kernel_cuda(TensorIterator& iter, const Scalar& eps_scalar) {
33+
void logit_backward_kernel_cuda(TensorIteratorBase& iter, const Scalar& eps_scalar) {
3434
AT_DISPATCH_FLOATING_TYPES_AND2(
3535
at::ScalarType::Half,
3636
at::ScalarType::BFloat16,
@@ -63,7 +63,7 @@ void logit_backward_kernel_cuda(TensorIterator& iter, const Scalar& eps_scalar)
6363
});
6464
}
6565

66-
void tanh_backward_kernel_cuda(TensorIterator& iter) {
66+
void tanh_backward_kernel_cuda(TensorIteratorBase& iter) {
6767
if(isComplexType(iter.dtype())) {
6868
AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "tanh_backward_complex_cuda", [&]() {
6969
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {

aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void mse_kernel_cuda(TensorIterator& iter) {
4141
});
4242
}
4343

44-
void xlogy_kernel_cuda(TensorIterator& iter) {
44+
void xlogy_kernel_cuda(TensorIteratorBase& iter) {
4545
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "xlogy_cuda", [&]() {
4646
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t x, scalar_t y) -> scalar_t {
4747
if (at::_isnan(y)){

aten/src/ATen/native/cuda/CompareEQKernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct CompareEqFunctor {
1717
}
1818
};
1919

20-
void eq_kernel_cuda(TensorIterator& iter) {
20+
void eq_kernel_cuda(TensorIteratorBase& iter) {
2121
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "eq_cuda", [&]() {
2222
gpu_kernel_with_scalars(iter, CompareEqFunctor<scalar_t>());
2323
});

aten/src/ATen/native/cuda/CompareGEKernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct CompareGEFunctor {
1717
}
1818
};
1919

20-
void ge_kernel_cuda(TensorIterator& iter) {
20+
void ge_kernel_cuda(TensorIteratorBase& iter) {
2121
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "ge_cuda", [&]() {
2222
gpu_kernel_with_scalars(iter, CompareGEFunctor<scalar_t>());
2323
});

aten/src/ATen/native/cuda/CompareGTKernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct CompareGTFunctor {
1717
}
1818
};
1919

20-
void gt_kernel_cuda(TensorIterator& iter) {
20+
void gt_kernel_cuda(TensorIteratorBase& iter) {
2121
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "gt_cuda", [&]() {
2222
gpu_kernel_with_scalars(iter, CompareGTFunctor<scalar_t>());
2323
});

0 commit comments

Comments
 (0)