Skip to content

Commit 31db055

Browse files
committed
Support scripting classmethod called with object instances
1 parent 963f762 commit 31db055

2 files changed

Lines changed: 39 additions & 2 deletions

File tree

test/jit/test_class_type.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,6 +1212,32 @@ def test_function(a: int, b: int) -> 'ClassWithStaticMethod':
12121212

12131213
self.checkScript(test_function, (1, 2))
12141214

1215+
def test_classmethod(self):
1216+
"""
1217+
Test classmethods on class types.
1218+
"""
1219+
global ClassWithClassMethod
1220+
1221+
@torch.jit.script
1222+
class ClassWithClassMethod:
1223+
def __init__(self, a: int):
1224+
self.a: int = a
1225+
1226+
def __eq__(self, other: 'ClassWithClassMethod'):
1227+
return self.a == other.a
1228+
1229+
@classmethod
1230+
def create(cls, a: int) -> 'ClassWithClassMethod':
1231+
return cls(a)
1232+
1233+
def test_function(a: int) -> 'ClassWithClassMethod':
1234+
x = ClassWithClassMethod(a)
1235+
# Support calling classmethod with an instance
1236+
# Calling with the class is not supported.
1237+
return x.create(a)
1238+
1239+
self.checkScript(test_function, (1,))
1240+
12151241
def test_properties(self):
12161242
"""
12171243
Test that a scripted class can make use of the @property decorator.

torch/jit/frontend.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,14 @@ def get_jit_class_def(cls, self_name):
166166
and not is_static_fn(cls, m.__name__)
167167
and m.__name__ in cls.__dict__
168168
)
169+
170+
def is_classmethod(fn):
171+
return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls
172+
169173
methods = [get_jit_def(method[1],
170174
method[0],
171-
self_name=self_name) for method in methods]
175+
self_name=self_name,
176+
is_classmethod=is_classmethod(method[1])) for method in methods]
172177

173178
properties = get_class_properties(cls, self_name)
174179

@@ -217,7 +222,7 @@ def remove_prefix(text, prefix):
217222
return aligned_prefix + aligned_suffix
218223

219224

220-
def get_jit_def(fn, def_name, self_name=None):
225+
def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
221226
"""
222227
Build a JIT AST (TreeView) from the given function.
223228
@@ -244,6 +249,12 @@ def _forward(self):
244249
ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True)
245250
fn_def = py_ast.body[0]
246251

252+
if is_classmethod:
253+
arg_name = fn_def.args.args[0].arg
254+
# Insert a statement that assigns the first argument to the class
255+
assign_stmt = ast.parse(f"{arg_name} = {self_name}").body[0]
256+
fn_def.body.insert(0, assign_stmt)
257+
247258
# Swap out the function signature and body if it is unused
248259
if should_drop(fn):
249260
unused_fn_def = ast.parse("def unused_fn(self: Any):\n\traise RuntimeError(\"Cannot call @unused methods\")")

0 commit comments

Comments
 (0)