2323"""
2424
2525import argparse
26+ import ast
2627import shutil
2728import sys
2829from pathlib import Path
2930
30- from lib_updater import PatchSpec , UtMethod , apply_patches
31+ from lib_updater import (
32+ COMMENT ,
33+ PatchSpec ,
34+ UtMethod ,
35+ apply_patches ,
36+ )
3137
3238
3339def parse_args ():
@@ -61,15 +67,18 @@ def __str__(self):
6167class TestResult :
6268 tests_result : str = ""
6369 tests = []
70+ unexpected_successes = [] # Tests that passed but were marked as expectedFailure
6471 stdout = ""
6572
6673 def __str__ (self ):
67- return f"TestResult(tests_result={ self .tests_result } ,tests={ len (self .tests )} )"
74+ return f"TestResult(tests_result={ self .tests_result } ,tests={ len (self .tests )} ,unexpected_successes= { len ( self . unexpected_successes ) } )"
6875
6976
7077def parse_results (result ):
7178 lines = result .stdout .splitlines ()
7279 test_results = TestResult ()
80+ test_results .tests = []
81+ test_results .unexpected_successes = []
7382 test_results .stdout = result .stdout
7483 in_test_results = False
7584 for line in lines :
@@ -107,6 +116,19 @@ def parse_results(result):
107116 res = line .split ("== Tests result: " )[1 ]
108117 res = res .split (" " )[0 ]
109118 test_results .tests_result = res
119+ # Parse: "UNEXPECTED SUCCESS: test_name (path)"
120+ elif line .startswith ("UNEXPECTED SUCCESS: " ):
121+ rest = line [len ("UNEXPECTED SUCCESS: " ) :]
122+ # Format: "test_name (path)"
123+ first_space = rest .find (" " )
124+ if first_space > 0 :
125+ test = Test ()
126+ test .name = rest [:first_space ]
127+ path_part = rest [first_space :].strip ()
128+ if path_part .startswith ("(" ) and path_part .endswith (")" ):
129+ test .path = path_part [1 :- 1 ]
130+ test .result = "unexpected_success"
131+ test_results .unexpected_successes .append (test )
110132 return test_results
111133
112134
@@ -117,6 +139,95 @@ def path_to_test(path) -> list[str]:
117139 return parts [- 2 :] # Get class name and method name
118140
119141
142+ def is_super_call_only (func_node : ast .FunctionDef | ast .AsyncFunctionDef ) -> bool :
143+ """Check if the method body is just 'return super().method_name()'."""
144+ if len (func_node .body ) != 1 :
145+ return False
146+ stmt = func_node .body [0 ]
147+ if not isinstance (stmt , ast .Return ) or stmt .value is None :
148+ return False
149+ # Check for super().method_name() pattern
150+ call = stmt .value
151+ if not isinstance (call , ast .Call ):
152+ return False
153+ if not isinstance (call .func , ast .Attribute ):
154+ return False
155+ super_call = call .func .value
156+ if not isinstance (super_call , ast .Call ):
157+ return False
158+ if not isinstance (super_call .func , ast .Name ) or super_call .func .id != "super" :
159+ return False
160+ return True
161+
162+
163+ def remove_expected_failures (
164+ contents : str , tests_to_remove : set [tuple [str , str ]]
165+ ) -> str :
166+ """Remove @unittest.expectedFailure decorators from tests that now pass."""
167+ if not tests_to_remove :
168+ return contents
169+
170+ tree = ast .parse (contents )
171+ lines = contents .splitlines ()
172+ lines_to_remove = set ()
173+
174+ for node in ast .walk (tree ):
175+ if not isinstance (node , ast .ClassDef ):
176+ continue
177+ class_name = node .name
178+ for item in node .body :
179+ if not isinstance (item , (ast .FunctionDef , ast .AsyncFunctionDef )):
180+ continue
181+ method_name = item .name
182+ if (class_name , method_name ) not in tests_to_remove :
183+ continue
184+
185+ # Check if we should remove the entire method (super() call only)
186+ remove_entire_method = is_super_call_only (item )
187+
188+ if remove_entire_method :
189+ # Remove entire method including decorators and any preceding comment
190+ first_line = item .lineno - 1 # 0-indexed, def line
191+ if item .decorator_list :
192+ first_line = item .decorator_list [0 ].lineno - 1
193+ # Check for TODO comment before first decorator/def
194+ if first_line > 0 :
195+ prev_line = lines [first_line - 1 ].strip ()
196+ if prev_line .startswith ("#" ) and COMMENT in prev_line :
197+ first_line -= 1
198+ # Remove from first_line to end_lineno (inclusive)
199+ for i in range (first_line , item .end_lineno ):
200+ lines_to_remove .add (i )
201+ else :
202+ # Only remove the expectedFailure decorator
203+ for dec in item .decorator_list :
204+ dec_line = dec .lineno - 1 # 0-indexed
205+ line_content = lines [dec_line ]
206+
207+ # Check if it's @unittest.expectedFailure
208+ if "expectedFailure" not in line_content :
209+ continue
210+
211+ # Check if TODO: RUSTPYTHON is on the same line or the line before
212+ has_comment_on_line = COMMENT in line_content
213+ has_comment_before = (
214+ dec_line > 0
215+ and lines [dec_line - 1 ].strip ().startswith ("#" )
216+ and COMMENT in lines [dec_line - 1 ]
217+ )
218+
219+ if has_comment_on_line or has_comment_before :
220+ lines_to_remove .add (dec_line )
221+ if has_comment_before :
222+ lines_to_remove .add (dec_line - 1 )
223+
224+ # Remove lines in reverse order to maintain line numbers
225+ for line_idx in sorted (lines_to_remove , reverse = True ):
226+ del lines [line_idx ]
227+
228+ return "\n " .join (lines ) + "\n " if lines else ""
229+
230+
120231def build_patches (test_parts_set : set [tuple [str , str ]]) -> dict :
121232 """Convert failing tests to lib_updater patch format."""
122233 patches = {}
@@ -190,20 +301,38 @@ def run_test(test_name):
190301 f = test_path .read_text (encoding = "utf-8" )
191302
192303 # Collect failing tests (with deduplication for subtests)
193- seen_tests = set () # Track (class_name, method_name) to avoid duplicates
304+ failing_tests = set () # Track (class_name, method_name) to avoid duplicates
194305 for test in tests .tests :
195306 if test .result == "fail" or test .result == "error" :
196307 test_parts = path_to_test (test .path )
197308 if len (test_parts ) == 2 :
198309 test_key = tuple (test_parts )
199- if test_key not in seen_tests :
200- seen_tests .add (test_key )
201- print (f"Marking test: { test_parts [0 ]} .{ test_parts [1 ]} " )
202-
203- # Apply patches using lib_updater
204- if seen_tests :
205- patches = build_patches (seen_tests )
310+ if test_key not in failing_tests :
311+ failing_tests .add (test_key )
312+ print (f"Marking as failing: { test_parts [0 ]} .{ test_parts [1 ]} " )
313+
314+ # Collect unexpected successes (tests that now pass but have expectedFailure)
315+ unexpected_successes = set ()
316+ for test in tests .unexpected_successes :
317+ test_parts = path_to_test (test .path )
318+ if len (test_parts ) == 2 :
319+ test_key = tuple (test_parts )
320+ if test_key not in unexpected_successes :
321+ unexpected_successes .add (test_key )
322+ print (f"Removing expectedFailure: { test_parts [0 ]} .{ test_parts [1 ]} " )
323+
324+ # Remove expectedFailure from tests that now pass
325+ if unexpected_successes :
326+ f = remove_expected_failures (f , unexpected_successes )
327+
328+ # Apply patches for failing tests
329+ if failing_tests :
330+ patches = build_patches (failing_tests )
206331 f = apply_patches (f , patches )
332+
333+ # Write changes if any modifications were made
334+ if failing_tests or unexpected_successes :
207335 test_path .write_text (f , encoding = "utf-8" )
208336
209- print (f"Modified { len (seen_tests )} tests" )
337+ print (f"Added expectedFailure to { len (failing_tests )} tests" )
338+ print (f"Removed expectedFailure from { len (unexpected_successes )} tests" )
0 commit comments