@@ -22,9 +22,10 @@ Tensor& pow_out(Tensor& result, const Tensor& base, Scalar exp) {
2222 TORCH_CHECK (!(isIntegralType (base.scalar_type (), true ) &&
2323 exp.isIntegral (true ) && exp.toLong () < 0 ),
2424 " Integers to negative integer powers are not allowed." );
25- if (exp.toDouble () == 0.0 ) {
25+ // Avoid runtime error when typecasting
26+ if (!exp.isComplex () && (exp.toDouble () == 0.0 )) {
2627 result.resize_as_ (base).fill_ (1 );
27- } else if (exp.toDouble () == 1.0 ) {
28+ } else if (! exp.isComplex () && (exp. toDouble () == 1.0 ) ) {
2829 result.resize_as_ (base).copy_ (base);
2930 } else {
3031 auto iter = TensorIterator::unary_op (result, base,
@@ -52,12 +53,28 @@ Tensor& pow_(Tensor& base, Scalar alpha) {
5253}
5354
5455Tensor pow (const Tensor& base, const Tensor& exp) {
55- Tensor result = at::empty ({0 }, base.options ());
56+ // If the exponent is complex, the result needs to be complex
57+ // we can't rely on result_type because it will break current
58+ // handling
59+ // TODO: change it to use type promotion after #37098 is merged
60+ ScalarType dtype = (exp.is_complex () ? exp.scalar_type () : base.scalar_type ());
61+ Tensor result = at::empty ({0 }, base.options ().dtype (dtype));
5662 return native::pow_out (result, base, exp);
5763}
5864
5965Tensor pow (const Tensor& base, Scalar exp) {
60- Tensor result = at::empty_like (base, MemoryFormat::Preserve);
66+ // If the exponent is complex, the result needs to be complex
67+ // we can't rely on result_type because it will break current
68+ // handling for other datatypes
69+ // TODO: change it to use type promotion after #37098 is merged
70+ ScalarType dtype = (exp.isComplex () ? exp.type () : base.scalar_type ());
71+ Tensor result = at::empty ({0 }, base.options ().dtype (dtype));
72+ if (exp.isComplex ()) {
73+ // The type checking logic in unary_op TensorIterator does not allow
74+ // a float tensor to output to a complex tensor, but binary ops allow it
75+ // so we create a tensor for the exponent to avoid using this iterator until its fixed
76+ return native::pow_out (result, base, c10::scalar_to_tensor (exp, base.device ()));
77+ }
6178 return native::pow_out (result, base, exp);
6279}
6380
0 commit comments