Skip to content

Commit f86e951

Browse files
yskim1501bddppq
authored andcommitted
Add domain as an optional parameter for make_node function (#1588)
* add domain as an optional parameter for creating NodeProto * at least two space for comments * add test case * add suggested comment * modify comment * explicitly check for None
1 parent ff45588 commit f86e951

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

onnx/helper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def make_node(
2323
outputs, # type: Sequence[Text]
2424
name=None, # type: Optional[Text]
2525
doc_string=None, # type: Optional[Text]
26+
domain=None, # type: Optional[Text]
2627
**kwargs # type: Any
2728
): # type: (...) -> NodeProto
2829
"""Construct a NodeProto.
@@ -33,6 +34,8 @@ def make_node(
3334
outputs (list of string): list of output names
3435
name (string, default None): optional unique identifier for NodeProto
3536
doc_string (string, default None): optional documentation string for NodeProto
37+
domain (string, default None): optional domain for NodeProto.
38+
If it's None, we will just use default domain (which is empty)
3639
**kwargs (dict): the attributes of the node. The acceptable values
3740
are documented in :func:`make_attribute`.
3841
"""
@@ -45,6 +48,8 @@ def make_node(
4548
node.name = name
4649
if doc_string:
4750
node.doc_string = doc_string
51+
if domain is not None:
52+
node.domain = domain
4853
if kwargs:
4954
node.attribute.extend(
5055
make_attribute(key, value)

onnx/test/helper_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,11 @@ def test_node_with_arg(self): # type: () -> None
202202
node_def.attribute[0],
203203
helper.make_attribute("arg_value", 1))
204204

205+
def test_node_domain(self): # type: () -> None
206+
node_def = helper.make_node(
207+
"Relu", ["X"], ["Y"], name="test", doc_string="doc", domain="test.domain")
208+
self.assertEqual(node_def.domain, "test.domain")
209+
205210
def test_graph(self): # type: () -> None
206211
node_def1 = helper.make_node(
207212
"Relu", ["X"], ["Y"])

0 commit comments

Comments
 (0)