@@ -166,9 +166,14 @@ def get_jit_class_def(cls, self_name):
166166 and not is_static_fn (cls , m .__name__ )
167167 and m .__name__ in cls .__dict__
168168 )
169+
170+ def is_classmethod (fn ):
171+ return inspect .ismethod (fn ) and getattr (fn , "__self__" , None ) == cls
172+
169173 methods = [get_jit_def (method [1 ],
170174 method [0 ],
171- self_name = self_name ) for method in methods ]
175+ self_name = self_name ,
176+ is_classmethod = is_classmethod (method [1 ])) for method in methods ]
172177
173178 properties = get_class_properties (cls , self_name )
174179
@@ -217,7 +222,7 @@ def remove_prefix(text, prefix):
217222 return aligned_prefix + aligned_suffix
218223
219224
220- def get_jit_def (fn , def_name , self_name = None ):
225+ def get_jit_def (fn , def_name , self_name = None , is_classmethod = False ):
221226 """
222227 Build a JIT AST (TreeView) from the given function.
223228
@@ -244,6 +249,12 @@ def _forward(self):
244249 ctx = SourceContext (source , filename , file_lineno , leading_whitespace_len , True )
245250 fn_def = py_ast .body [0 ]
246251
252+ if is_classmethod :
253+ arg_name = fn_def .args .args [0 ].arg
254+ # Insert a statement that assigns the first argument to the class
255+ assign_stmt = ast .parse (f"{ arg_name } = { self_name } " ).body [0 ]
256+ fn_def .body .insert (0 , assign_stmt )
257+
247258 # Swap out the function signature and body if it is unused
248259 if should_drop (fn ):
249260 unused_fn_def = ast .parse ("def unused_fn(self: Any):\n \t raise RuntimeError(\" Cannot call @unused methods\" )" )
0 commit comments