Skip to content

Commit ab1105a

Browse files
authored
Fix fix_test.py (#6415)
1 parent 6c186e3 commit ab1105a

File tree

1 file changed

+96
-69
lines changed

1 file changed

+96
-69
lines changed

scripts/fix_test.py

Lines changed: 96 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,20 @@
55
66
How to use:
77
1. Copy a specific test from the CPython repository to the RustPython repository.
8-
2. Remove all unexpected failures from the test and skip the tests that hang
9-
3. Run python ./scripts/fix_test.py --test test_venv --path ./Lib/test/test_venv.py or equivalent for the test from the project root.
10-
4. Ensure that there are no unexpected successes in the test.
11-
5. Actually fix the test.
8+
2. Remove all unexpected failures from the test and skip the tests that hang.
9+
3. Build RustPython: cargo build --release
10+
4. Run from the project root:
11+
- For single-file tests: python ./scripts/fix_test.py --path ./Lib/test/test_venv.py
12+
- For package tests: python ./scripts/fix_test.py --path ./Lib/test/test_inspect/test_inspect.py
13+
5. Verify: cargo run --release -- -m test test_venv (should pass with expected failures)
14+
6. Actually fix the tests marked with # TODO: RUSTPYTHON
1215
"""
1316

1417
import argparse
1518
import ast
1619
import itertools
1720
import platform
21+
import sys
1822
from pathlib import Path
1923

2024

@@ -58,85 +62,87 @@ def parse_results(result):
5862
in_test_results = True
5963
elif line.startswith("-----------"):
6064
in_test_results = False
61-
if (
62-
in_test_results
63-
and not line.startswith("tests")
64-
and not line.startswith("[")
65-
):
66-
line = line.split(" ")
67-
if line != [] and len(line) > 3:
68-
test = Test()
69-
test.name = line[0]
70-
test.path = line[1].strip("(").strip(")")
71-
test.result = " ".join(line[3:]).lower()
72-
test_results.tests.append(test)
73-
else:
74-
if "== Tests result: " in line:
75-
res = line.split("== Tests result: ")[1]
76-
res = res.split(" ")[0]
77-
test_results.tests_result = res
65+
if in_test_results and " ... " in line:
66+
line = line.strip()
67+
# Skip lines that don't look like test results
68+
if line.startswith("tests") or line.startswith("["):
69+
continue
70+
# Parse: "test_name (path) [subtest] ... RESULT"
71+
parts = line.split(" ... ")
72+
if len(parts) >= 2:
73+
test_info = parts[0]
74+
result_str = parts[-1].lower()
75+
# Only process FAIL or ERROR
76+
if result_str not in ("fail", "error"):
77+
continue
78+
# Extract test name (first word)
79+
first_space = test_info.find(" ")
80+
if first_space > 0:
81+
test = Test()
82+
test.name = test_info[:first_space]
83+
# Extract path from (path)
84+
rest = test_info[first_space:].strip()
85+
if rest.startswith("("):
86+
end_paren = rest.find(")")
87+
if end_paren > 0:
88+
test.path = rest[1:end_paren]
89+
test.result = result_str
90+
test_results.tests.append(test)
91+
elif "== Tests result: " in line:
92+
res = line.split("== Tests result: ")[1]
93+
res = res.split(" ")[0]
94+
test_results.tests_result = res
7895
return test_results
7996

8097

8198
def path_to_test(path) -> list[str]:
82-
return path.split(".")[2:]
99+
# path format: test.module_name[.submodule].ClassName.test_method
100+
# We need [ClassName, test_method] - always the last 2 elements
101+
parts = path.split(".")
102+
return parts[-2:] # Get class name and method name
83103

84104

85-
def modify_test(file: str, test: list[str], for_platform: bool = False) -> str:
105+
def find_test_lineno(file: str, test: list[str]) -> tuple[int, int] | None:
106+
"""Find the line number and column offset of a test function.
107+
Returns (lineno, col_offset) or None if not found.
108+
"""
86109
a = ast.parse(file)
87-
lines = file.splitlines()
88-
fixture = "@unittest.expectedFailure"
89-
for node in ast.walk(a):
90-
if isinstance(node, ast.FunctionDef):
91-
if node.name == test[-1]:
92-
assert not for_platform
93-
indent = " " * node.col_offset
94-
lines.insert(node.lineno - 1, indent + fixture)
95-
lines.insert(node.lineno - 1, indent + "# TODO: RUSTPYTHON")
96-
break
97-
return "\n".join(lines)
98-
99-
100-
def modify_test_v2(file: str, test: list[str], for_platform: bool = False) -> str:
101-
a = ast.parse(file)
102-
lines = file.splitlines()
103-
fixture = "@unittest.expectedFailure"
104110
for key, node in ast.iter_fields(a):
105111
if key == "body":
106-
for i, n in enumerate(node):
112+
for n in node:
107113
match n:
108114
case ast.ClassDef():
109115
if len(test) == 2 and test[0] == n.name:
110-
# look through body for function def
111-
for i, fn in enumerate(n.body):
116+
for fn in n.body:
112117
match fn:
113-
case ast.FunctionDef():
118+
case ast.FunctionDef() | ast.AsyncFunctionDef():
114119
if fn.name == test[-1]:
115-
assert not for_platform
116-
indent = " " * fn.col_offset
117-
lines.insert(
118-
fn.lineno - 1, indent + fixture
119-
)
120-
lines.insert(
121-
fn.lineno - 1,
122-
indent + "# TODO: RUSTPYTHON",
123-
)
124-
break
125-
case ast.FunctionDef():
120+
return (fn.lineno, fn.col_offset)
121+
case ast.FunctionDef() | ast.AsyncFunctionDef():
126122
if n.name == test[0] and len(test) == 1:
127-
assert not for_platform
128-
indent = " " * n.col_offset
129-
lines.insert(n.lineno - 1, indent + fixture)
130-
lines.insert(n.lineno - 1, indent + "# TODO: RUSTPYTHON")
131-
break
132-
if i > 500:
133-
exit()
123+
return (n.lineno, n.col_offset)
124+
return None
125+
126+
127+
def apply_modifications(file: str, modifications: list[tuple[int, int]]) -> str:
128+
"""Apply all modifications in reverse order to avoid line number offset issues."""
129+
lines = file.splitlines()
130+
fixture = "@unittest.expectedFailure"
131+
# Sort by line number in descending order
132+
modifications.sort(key=lambda x: x[0], reverse=True)
133+
for lineno, col_offset in modifications:
134+
indent = " " * col_offset
135+
lines.insert(lineno - 1, indent + fixture)
136+
lines.insert(lineno - 1, indent + "# TODO: RUSTPYTHON")
134137
return "\n".join(lines)
135138

136139

137140
def run_test(test_name):
138141
print(f"Running test: {test_name}")
139142
rustpython_location = "./target/release/rustpython"
143+
if sys.platform == "win32":
144+
rustpython_location += ".exe"
145+
140146
import subprocess
141147

142148
result = subprocess.run(
@@ -149,13 +155,34 @@ def run_test(test_name):
149155

150156
if __name__ == "__main__":
151157
args = parse_args()
152-
test_name = args.path.stem
158+
test_path = args.path.resolve()
159+
if not test_path.exists():
160+
print(f"Error: File not found: {test_path}")
161+
sys.exit(1)
162+
test_name = test_path.stem
153163
tests = run_test(test_name)
154-
f = open(args.path).read()
164+
f = test_path.read_text(encoding="utf-8")
165+
166+
# Collect all modifications first (with deduplication for subtests)
167+
modifications = []
168+
seen_tests = set() # Track (class_name, method_name) to avoid duplicates
155169
for test in tests.tests:
156170
if test.result == "fail" or test.result == "error":
157-
print("Modifying test:", test.name)
158-
f = modify_test_v2(f, path_to_test(test.path), args.platform)
159-
with open(args.path, "w") as file:
160-
# TODO: Find validation method, and make --force override it
161-
file.write(f)
171+
test_parts = path_to_test(test.path)
172+
test_key = tuple(test_parts)
173+
if test_key in seen_tests:
174+
continue # Skip duplicate (same test, different subtest)
175+
seen_tests.add(test_key)
176+
location = find_test_lineno(f, test_parts)
177+
if location:
178+
print(f"Modifying test: {test.name} at line {location[0]}")
179+
modifications.append(location)
180+
else:
181+
print(f"Warning: Could not find test: {test.name} ({test_parts})")
182+
183+
# Apply all modifications in reverse order
184+
if modifications:
185+
f = apply_modifications(f, modifications)
186+
test_path.write_text(f, encoding="utf-8")
187+
188+
print(f"Modified {len(modifications)} tests")

0 commit comments

Comments
 (0)