Skip to content

Commit 6cd17e3

Browse files
Handle installation of extras properly (#231)
* Handle installation of extras properly * Pass pyright * Update tests
1 parent b933fe7 commit 6cd17e3

7 files changed

Lines changed: 303 additions & 146 deletions

File tree

src/usethis/_core/tool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def use_pre_commit(*, remove: bool = False) -> None:
9393

9494
if pyproject_fmt_tool.is_used():
9595
# We will use pre-commit instead of the dev-dep.
96-
remove_deps_from_group(pyproject_fmt_tool.get_unique_dev_deps(), "dev")
96+
remove_deps_from_group(pyproject_fmt_tool.dev_deps, "dev")
9797
pyproject_fmt_tool.add_pyproject_configs()
9898
_pyproject_fmt_instructions_pre_commit()
9999

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
import re
2-
31
from packaging.requirements import Requirement
4-
from pydantic import TypeAdapter
2+
from pydantic import BaseModel, TypeAdapter
53

64
from usethis._config import usethis_config
75
from usethis._console import tick_print
@@ -10,7 +8,19 @@
108
from usethis._integrations.uv.errors import UVDepGroupError, UVSubprocessFailedError
119

1210

13-
def get_dep_groups() -> dict[str, list[str]]:
11+
class Dependency(BaseModel):
12+
name: str
13+
extras: frozenset[str] = frozenset()
14+
15+
def __str__(self) -> str:
16+
extras = sorted(self.extras or set())
17+
return self.name + "".join(f"[{extra}]" for extra in extras)
18+
19+
def __hash__(self) -> int:
20+
return hash((self.__class__.__name__, self.name, self.extras))
21+
22+
23+
def get_dep_groups() -> dict[str, list[Dependency]]:
1424
pyproject = read_pyproject_toml()
1525
try:
1626
dep_groups_section = pyproject["dependency-groups"]
@@ -23,71 +33,85 @@ def get_dep_groups() -> dict[str, list[str]]:
2333
dep_groups_section
2434
)
2535
reqs_by_group = {
26-
group: [Requirement(req_str).name for req_str in req_strs]
36+
group: [Requirement(req_str) for req_str in req_strs]
2737
for group, req_strs in req_strs_by_group.items()
2838
}
29-
return reqs_by_group
39+
deps_by_group = {
40+
group: [Dependency(name=req.name, extras=frozenset(req.extras)) for req in reqs]
41+
for group, reqs in reqs_by_group.items()
42+
}
43+
return deps_by_group
3044

3145

32-
def get_deps_from_group(group: str) -> list[str]:
46+
def get_deps_from_group(group: str) -> list[Dependency]:
3347
dep_groups = get_dep_groups()
3448
try:
3549
return dep_groups[group]
3650
except KeyError:
3751
return []
3852

3953

40-
def add_deps_to_group(deps: list[str], group: str) -> None:
54+
def add_deps_to_group(deps: list[Dependency], group: str) -> None:
4155
"""Add a package as a non-build dependency using PEP 735 dependency groups."""
4256
existing_group = get_deps_from_group(group)
4357

44-
_deps = [dep for dep in deps if _strip_extras(dep) not in existing_group]
58+
to_add_deps = [
59+
dep for dep in deps if not is_dep_satisfied_in(dep, in_=existing_group)
60+
]
4561

46-
if not _deps:
62+
if not to_add_deps:
4763
return
4864

49-
deps_str = ", ".join([f"'{_strip_extras(dep)}'" for dep in _deps])
50-
ies = "y" if len(_deps) == 1 else "ies"
65+
deps_str = ", ".join([f"'{dep}'" for dep in to_add_deps])
66+
ies = "y" if len(to_add_deps) == 1 else "ies"
5167
tick_print(
5268
f"Adding dependenc{ies} {deps_str} to the '{group}' group in 'pyproject.toml'."
5369
)
5470

55-
for dep in _deps:
71+
for dep in to_add_deps:
5672
try:
5773
if not usethis_config.offline:
58-
call_uv_subprocess(["add", "--group", group, "--quiet", dep])
74+
call_uv_subprocess(["add", "--group", group, "--quiet", str(dep)])
5975
else:
6076
call_uv_subprocess(
61-
["add", "--group", group, "--quiet", "--offline", dep]
77+
["add", "--group", group, "--quiet", "--offline", str(dep)]
6278
)
6379
except UVSubprocessFailedError as err:
6480
msg = f"Failed to add '{dep}' to the '{group}' dependency group:\n{err}"
6581
raise UVDepGroupError(msg) from None
6682

6783

