Skip to content

Commit 0fd856c

Browse files
Revert "[ONNX] Fix scalar type promotion between fp16 tensor and fp32 scalar (#113404)"
This reverts commit 39ca5a3. Reverted #113404 on behalf of https://github.com/jeanschmidt due to sorry it is breaking CI jobs on main ([comment](#113404 (comment)))
1 parent d64bc8f commit 0fd856c

2 files changed

Lines changed: 17 additions & 40 deletions

File tree

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7138,18 +7138,6 @@ def forward(self, fp16, fp32):
71387138
fp32 = Tensor([1.5])
71397139
self.run_test(CatModel(), (fp16, fp32))
71407140

7141-
@skipIfUnsupportedMinOpsetVersion(9)
7142-
def test_scalar_type_does_not_trigger_upcast_type_promotion(self):
7143-
class DoNotUpcastModel(torch.nn.Module):
7144-
def forward(self, x):
7145-
scale = x.size()[-1] ** -0.5
7146-
# 'scale' is exported as onnx float32 rank 0 tensor.
7147-
# The following 'Mul' should NOT be promoted to float32.
7148-
return x * scale
7149-
7150-
x = torch.ones(2, 3, dtype=torch.float16)
7151-
self.run_test(DoNotUpcastModel(), x)
7152-
71537141
@skipIfUnsupportedMinOpsetVersion(9)
71547142
def test_full_like(self):
71557143
class FullLikeModel(torch.nn.Module):

torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,7 @@ static c10::optional<c10::ScalarType> InferExpectedScalarType(const Node* n) {
161161
at::typeMetaToScalarType(at::get_default_dtype());
162162
switch (scalar_type) {
163163
case at::kDouble:
164-
case at::kFloat:
165-
// floating-point numbers wrapped as float32/float64 tensors are
164+
// floating-point numbers wrapped as double tensors are
166165
// considered to have default type, instead of double.
167166
typesFromScalars.emplace_back(default_scalar_type);
168167
break;
@@ -203,33 +202,23 @@ static c10::optional<c10::ScalarType> InferExpectedScalarType(const Node* n) {
203202
} else {
204203
typesFromTensors.emplace_back(scalar_type);
205204
}
206-
} else if (auto scalar_type = get_scalar_type(input)) {
207-
auto tensor_type = input->type()->castRaw<TensorType>();
208-
// get_scalar_type returns non-null value already guarantees
209-
// that the input has a valid tensor_type.
210-
TORCH_INTERNAL_ASSERT(nullptr != tensor_type);
211-
// ONNX model track shape related computes that were done in pytorch
212-
// by python numbers as tensor computes. This is the only way for ONNX
213-
// to track them properly since ONNX only has tensor type, otherwise
214-
// the computation result will be tracked statically as constant, and
215-
// the model won't work for another input that differs in shape.
216-
217-
// Now for type promotion logic, scalars should be treated differently
218-
// with tensors. More info regarding type promotion logic commented at
219-
// `emplace_type_from_scalar`. Here we filter out rank 0 tensors and
220-
// run it with `emplace_type_from_scalar` to determine if they are
221-
// considered scalars for type promotion.
222-
223-
// NOTE that this might introduce regression that a REAL 0-rank tensor
224-
// is now being recognized as scalar. The downside is the model will
225-
// drop in accuracy for these cases as certain computations will
226-
// happen in lower precision data types.
227-
auto rank = tensor_type->dim();
228-
if (rank && rank.value() == 0) {
229-
emplace_type_from_scalar(scalar_type.value());
230-
} else {
231-
typesFromTensors.emplace_back(scalar_type.value());
205+
} else if (nkind == prim::Param) {
206+
// ONNX doesn't support scalar as graph input. When
207+
// seeing a scalar input, we convert its expected type to tensor.
208+
if (auto scalar_type = get_scalar_type(input)) {
209+
auto tensor_type = input->type()->castRaw<TensorType>();
210+
// get_scalar_type returns non-null value already guarantees
211+
// that the input has a valid tensor_type.
212+
TORCH_INTERNAL_ASSERT(nullptr != tensor_type);
213+
auto rank = tensor_type->dim();
214+
if (rank && rank.value() == 0) {
215+
emplace_type_from_scalar(scalar_type.value());
216+
} else {
217+
typesFromTensors.emplace_back(scalar_type.value());
218+
}
232219
}
220+
} else if (auto scalar_type = get_scalar_type(input)) {
221+
typesFromTensors.emplace_back(*scalar_type);
233222
}
234223
});
235224

0 commit comments

Comments
 (0)