Skip to content

Commit beefa15

Browse files
yuslepukhinhouseroad
authored andcommitted
Use strings directly for casing as np.object w/o redundant StringHolder. (#1736)
1 parent 4023bae commit beefa15

File tree

3 files changed

+6
-22
lines changed

3 files changed

+6
-22
lines changed

onnx/mapping.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,6 @@
77
from typing import Text, Any
88
import numpy as np # type: ignore
99

10-
11-
class StringHolder:
12-
def __init__(self, s): # type: (Text) -> None
13-
self.text = s
14-
15-
def __eq__(self, other): # type: (object) -> Any
16-
if isinstance(other, self.__class__):
17-
return self.text == getattr(other, 'text')
18-
return False
19-
20-
def __ne__(self, other): # type: (object) -> Any
21-
return not self.__eq__(other)
22-
23-
2410
TENSOR_TYPE_TO_NP_TYPE = {
2511
int(TensorProto.FLOAT): np.dtype('float32'),
2612
int(TensorProto.UINT8): np.dtype('uint8'),

onnx/numpy_helper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np # type: ignore
99
from onnx import TensorProto
1010
from onnx import mapping
11-
from onnx.mapping import StringHolder
11+
from six import text_type
1212
from typing import Sequence, Any, Optional, Text, List
1313

1414

@@ -45,7 +45,7 @@ def to_array(tensor): # type: (TensorProto) -> np.ndarray[Any]
4545

4646
if tensor.data_type == TensorProto.STRING:
4747
utf8_strings = getattr(tensor, storage_field)
48-
ss = list(StringHolder(s.decode('utf-8')) for s in utf8_strings)
48+
ss = list(s.decode('utf-8') for s in utf8_strings)
4949
return np.asarray(ss).astype(np_dtype).reshape(dims)
5050

5151
if tensor.HasField("raw_data"):
@@ -85,10 +85,10 @@ def from_array(arr, name=None): # type: (np.ndarray[Any], Optional[Text]) -> Te
8585
# Special care for strings.
8686
tensor.data_type = mapping.NP_TYPE_TO_TENSOR_TYPE[arr.dtype]
8787
for e in arr:
88-
if isinstance(e, StringHolder):
89-
tensor.string_data.append(e.text.encode('utf-8'))
88+
if isinstance(e, text_type):
89+
tensor.string_data.append(e.encode('utf-8'))
9090
else:
91-
raise NotImplementedError("Unrecognized object in the object array, expect StringHolder")
91+
raise NotImplementedError("Unrecognized object in the object array, expect a string")
9292

9393
return tensor
9494

onnx/test/numpy_helper_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import numpy as np # type: ignore
77

88
from onnx import numpy_helper
9-
from onnx.mapping import StringHolder
109

1110
import unittest
1211

@@ -53,8 +52,7 @@ def test_int64(self): # type: () -> None
5352
self._test_numpy_helper_int_type(np.int64)
5453

5554
def test_string(self): # type: () -> None
56-
strholder_list = list(StringHolder(s) for s in ['Amy', 'Billy', 'Cindy', 'David'])
57-
a = np.array(strholder_list).astype(np.object)
55+
a = np.array(['Amy', 'Billy', 'Cindy', 'David']).astype(np.object)
5856
tensor_def = numpy_helper.from_array(a, "test")
5957
self.assertEqual(tensor_def.name, "test")
6058
a_recover = numpy_helper.to_array(tensor_def)

0 commit comments

Comments
 (0)