@@ -161,8 +161,7 @@ static c10::optional<c10::ScalarType> InferExpectedScalarType(const Node* n) {
161161 at::typeMetaToScalarType (at::get_default_dtype ());
162162 switch (scalar_type) {
163163 case at::kDouble :
164- case at::kFloat :
165- // floating-point numbers wrapped as float32/float64 tensors are
164+ // floating-point numbers wrapped as double tensors are
166165 // considered to have default type, instead of double.
167166 typesFromScalars.emplace_back (default_scalar_type);
168167 break ;
@@ -203,33 +202,23 @@ static c10::optional<c10::ScalarType> InferExpectedScalarType(const Node* n) {
203202 } else {
204203 typesFromTensors.emplace_back (scalar_type);
205204 }
206- } else if (auto scalar_type = get_scalar_type (input)) {
207- auto tensor_type = input->type ()->castRaw <TensorType>();
208- // get_scalar_type returns non-null value already guarantees
209- // that the input has a valid tensor_type.
210- TORCH_INTERNAL_ASSERT (nullptr != tensor_type);
211- // ONNX model track shape related computes that were done in pytorch
212- // by python numbers as tensor computes. This is the only way for ONNX
213- // to track them properly since ONNX only has tensor type, otherwise
214- // the computation result will be tracked statically as constant, and
215- // the model won't work for another input that differs in shape.
216-
217- // Now for type promotion logic, scalars should be treated differently
218- // with tensors. More info regarding type promotion logic commented at
219- // `emplace_type_from_scalar`. Here we filter out rank 0 tensors and
220- // run it with `emplace_type_from_scalar` to determine if they are
221- // considered scalars for type promotion.
222-
223- // NOTE that this might introduce regression that a REAL 0-rank tensor
224- // is now being recognized as scalar. The downside is the model will
225- // drop in accuracy for these cases as certain computations will
226- // happen in lower precision data types.
227- auto rank = tensor_type->dim ();
228- if (rank && rank.value () == 0 ) {
229- emplace_type_from_scalar (scalar_type.value ());
230- } else {
231- typesFromTensors.emplace_back (scalar_type.value ());
205+ } else if (nkind == prim::Param) {
206+ // ONNX doesn't support scalar as graph input. When
207+ // seeing a scalar input, we convert its expected type to tensor.
208+ if (auto scalar_type = get_scalar_type (input)) {
209+ auto tensor_type = input->type ()->castRaw <TensorType>();
210+ // get_scalar_type returns non-null value already guarantees
211+ // that the input has a valid tensor_type.
212+ TORCH_INTERNAL_ASSERT (nullptr != tensor_type);
213+ auto rank = tensor_type->dim ();
214+ if (rank && rank.value () == 0 ) {
215+ emplace_type_from_scalar (scalar_type.value ());
216+ } else {
217+ typesFromTensors.emplace_back (scalar_type.value ());
218+ }
232219 }
220+ } else if (auto scalar_type = get_scalar_type (input)) {
221+ typesFromTensors.emplace_back (*scalar_type);
233222 }
234223 });
235224
0 commit comments