68-
def remove_deps_from_group(deps: list[str], group: str) -> None:
84+
def is_dep_satisfied_in(dep: Dependency, *, in_: list[Dependency]) -> bool:
85+
return any(_is_dep_satisfied_by(dep, by=by) for by in in_)
86+
87+
88+
def _is_dep_satisfied_by(dep: Dependency, *, by: Dependency) -> bool:
89+
# Name is the same and extras are a subset of the extras of the dependency
90+
return dep.name == by.name and (dep.extras or set()) <= (by.extras or set())
91+
92+
93+
def remove_deps_from_group(deps: list[Dependency], group: str) -> None:
6994
"""Remove the tool's development dependencies, if present."""
7095
existing_group = get_deps_from_group(group)
7196

72-
_deps = [dep for dep in deps if _strip_extras(dep) in existing_group]
97+
_deps = [dep for dep in deps if is_dep_satisfied_in(dep, in_=existing_group)]
7398

7499
if not _deps:
75100
return
76101

77-
deps_str = ", ".join([f"'{_strip_extras(dep)}'" for dep in _deps])
102+
deps_str = ", ".join([f"'{dep}'" for dep in _deps])
78103
ies = "y" if len(_deps) == 1 else "ies"
79104
tick_print(
80105
f"Removing dependenc{ies} {deps_str} from the '{group}' group in 'pyproject.toml'."
81106
)
82107

83108
for dep in _deps:
84109
try:
85-
se_dep = _strip_extras(dep)
86110
if not usethis_config.offline:
87-
call_uv_subprocess(["remove", "--group", group, "--quiet", se_dep])
111+
call_uv_subprocess(["remove", "--group", group, "--quiet", str(dep)])
88112
else:
89113
call_uv_subprocess(
90-
["remove", "--group", group, "--quiet", "--offline", se_dep]
114+
["remove", "--group", group, "--quiet", "--offline", str(dep)]
91115
)
92116
except UVSubprocessFailedError as err:
93117
msg = (
@@ -96,12 +120,7 @@ def remove_deps_from_group(deps: list[str], group: str) -> None:
96120
raise UVDepGroupError(msg) from None
97121

98122

99-
def is_dep_in_any_group(dep: str) -> bool:
100-
return _strip_extras(dep) in {
101-
dep for group in get_dep_groups().values() for dep in group
102-
}
103-
104-
105-
def _strip_extras(dep: str) -> str:
106-
"""Remove extras from a dependency string."""
107-
return re.sub(r"\[.*\]", "", dep)
123+
def is_dep_in_any_group(dep: Dependency) -> bool:
124+
return is_dep_satisfied_in(
125+
dep, in_=[dep for group in get_dep_groups().values() for dep in group]
126+
)

src/usethis/_tool.py

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
remove_config_value,
2525
set_config_value,
2626
)
27-
from usethis._integrations.uv.deps import is_dep_in_any_group
27+
from usethis._integrations.uv.deps import Dependency, is_dep_in_any_group
2828

2929

3030
class Tool(Protocol):
@@ -38,8 +38,8 @@ def name(self) -> str:
3838
"""
3939

4040
@property
41-
def dev_deps(self) -> list[str]:
42-
"""The name of the tool's development dependencies."""
41+
def dev_deps(self) -> list[Dependency]:
42+
"""The tool's development dependencies."""
4343
return []
4444

4545
def get_pre_commit_repos(self) -> list[LocalRepo | UriRepo]:
@@ -54,10 +54,6 @@ def get_associated_ruff_rules(self) -> list[str]:
5454
"""Get the Ruff rule codes associated with the tool."""
5555
return []
5656

