|
8 | 8 | import numpy as np # type: ignore |
9 | 9 | from onnx import TensorProto |
10 | 10 | from onnx import mapping |
11 | | -from onnx.mapping import StringHolder |
| 11 | +from six import text_type |
12 | 12 | from typing import Sequence, Any, Optional, Text, List |
13 | 13 |
|
14 | 14 |
|
@@ -45,7 +45,7 @@ def to_array(tensor): # type: (TensorProto) -> np.ndarray[Any] |
45 | 45 |
|
46 | 46 | if tensor.data_type == TensorProto.STRING: |
47 | 47 | utf8_strings = getattr(tensor, storage_field) |
48 | | - ss = list(StringHolder(s.decode('utf-8')) for s in utf8_strings) |
| 48 | + ss = list(s.decode('utf-8') for s in utf8_strings) |
49 | 49 | return np.asarray(ss).astype(np_dtype).reshape(dims) |
50 | 50 |
|
51 | 51 | if tensor.HasField("raw_data"): |
@@ -85,10 +85,10 @@ def from_array(arr, name=None): # type: (np.ndarray[Any], Optional[Text]) -> Te |
85 | 85 | # Special care for strings. |
86 | 86 | tensor.data_type = mapping.NP_TYPE_TO_TENSOR_TYPE[arr.dtype] |
87 | 87 | for e in arr: |
88 | | - if isinstance(e, StringHolder): |
89 | | - tensor.string_data.append(e.text.encode('utf-8')) |
| 88 | + if isinstance(e, text_type): |
| 89 | + tensor.string_data.append(e.encode('utf-8')) |
90 | 90 | else: |
91 | | - raise NotImplementedError("Unrecognized object in the object array, expect StringHolder") |
| 91 | + raise NotImplementedError("Unrecognized object in the object array, expect a string") |
92 | 92 |
|
93 | 93 | return tensor |
94 | 94 |
|
|
0 commit comments