@@ -408,7 +408,7 @@ void rshift_kernel(TensorIteratorBase& iter) {
408408 }
409409}
410410
411- void lt_kernel (TensorIterator & iter) {
411+ void lt_kernel (TensorIteratorBase & iter) {
412412 // See Note [special-case bool outputs]
413413 if (iter.dtype () == ScalarType::Bool) {
414414 AT_DISPATCH_ALL_TYPES_AND3 (kBool , kBFloat16 , kHalf , iter.common_dtype (), " lt_cpu" , [&]() {
@@ -431,7 +431,7 @@ void lt_kernel(TensorIterator& iter) {
431431 }
432432}
433433
434- void le_kernel (TensorIterator & iter) {
434+ void le_kernel (TensorIteratorBase & iter) {
435435 // See Note [special-case bool outputs]
436436 if (iter.dtype () == ScalarType::Bool) {
437437 AT_DISPATCH_ALL_TYPES_AND3 (kBool , kBFloat16 , kHalf , iter.common_dtype (), " le_cpu" , [&]() {
@@ -454,7 +454,7 @@ void le_kernel(TensorIterator& iter) {
454454 }
455455}
456456
457- void gt_kernel (TensorIterator & iter) {
457+ void gt_kernel (TensorIteratorBase & iter) {
458458 // See Note [special-case bool outputs]
459459 if (iter.dtype () == ScalarType::Bool) {
460460 AT_DISPATCH_ALL_TYPES_AND3 (kBool , kBFloat16 , kHalf , iter.common_dtype (), " gt_cpu" , [&]() {
@@ -477,7 +477,7 @@ void gt_kernel(TensorIterator& iter) {
477477 }
478478}
479479
480- void ge_kernel (TensorIterator & iter) {
480+ void ge_kernel (TensorIteratorBase & iter) {
481481 // See Note [special-case bool outputs]
482482 if (iter.dtype () == ScalarType::Bool) {
483483 AT_DISPATCH_ALL_TYPES_AND3 (kBool , kBFloat16 , kHalf , iter.common_dtype (), " ge_cpu" , [&]() {
@@ -500,7 +500,7 @@ void ge_kernel(TensorIterator& iter) {
500500 }
501501}
502502
503- void eq_kernel (TensorIterator & iter) {
503+ void eq_kernel (TensorIteratorBase & iter) {
504504 // See Note [special-case bool outputs]
505505 if (iter.dtype () == ScalarType::Bool) {
506506 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3 (kBool , kBFloat16 , kHalf , iter.common_dtype (), " eq_cpu" , [&]() {
@@ -523,7 +523,7 @@ void eq_kernel(TensorIterator& iter) {
523523 }
524524}
525525
526- void ne_kernel (TensorIterator & iter) {
526+ void ne_kernel (TensorIteratorBase & iter) {
527527 // See Note [special-case bool outputs]
528528 if (iter.dtype () == ScalarType::Bool) {
529529 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3 (kBool , kBFloat16 , kHalf , iter.common_dtype (), " ne_cpu" , [&]() {
@@ -671,7 +671,7 @@ void huber_kernel(TensorIterator& iter, double delta) {
671671 });
672672}
673673
674- void sigmoid_backward_kernel (TensorIterator & iter) {
674+ void sigmoid_backward_kernel (TensorIteratorBase & iter) {
675675 if (isComplexType (iter.dtype ())) {
676676 AT_DISPATCH_COMPLEX_TYPES (iter.dtype (), " sigmoid_backward_cpu" , [&]() {
677677 auto one_vec = Vectorized<scalar_t >(scalar_t {1 });
@@ -700,7 +700,7 @@ void sigmoid_backward_kernel(TensorIterator& iter) {
700700 }
701701}
702702
703- void logit_backward_kernel (TensorIterator & iter, const Scalar& eps_scalar) {
703+ void logit_backward_kernel (TensorIteratorBase & iter, const Scalar& eps_scalar) {
704704 AT_DISPATCH_FLOATING_TYPES_AND (
705705 kBFloat16 , iter.dtype (), " logit_backward_cpu" , [&]() {
706706 const scalar_t eps = eps_scalar.to <scalar_t >();
@@ -750,7 +750,7 @@ void logit_backward_kernel(TensorIterator& iter, const Scalar& eps_scalar) {
750750 });
751751}
752752
753- void tanh_backward_kernel (TensorIterator & iter) {
753+ void tanh_backward_kernel (TensorIteratorBase & iter) {
754754 if (isComplexType (iter.dtype ())) {
755755 AT_DISPATCH_COMPLEX_TYPES (iter.dtype (), " tanh_backward_cpu" , [&]() {
756756 auto one_vec = Vectorized<scalar_t >(scalar_t {1 });
@@ -961,7 +961,7 @@ void copysign_kernel(TensorIteratorBase& iter) {
961961 });
962962}
963963
964- void xlogy_kernel (TensorIterator & iter) {
964+ void xlogy_kernel (TensorIteratorBase & iter) {
965965 AT_DISPATCH_FLOATING_TYPES_AND2 (kBFloat16 , kHalf , iter.common_dtype (), " xlogy_cpu" , [&]() {
966966 cpu_kernel (iter, [](scalar_t x, scalar_t y) -> scalar_t {
967967 if (at::_isnan (y)){
0 commit comments