Skip to content

Commit c934265

Browse files
rename fix_test and support removing unexpected success (#6748)
* fix_test to remove unexpected success * rename fix_test * Auto-format: ruff format --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent e9a57d1 commit c934265

File tree

1 file changed

+140
-11
lines changed

1 file changed

+140
-11
lines changed
Lines changed: 140 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,17 @@
2323
"""
2424

2525
import argparse
26+
import ast
2627
import shutil
2728
import sys
2829
from pathlib import Path
2930

30-
from lib_updater import PatchSpec, UtMethod, apply_patches
31+
from lib_updater import (
32+
COMMENT,
33+
PatchSpec,
34+
UtMethod,
35+
apply_patches,
36+
)
3137

3238

3339
def parse_args():
@@ -61,15 +67,18 @@ def __str__(self):
6167
class TestResult:
6268
tests_result: str = ""
6369
tests = []
70+
unexpected_successes = [] # Tests that passed but were marked as expectedFailure
6471
stdout = ""
6572

6673
def __str__(self):
67-
return f"TestResult(tests_result={self.tests_result},tests={len(self.tests)})"
74+
return f"TestResult(tests_result={self.tests_result},tests={len(self.tests)},unexpected_successes={len(self.unexpected_successes)})"
6875

6976

7077
def parse_results(result):
7178
lines = result.stdout.splitlines()
7279
test_results = TestResult()
80+
test_results.tests = []
81+
test_results.unexpected_successes = []
7382
test_results.stdout = result.stdout
7483
in_test_results = False
7584
for line in lines:
@@ -107,6 +116,19 @@ def parse_results(result):
107116
res = line.split("== Tests result: ")[1]
108117
res = res.split(" ")[0]
109118
test_results.tests_result = res
119+
# Parse: "UNEXPECTED SUCCESS: test_name (path)"
120+
elif line.startswith("UNEXPECTED SUCCESS: "):
121+
rest = line[len("UNEXPECTED SUCCESS: ") :]
122+
# Format: "test_name (path)"
123+
first_space = rest.find(" ")
124+
if first_space > 0:
125+
test = Test()
126+
test.name = rest[:first_space]
127+
path_part = rest[first_space:].strip()
128+
if path_part.startswith("(") and path_part.endswith(")"):
129+
test.path = path_part[1:-1]
130+
test.result = "unexpected_success"
131+
test_results.unexpected_successes.append(test)
110132
return test_results
111133

112134

@@ -117,6 +139,95 @@ def path_to_test(path) -> list[str]:
117139
return parts[-2:] # Get class name and method name
118140

119141

142+
def is_super_call_only(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> bool:
143+
"""Check if the method body is just 'return super().method_name()'."""
144+
if len(func_node.body) != 1:
145+
return False
146+
stmt = func_node.body[0]
147+
if not isinstance(stmt, ast.Return) or stmt.value is None:
148+
return False
149+
# Check for super().method_name() pattern
150+
call = stmt.value
151+
if not isinstance(call, ast.Call):
152+
return False
153+
if not isinstance(call.func, ast.Attribute):
154+
return False
155+
super_call = call.func.value
156+
if not isinstance(super_call, ast.Call):
157+
return False
158+
if not isinstance(super_call.func, ast.Name) or super_call.func.id != "super":
159+
return False
160+
return True
161+
162+
163+
def remove_expected_failures(
164+
contents: str, tests_to_remove: set[tuple[str, str]]
165+
) -> str:
166+
"""Remove @unittest.expectedFailure decorators from tests that now pass."""
167+
if not tests_to_remove:
168+
return contents
169+
170+
tree = ast.parse(contents)
171+
lines = contents.splitlines()
172+
lines_to_remove = set()
173+
174+
for node in ast.walk(tree):
175+
if not isinstance(node, ast.ClassDef):
176+
continue
177+
class_name = node.name
178+
for item in node.body:
179+
if not isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
180+
continue
181+
method_name = item.name
182+
if (class_name, method_name) not in tests_to_remove:
183+
continue
184+
185+
# Check if we should remove the entire method (super() call only)
186+
remove_entire_method = is_super_call_only(item)
187+
188+
if remove_entire_method:
189+
# Remove entire method including decorators and any preceding comment
190+
first_line = item.lineno - 1 # 0-indexed, def line
191+
if item.decorator_list:
192+
first_line = item.decorator_list[0].lineno - 1
193+
# Check for TODO comment before first decorator/def
194+
if first_line > 0:
195+
prev_line = lines[first_line - 1].strip()
196+
if prev_line.startswith("#") and COMMENT in prev_line:
197+
first_line -= 1
198+
# Remove from first_line to end_lineno (inclusive)
199+
for i in range(first_line, item.end_lineno):
200+
lines_to_remove.add(i)
201+
else:
202+
# Only remove the expectedFailure decorator
203+
for dec in item.decorator_list:
204+
dec_line = dec.lineno - 1 # 0-indexed
205+
line_content = lines[dec_line]
206+
207+
# Check if it's @unittest.expectedFailure
208+
if "expectedFailure" not in line_content:
209+
continue
210+
211+
# Check if TODO: RUSTPYTHON is on the same line or the line before
212+
has_comment_on_line = COMMENT in line_content
213+
has_comment_before = (
214+
dec_line > 0
215+
and lines[dec_line - 1].strip().startswith("#")
216+
and COMMENT in lines[dec_line - 1]
217+
)
218+
219+
if has_comment_on_line or has_comment_before:
220+
lines_to_remove.add(dec_line)
221+
if has_comment_before:
222+
lines_to_remove.add(dec_line - 1)
223+
224+
# Remove lines in reverse order to maintain line numbers
225+
for line_idx in sorted(lines_to_remove, reverse=True):
226+
del lines[line_idx]
227+
228+
return "\n".join(lines) + "\n" if lines else ""
229+
230+
120231
def build_patches(test_parts_set: set[tuple[str, str]]) -> dict:
121232
"""Convert failing tests to lib_updater patch format."""
122233
patches = {}
@@ -190,20 +301,38 @@ def run_test(test_name):
190301
f = test_path.read_text(encoding="utf-8")
191302

192303
# Collect failing tests (with deduplication for subtests)
193-
seen_tests = set() # Track (class_name, method_name) to avoid duplicates
304+
failing_tests = set() # Track (class_name, method_name) to avoid duplicates
194305
for test in tests.tests:
195306
if test.result == "fail" or test.result == "error":
196307
test_parts = path_to_test(test.path)
197308
if len(test_parts) == 2:
198309
test_key = tuple(test_parts)
199-
if test_key not in seen_tests:
200-
seen_tests.add(test_key)
201-
print(f"Marking test: {test_parts[0]}.{test_parts[1]}")
202-
203-
# Apply patches using lib_updater
204-
if seen_tests:
205-
patches = build_patches(seen_tests)
310+
if test_key not in failing_tests:
311+
failing_tests.add(test_key)
312+
print(f"Marking as failing: {test_parts[0]}.{test_parts[1]}")
313+
314+
# Collect unexpected successes (tests that now pass but have expectedFailure)
315+
unexpected_successes = set()
316+
for test in tests.unexpected_successes:
317+
test_parts = path_to_test(test.path)
318+
if len(test_parts) == 2:
319+
test_key = tuple(test_parts)
320+
if test_key not in unexpected_successes:
321+
unexpected_successes.add(test_key)
322+
print(f"Removing expectedFailure: {test_parts[0]}.{test_parts[1]}")
323+
324+
# Remove expectedFailure from tests that now pass
325+
if unexpected_successes:
326+
f = remove_expected_failures(f, unexpected_successes)
327+
328+
# Apply patches for failing tests
329+
if failing_tests:
330+
patches = build_patches(failing_tests)
206331
f = apply_patches(f, patches)
332+
333+
# Write changes if any modifications were made
334+
if failing_tests or unexpected_successes:
207335
test_path.write_text(f, encoding="utf-8")
208336

209-
print(f"Modified {len(seen_tests)} tests")
337+
print(f"Added expectedFailure to {len(failing_tests)} tests")
338+
print(f"Removed expectedFailure from {len(unexpected_successes)} tests")

0 commit comments

Comments
 (0)