57-
def get_unique_dev_deps(self) -> list[str]:
58-
"""Any development dependencies only used by this tool (not shared)."""
59-
return self.dev_deps
60-
6157
def get_managed_files(self) -> list[Path]:
6258
"""Get (relative) paths to files managed by the tool."""
6359
return []
@@ -74,9 +70,7 @@ def is_used(self) -> bool:
7470
2. Whether any of the tool's managed files are in the project.
7571
3. Whether any of the tool's managed pyproject.toml sections are present.
7672
"""
77-
is_any_deps = any(
78-
is_dep_in_any_group(dep) for dep in self.get_unique_dev_deps()
79-
)
73+
is_any_deps = any(is_dep_in_any_group(dep) for dep in self.dev_deps)
8074
is_any_files = any(
8175
file.exists() and file.is_file() for file in self.get_managed_files()
8276
)
@@ -167,8 +161,8 @@ def name(self) -> str:
167161
return "coverage"
168162

169163
@property
170-
def dev_deps(self) -> list[str]:
171-
return ["coverage[toml]"]
164+
def dev_deps(self) -> list[Dependency]:
165+
return [Dependency(name="coverage", extras=frozenset({"toml"}))]
172166

173167
def get_pyproject_configs(self) -> list[PyProjectConfig]:
174168
return [
@@ -194,10 +188,10 @@ def get_pyproject_configs(self) -> list[PyProjectConfig]:
194188
),
195189
]
196190

197-
def get_pyproject_id_keys(self):
191+
def get_pyproject_id_keys(self) -> list[list[str]]:
198192
return [["tool", "coverage"]]
199193

200-
def get_managed_files(self):
194+
def get_managed_files(self) -> list[Path]:
201195
return [Path(".coveragerc")]
202196

203197

@@ -207,8 +201,8 @@ def name(self) -> str:
207201
return "deptry"
208202

209203
@property
210-
def dev_deps(self) -> list[str]:
211-
return ["deptry"]
204+
def dev_deps(self) -> list[Dependency]:
205+
return [Dependency(name="deptry")]
212206

213207
def get_pre_commit_repos(self) -> list[LocalRepo | UriRepo]:
214208
return [
@@ -234,10 +228,10 @@ def name(self) -> str:
234228
return "pre-commit"
235229

236230
@property
237-
def dev_deps(self) -> list[str]:
238-
return ["pre-commit"]
231+
def dev_deps(self) -> list[Dependency]:
232+
return [Dependency(name="pre-commit")]
239233

240-
def get_managed_files(self):
234+
def get_managed_files(self) -> list[Path]:
241235
return [Path(".pre-commit-config.yaml")]
242236

243237

@@ -247,8 +241,8 @@ def name(self) -> str:
247241
return "pyproject-fmt"
248242

249243
@property
250-
def dev_deps(self) -> list[str]:
251-
return ["pyproject-fmt"]
244+
def dev_deps(self) -> list[Dependency]:
245+
return [Dependency(name="pyproject-fmt")]
252246

253247
def get_pre_commit_repos(self) -> list[LocalRepo | UriRepo]:
254248
return [
@@ -267,7 +261,7 @@ def get_pyproject_configs(self) -> list[PyProjectConfig]:
267261
)
268262
]
269263

270-
def get_pyproject_id_keys(self):
264+
def get_pyproject_id_keys(self) -> list[list[str]]:
271265
return [["tool", "pyproject-fmt"]]
272266

273267

@@ -277,8 +271,11 @@ def name(self) -> str:
277271
return "pytest"
278272

279273
@property
280-
def dev_deps(self) -> list[str]:
281-
return ["pytest", "pytest-cov"]
274+
def dev_deps(self) -> list[Dependency]:
275+
return [
276+
Dependency(name="pytest"),
277+
Dependency(name="pytest-cov"),
278+
]
282279

283280
def get_pyproject_configs(self) -> list[PyProjectConfig]:
284281
return [
@@ -299,13 +296,10 @@ def get_pyproject_configs(self) -> list[PyProjectConfig]:
299296
def get_associated_ruff_rules(self) -> list[str]:
300297
return ["PT"]
301298

302-
def get_unique_dev_deps(self):
303-
return ["pytest", "pytest-cov"]
304-
305-
def get_pyproject_id_keys(self):
299+
def get_pyproject_id_keys(self) -> list[list[str]]:
306300
return [["tool", "pytest"]]
307301

308-
def get_managed_files(self):
302+
def get_managed_files(self) -> list[Path]:
309303
return [Path("tests/conftest.py")]
310304

311305

@@ -315,7 +309,7 @@ def name(self) -> str:
315309
return "requirements.txt"
316310

317311
@property
318-
def dev_deps(self) -> list[str]:
312+
def dev_deps(self) -> list[Dependency]:
319313
return []
320314

321315
def get_pre_commit_repos(self) -> list[LocalRepo | UriRepo]:
@@ -346,8 +340,8 @@ def name(self) -> str:
346340
return "Ruff"
347341

348342
@property
349-
def dev_deps(self) -> list[str]:
350-
return ["ruff"]
343+
def dev_deps(self) -> list[Dependency]:
344+
return [Dependency(name="ruff")]
351345

352346
def get_pre_commit_repos(self) -> list[LocalRepo | UriRepo]:
353347
return [
@@ -397,10 +391,10 @@ def get_pyproject_configs(self) -> list[PyProjectConfig]:
397391
)
398392
]
399393

400-
def get_pyproject_id_keys(self):
394+
def get_pyproject_id_keys(self) -> list[list[str]]:
401395
return [["tool", "ruff"]]
402396

403-
def get_managed_files(self):
397+
def get_managed_files(self) -> list[Path]:
404398
return [Path("ruff.toml"), Path(".ruff.toml")]
405399

406400

0 commit comments

Comments
 (0)