Skip to content

Commit 981bbc6

Browse files
committed
handle redundant t mode
1 parent 1eef83a commit 981bbc6

3 files changed

Lines changed: 45 additions & 17 deletions

File tree

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,8 @@ Note that `if` blocks without an `else` will not be rewritten as it could introd
487487
+open("foo")
488488
-open("f", "r", encoding="UTF-8")
489489
+open("f", encoding="UTF-8")
490+
-open("f", "wt")
491+
+open("f", "w")
490492
```
491493

492494

pyupgrade/_plugins/open_mode.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import ast
44
import functools
5+
import itertools
56
from typing import Iterable
67
from typing import NamedTuple
78

@@ -18,10 +19,20 @@
1819
from pyupgrade._token_helpers import find_open_paren
1920
from pyupgrade._token_helpers import parse_call_args
2021

21-
U_MODE_REMOVE = frozenset(('U', 'Ur', 'rU', 'r', 'rt', 'tr'))
22-
U_MODE_REPLACE_R = frozenset(('Ub', 'bU'))
23-
U_MODE_REMOVE_U = frozenset(('rUb', 'Urb', 'rbU', 'Ubr', 'bUr', 'brU'))
24-
U_MODE_REPLACE = U_MODE_REPLACE_R | U_MODE_REMOVE_U
22+
23+
def _plus(args: tuple[str, ...]) -> tuple[str, ...]:
24+
return args + tuple(f'{arg}+' for arg in args)
25+
26+
27+
def _permute(*args: str) -> tuple[str, ...]:
28+
return tuple(''.join(p) for s in args for p in itertools.permutations(s))
29+
30+
31+
MODE_REMOVE = frozenset(_permute('U', 'r', 'rU', 'rt'))
32+
MODE_REPLACE_R = frozenset(_permute('Ub'))
33+
MODE_REMOVE_T = frozenset(_plus(_permute('at', 'rt', 'wt', 'xt')))
34+
MODE_REMOVE_U = frozenset(_permute('rUb'))
35+
MODE_REPLACE = MODE_REPLACE_R | MODE_REMOVE_T | MODE_REMOVE_U
2536

2637

2738
class FunctionArg(NamedTuple):
@@ -35,12 +46,15 @@ def _fix_open_mode(i: int, tokens: list[Token], *, arg_idx: int) -> None:
3546
mode = tokens_to_src(tokens[slice(*func_args[arg_idx])])
3647
mode_stripped = mode.split('=')[-1]
3748
mode_stripped = ast.literal_eval(mode_stripped.strip())
38-
if mode_stripped in U_MODE_REMOVE:
49+
if mode_stripped in MODE_REMOVE:
3950
delete_argument(arg_idx, tokens, func_args)
40-
elif mode_stripped in U_MODE_REPLACE_R:
51+
elif mode_stripped in MODE_REPLACE_R:
4152
new_mode = mode.replace('U', 'r')
4253
tokens[slice(*func_args[arg_idx])] = [Token('SRC', new_mode)]
43-
elif mode_stripped in U_MODE_REMOVE_U:
54+
elif mode_stripped in MODE_REMOVE_T:
55+
new_mode = mode.replace('t', '')
56+
tokens[slice(*func_args[arg_idx])] = [Token('SRC', new_mode)]
57+
elif mode_stripped in MODE_REMOVE_U:
4458
new_mode = mode.replace('U', '')
4559
tokens[slice(*func_args[arg_idx])] = [Token('SRC', new_mode)]
4660
else:
@@ -69,13 +83,10 @@ def visit_Call(
6983
):
7084
if len(node.args) >= 2 and isinstance(node.args[1], ast.Str):
7185
if (
72-
node.args[1].s in U_MODE_REPLACE or
73-
(len(node.args) == 2 and node.args[1].s in U_MODE_REMOVE)
86+
node.args[1].s in MODE_REPLACE or
87+
(len(node.args) == 2 and node.args[1].s in MODE_REMOVE)
7488
):
75-
func = functools.partial(
76-
_fix_open_mode,
77-
arg_idx=1,
78-
)
89+
func = functools.partial(_fix_open_mode, arg_idx=1)
7990
yield ast_to_offset(node), func
8091
elif node.keywords and (len(node.keywords) + len(node.args) > 1):
8192
mode = next(
@@ -90,8 +101,8 @@ def visit_Call(
90101
mode is not None and
91102
isinstance(mode.value, ast.Str) and
92103
(
93-
mode.value.s in U_MODE_REMOVE or
94-
mode.value.s in U_MODE_REPLACE
104+
mode.value.s in MODE_REMOVE or
105+
mode.value.s in MODE_REPLACE
95106
)
96107
):
97108
func = functools.partial(

tests/features/open_mode_test.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,18 @@
44

55
from pyupgrade._data import Settings
66
from pyupgrade._main import _fix_plugins
7+
from pyupgrade._plugins.open_mode import _permute
8+
from pyupgrade._plugins.open_mode import _plus
9+
10+
11+
def test_plus():
12+
assert _plus(('a',)) == ('a', 'a+')
13+
assert _plus(('a', 'b')) == ('a', 'b', 'a+', 'b+')
14+
15+
16+
def test_permute():
17+
assert _permute('ab') == ('ab', 'ba')
18+
assert _permute('abc') == ('abc', 'acb', 'bac', 'bca', 'cab', 'cba')
719

820

921
@pytest.mark.parametrize(
@@ -18,8 +30,6 @@
1830
'open("foo", qux="r")',
1931
'open("foo", 3)',
2032
'open(mode="r")',
21-
# TODO: could maybe be rewritten to remove t?
22-
'open("foo", "wt")',
2333
# don't remove this, they meant to use `encoding=`
2434
'open("foo", "r", "utf-8")',
2535
),
@@ -70,6 +80,11 @@ def test_fix_open_mode_noop(s):
7080
'open("foo")',
7181
id='io.open also rewrites modes in a single pass',
7282
),
83+
('open("foo", "wt")', 'open("foo", "w")'),
84+
('open("foo", "xt")', 'open("foo", "x")'),
85+
('open("foo", "at")', 'open("foo", "a")'),
86+
('open("foo", "wt+")', 'open("foo", "w+")'),
87+
('open("foo", "rt+")', 'open("foo", "r+")'),
7388
),
7489
)
7590
def test_fix_open_mode(s, expected):

0 commit comments

Comments
 (0)