Skip to content

Commit 9c4683e

Browse files
Mike Ruberryfacebook-github-bot
authored andcommitted
Revert D20312366: [pytorch][PR] Added type promotion logic for complex numbers
Test Plan: revert-hammer Differential Revision: D20312366 Original commit changeset: 90f00a1a916d fbshipit-source-id: 4510739a888b2eec5d8a72e792998ac46da6d82a
1 parent 0d8447a commit 9c4683e

9 files changed

Lines changed: 27 additions & 91 deletions

File tree

aten/src/ATen/native/TypeProperties.cpp

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,10 @@ static inline ScalarType promote_skip_undefined(ScalarType a, ScalarType b) {
5757

5858

5959
static inline ScalarType combine_categories(ScalarType higher, ScalarType lower) {
60-
if(isComplexType(higher)) {
60+
if (isFloatingType(higher)) {
6161
return higher;
6262
}
63-
else if(!isComplexType(lower) && isFloatingType(higher)) {
64-
return higher;
65-
}
66-
if (higher == ScalarType::Bool || isFloatingType(lower) || isComplexType(lower)) {
63+
if (higher == ScalarType::Bool || isFloatingType(lower)) {
6764
return promote_skip_undefined(higher, lower);
6865
}
6966
if (higher != ScalarType::Undefined) {
@@ -78,14 +75,8 @@ ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeS
7875
}
7976
ResultTypeState new_state = in_state;
8077
ScalarType current = tensor.scalar_type();
81-
if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
82-
auto current_default = typeMetaToScalarType(at::get_default_dtype());
83-
if(isComplexType(current)) {
84-
current = typeMetaToScalarType(at::get_default_complex_dtype());
85-
}
86-
else if(isFloatingType(current)) {
87-
current = current_default;
88-
}
78+
if (tensor.unsafeGetTensorImpl()->is_wrapped_number() && isFloatingType(current)) {
79+
current = typeMetaToScalarType(at::get_default_dtype());
8980
}
9081
if ( tensor.dim() > 0 ) {
9182
new_state.dimResult = promote_skip_undefined(in_state.dimResult, current);
@@ -94,6 +85,7 @@ ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeS
9485
} else {
9586
new_state.zeroResult = promote_skip_undefined(in_state.zeroResult, current);
9687
}
88+
9789
return new_state;
9890
}
9991

@@ -106,6 +98,7 @@ ScalarType result_type(TensorList tensors) {
10698
for (const Tensor& tensor : tensors) {
10799
state = update_result_type_state(tensor, state);
108100
}
101+
109102
return result_type(state);
110103
}
111104

aten/src/ATen/native/cuda/CUDAScalar.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace native {
99

1010
Scalar _local_scalar_dense_cuda(const Tensor& self) {
1111
Scalar r;
12-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
12+
AT_DISPATCH_ALL_TYPES_AND3(
1313
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_local_scalar_dense_cuda", [&] {
1414
scalar_t value;
1515
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

c10/core/DefaultDtype.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,12 @@
33

44
namespace c10 {
55
static auto default_dtype = caffe2::TypeMeta::Make<float>();
6-
static auto default_complex_dtype = caffe2::TypeMeta::Make<std::complex<float>>();
76

87
void set_default_dtype(caffe2::TypeMeta dtype) {
98
default_dtype = std::move(dtype);
10-
if(dtype == caffe2::TypeMeta::Make<double>()) {
11-
default_complex_dtype = std::move(caffe2::TypeMeta::Make<std::complex<double>>());
12-
} else {
13-
default_complex_dtype = std::move(caffe2::TypeMeta::Make<std::complex<float>>());
14-
}
159
}
1610

1711
const caffe2::TypeMeta& get_default_dtype() {
1812
return default_dtype;
1913
}
20-
const caffe2::TypeMeta& get_default_complex_dtype() {
21-
return default_complex_dtype;
22-
}
2314
} // namespace c10

c10/core/DefaultDtype.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,4 @@ class TypeMeta;
99
namespace c10 {
1010
C10_API void set_default_dtype(caffe2::TypeMeta dtype);
1111
C10_API const caffe2::TypeMeta& get_default_dtype();
12-
C10_API const caffe2::TypeMeta& get_default_complex_dtype();
1312
} // namespace c10

docs/source/tensor_attributes.rst

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,13 @@ torch.dtype
1515
.. class:: torch.dtype
1616

1717
A :class:`torch.dtype` is an object that represents the data type of a
18-
:class:`torch.Tensor`. PyTorch has eleven different data types:
18+
:class:`torch.Tensor`. PyTorch has nine different data types:
1919

2020
======================== =========================================== ===========================
21-
Data type dtype Legacy Constructors
21+
Data type dtype Tensor types
2222
======================== =========================================== ===========================
2323
32-bit floating point ``torch.float32`` or ``torch.float`` ``torch.*.FloatTensor``
2424
64-bit floating point ``torch.float64`` or ``torch.double`` ``torch.*.DoubleTensor``
25-
64-bit complex ``torch.complex64`` or ``torch.cfloat``
26-
128-bit floating point ``torch.complex128`` or ``torch.cdouble``
2725
16-bit floating point ``torch.float16`` or ``torch.half`` ``torch.*.HalfTensor``
2826
8-bit integer (unsigned) ``torch.uint8`` ``torch.*.ByteTensor``
2927
8-bit integer (signed) ``torch.int8`` ``torch.*.CharTensor``
@@ -36,16 +34,13 @@ Boolean ``torch.bool`` ``torch
3634
To find out if a :class:`torch.dtype` is a floating point data type, the property :attr:`is_floating_point`
3735
can be used, which returns ``True`` if the data type is a floating point data type.
3836

39-
To find out if a :class:`torch.dtype` is a complex data type, the property :attr:`is_complex`
40-
can be used, which returns ``True`` if the data type is a complex data type.
41-
4237
.. _type-promotion-doc:
4338

4439
When the dtypes of inputs to an arithmetic operation (`add`, `sub`, `div`, `mul`) differ, we promote
4540
by finding the minimum dtype that satisfies the following rules:
4641

4742
* If the type of a scalar operand is of a higher category than tensor operands
48-
(where complex > floating > integral > boolean), we promote to a type with sufficient size to hold
43+
(where floating > integral > boolean), we promote to a type with sufficient size to hold
4944
all scalar operands of that category.
5045
* If a zero-dimension tensor operand has a higher category than dimensioned operands,
5146
we promote to a type with sufficient size and category to hold all zero-dim tensor operands of
@@ -62,8 +57,6 @@ Promotion Examples::
6257

6358
>>> float_tensor = torch.ones(1, dtype=torch.float)
6459
>>> double_tensor = torch.ones(1, dtype=torch.double)
65-
>>> complex_float_tensor = torch.ones(1, dtype=torch.complex64)
66-
>>> complex_double_tensor = torch.ones(1, dtype=torch.complex128)
6760
>>> int_tensor = torch.ones(1, dtype=torch.int)
6861
>>> long_tensor = torch.ones(1, dtype=torch.long)
6962
>>> uint_tensor = torch.ones(1, dtype=torch.uint8)
@@ -88,8 +81,6 @@ Promotion Examples::
8881
torch.uint8
8982
>>> (float_tensor + double_tensor).dtype
9083
torch.float64
91-
>>> (complex_float_tensor + complex_double_tensor).dtype
92-
torch.complex128
9384
>>> (bool_tensor + int_tensor).dtype
9485
torch.int32
9586
# Since long is a different kind than float, result dtype only needs to be large enough

test/test_type_promotion.py

Lines changed: 13 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def wrapped_fn(*args, **kwargs):
2828

2929
return wrapped_fn
3030

31+
3132
class TestTypePromotion(TestCase):
3233

3334
# In-place operations don't promote.
@@ -80,44 +81,14 @@ def test_int_promotion(self, device):
8081

8182
@float_double_default_dtype
8283
def test_float_promotion(self, device):
83-
def test_promotion(dtype_float, dtype_double):
84-
a = torch.ones([4, 4, 4], dtype=dtype_float, device=device)
85-
b = torch.ones([4, 4, 4], dtype=dtype_double, device=device)
86-
c = a + b
87-
self.assertEqual(c, b + b)
88-
self.assertEqual(c.dtype, dtype_double)
89-
c = b + a
90-
self.assertEqual(c, b + b)
91-
self.assertEqual(c.dtype, dtype_double)
92-
test_promotion(torch.float, torch.double)
93-
94-
@float_double_default_dtype
95-
def test_complex_promotion(self, device):
96-
def test_promotion(dtype_float, dtype_double):
97-
a = torch.ones([4, 4, 4], dtype=dtype_float, device=device)
98-
b = torch.ones([4, 4, 4], dtype=dtype_double, device=device)
99-
c = a + b
100-
self.assertEqual(c, b + b)
101-
self.assertEqual(c.dtype, dtype_double)
102-
c = b + a
103-
self.assertEqual(c, b + b)
104-
self.assertEqual(c.dtype, dtype_double)
105-
106-
test_promotion(torch.complex64, torch.complex128)
107-
108-
a = torch.randn(3, dtype=torch.complex64, device=device)
109-
self.assertEqual((a * 5).dtype, torch.complex64)
110-
# not a "wrapped number"
111-
other = torch.tensor(5.5, dtype=torch.double, device=device)
112-
self.assertEqual((a + other).dtype, torch.complex64)
113-
114-
@float_double_default_dtype
115-
def test_complex_scalar_mult_tensor_promotion(self, device):
116-
a = 1j * torch.ones(2, device=device)
117-
a = a + 1j
118-
b = torch.tensor([2j, 2j], device=device)
119-
self.assertEqual(a, b)
120-
self.assertEqual(a.dtype, b.dtype)
84+
a = torch.ones([4, 4, 4], dtype=torch.float, device=device)
85+
b = torch.ones([4, 4, 4], dtype=torch.double, device=device)
86+
c = a + b
87+
self.assertEqual(c, b + b)
88+
self.assertEqual(c.dtype, torch.double)
89+
c = b + a
90+
self.assertEqual(c, b + b)
91+
self.assertEqual(c.dtype, torch.double)
12192

12293
@float_double_default_dtype
12394
def test_add_wrapped(self, device):
@@ -205,17 +176,12 @@ def _get_test_tensor(self, device, dtype, remove_zeros=False):
205176
shape = [5, 5, 5]
206177
if dtype == torch.bool:
207178
tensor = torch.randint(int(remove_zeros), 2, shape, device=device, dtype=dtype)
208-
elif dtype.is_floating_point or dtype.is_complex:
179+
elif dtype.is_floating_point:
209180
# "_th_normal_ not supported on CPUType for Half" so simpler create and convert
210181
tensor = torch.randn(shape, device=device)
211182
tensor = tensor.to(dtype)
212183
if remove_zeros:
213-
tensor_abs = torch.abs(tensor)
214-
if dtype == torch.complex64:
215-
tensor_abs = tensor_abs.to(torch.float)
216-
elif dtype == torch.complex128:
217-
tensor_abs = tensor_abs.to(torch.double)
218-
tensor[tensor_abs < 0.05] = 5
184+
tensor[torch.abs(tensor) < 0.05] = 5
219185
else:
220186
tensor = torch.randint(-5 if dtype.is_signed else 0, 10, shape, device=device, dtype=dtype)
221187
if remove_zeros:
@@ -228,9 +194,8 @@ def _get_test_tensor(self, device, dtype, remove_zeros=False):
228194
def test_many_promotions(self, device):
229195
# Can also include half on CPU in cases where it will be promoted to a
230196
# supported dtype
231-
complex_dtypes = get_all_complex_dtypes()
232-
dtypes1 = torch.testing.get_all_math_dtypes('cuda') + complex_dtypes
233-
dtypes2 = torch.testing.get_all_math_dtypes(device) + complex_dtypes
197+
dtypes1 = torch.testing.get_all_math_dtypes('cuda')
198+
dtypes2 = torch.testing.get_all_math_dtypes(device)
234199
ops = [torch.add, torch.sub, torch.mul, torch.div, torch.rsub]
235200
for dt1, dt2 in itertools.product(dtypes1, dtypes2):
236201
for op, non_contiguous in itertools.product(ops, [True, False]):

torch/csrc/utils/tensor_dtypes.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ static std::pair<std::string, std::string> getDtypeNames(
3636
case at::ScalarType::ComplexHalf:
3737
return std::make_pair("complex32", "");
3838
case at::ScalarType::ComplexFloat:
39-
return std::make_pair("complex64", "cfloat");
39+
return std::make_pair("complex64", "");
4040
case at::ScalarType::ComplexDouble:
41-
return std::make_pair("complex128", "cdouble");
41+
return std::make_pair("complex128", "");
4242
case at::ScalarType::Bool:
4343
return std::make_pair("bool", "");
4444
case at::ScalarType::QInt8:

torch/onnx/symbolic_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,8 +492,8 @@ def _set_operator_export_type(operator_export_type):
492492
'int64_t': 'Long',
493493
'int16_t': 'Short',
494494
'bool': 'Bool',
495-
'complex64': 'ComplexFloat',
496-
'complex128': 'ComplexDouble'
495+
'complex64': '',
496+
'complex128': ''
497497
}
498498

499499

torch/testing/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,6 @@ def get_all_math_dtypes(device):
102102

103103
return dtypes
104104

105-
def get_all_complex_dtypes():
106-
dtypes = [torch.complex64, torch.complex128]
107-
return dtypes
108105

109106
def get_all_device_types():
110107
return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']

0 commit comments

Comments
 (0)