Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ ALWAYS check whether an existing function already covers your use case before im
- `series()` (`usethis._pipeweld.containers`) — Create a Series pipeline composition from the given components.
- `depgroup()` (`usethis._pipeweld.containers`) — Create a DepGroup pipeline composition tied to a named configuration group.
- `get_endpoint()` (`usethis._pipeweld.func`) — Get the last step name (endpoint) from a pipeline component.
- `get_predecessor()` (`usethis._pipeweld.func`) — Find the step that immediately precedes `step` in a pipeline component.
- `call_subprocess()` (`usethis._subprocess`) — Run a subprocess and return its stdout, raising SubprocessFailedError on failure.
- `change_cwd()` (`usethis._test`) — Change the working directory temporarily.
- `is_offline()` (`usethis._test`) — Return True if the current environment has no internet connectivity.
Expand Down
1 change: 1 addition & 0 deletions docs/functions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
- `series()` (`usethis._pipeweld.containers`) — Create a Series pipeline composition from the given components.
- `depgroup()` (`usethis._pipeweld.containers`) — Create a DepGroup pipeline composition tied to a named configuration group.
- `get_endpoint()` (`usethis._pipeweld.func`) — Get the last step name (endpoint) from a pipeline component.
- `get_predecessor()` (`usethis._pipeweld.func`) — Find the step that immediately precedes `step` in a pipeline component.
- `call_subprocess()` (`usethis._subprocess`) — Run a subprocess and return its stdout, raising SubprocessFailedError on failure.
- `change_cwd()` (`usethis._test`) — Change the working directory temporarily.
- `is_offline()` (`usethis._test`) — Return True if the current environment has no internet connectivity.
Expand Down
44 changes: 20 additions & 24 deletions src/usethis/_integrations/pre_commit/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)
from usethis._integrations.pre_commit.language import get_system_language
from usethis._integrations.pre_commit.yaml import PreCommitConfigYAMLManager
from usethis._pipeweld.containers import series
from usethis._pipeweld.func import Adder, get_predecessor

if TYPE_CHECKING:
from collections.abc import Collection
Expand Down Expand Up @@ -67,39 +69,33 @@ def add_repo(repo: schema.LocalRepo | schema.UriRepo) -> None:
mgr.commit_model(model)
else:
# There are existing hooks so we need to know where to insert the new hook.

# Get the precendents, i.e. hooks occurring before the new hook
# Also the successors, i.e. hooks occurring after the new hook
# Use pipeweld to determine the correct insertion position based on the
# canonical hook ordering.
try:
hook_idx = _HOOK_ORDER.index(hook_config.id)
except ValueError:
msg = f"Hook '{hook_config.id}' not recognized."
raise NotImplementedError(msg) from None
precedents = _HOOK_ORDER[:hook_idx]
successors = _HOOK_ORDER[hook_idx + 1 :]

existing_precedents = [hook for hook in existing_hooks if hook in precedents]
existing_successors = [hook for hook in existing_hooks if hook in successors]

# Add immediately after the last precedecessor.
# If there isn't one, we want to add as late as possible without violating
# order, i.e. before the first successor, if there is one.
if existing_precedents:
last_precedent = existing_precedents[-1]
elif not existing_successors:
last_precedent = existing_hooks[-1]
else:
first_successor = existing_successors[0]
first_successor_idx = existing_hooks.index(first_successor)
if first_successor_idx == 0:
last_precedent = None
else:
last_precedent = existing_hooks[first_successor_idx - 1]

prerequisites = set(_HOOK_ORDER[:hook_idx])
postrequisites = set(_HOOK_ORDER[hook_idx + 1 :])

pipeline = series(*existing_hooks)
adder = Adder(
pipeline=pipeline,
step=hook_config.id,
prerequisites=prerequisites,
postrequisites=postrequisites,
force_linear=True,
)
result = adder.add()

predecessor = get_predecessor(result.solution, hook_config.id)

model.repos = insert_repo(
repo_to_insert=repo,
existing_repos=model.repos,
predecessor=last_precedent,
predecessor=predecessor,
)

mgr.commit_model(model)
Expand Down
100 changes: 100 additions & 0 deletions src/usethis/_pipeweld/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Adder(BaseModel):
prerequisites: set[str] = set()
postrequisites: set[str] = set()
compatible_config_groups: set[str] = set()
force_linear: bool = False

def add(self) -> WeldResult:
"""Add the step to the pipeline and return the modified pipeline with instructions."""
Expand All @@ -70,6 +71,11 @@ def add(self) -> WeldResult:

instructions += new_instructions

if self.force_linear:
original_order = _extract_ordered_steps(self.pipeline)
flat = _linearize_component(rearranged_pipeline, self.step, original_order)
rearranged_pipeline = series(*flat)

return WeldResult(
solution=rearranged_pipeline,
instructions=instructions,
Expand Down Expand Up @@ -616,3 +622,97 @@ def get_endpoint(component: str | Series | DepGroup | Parallel) -> str:
return get_endpoint(component.series)
else:
assert_never(component)


def get_predecessor(
component: str | Series | Parallel | DepGroup, step: str
) -> str | None:
"""Find the step that immediately precedes `step` in a pipeline component.

Returns `None` if `step` is the first step in the component.
Raises `ValueError` if `step` is not found in the component.
"""
if isinstance(component, str):
if component == step:
return None
msg = f"Step '{step}' not found in component."
raise ValueError(msg)
elif isinstance(component, Series):
for i, sub in enumerate(component.root):
if _has_any_steps(sub, steps={step}):
inner = get_predecessor(sub, step)
if inner is not None:
return inner
# step is first within this sub-component; look to previous sub
if i > 0:
return get_endpoint(component.root[i - 1])
return None
msg = f"Step '{step}' not found in component."
raise ValueError(msg)
elif isinstance(component, Parallel):
for sub in component.root:
if _has_any_steps(sub, steps={step}):
return get_predecessor(sub, step)
msg = f"Step '{step}' not found in component."
raise ValueError(msg)
elif isinstance(component, DepGroup):
return get_predecessor(component.series, step)
else:
assert_never(component)


def _extract_ordered_steps(
component: str | Series | Parallel | DepGroup,
) -> list[str]:
"""Extract all step names from a component in depth-first order."""
if isinstance(component, str):
return [component]
elif isinstance(component, Series | Parallel):
return [s for sub in component.root for s in _extract_ordered_steps(sub)]
elif isinstance(component, DepGroup):
return _extract_ordered_steps(component.series)
else:
assert_never(component)


def _linearize_component(
component: str | Series | Parallel | DepGroup,
new_step: str,
original_order: list[str],
) -> list[str]:
"""Flatten a pipeline component to a linear list of step names.

Within parallel groups, existing steps maintain their relative order from
``original_order`` and the new step is placed after all existing steps.
"""
if isinstance(component, str):
return [component]
elif isinstance(component, Series):
result: list[str] = []
for sub in component.root:
result.extend(_linearize_component(sub, new_step, original_order))
return result
elif isinstance(component, Parallel):
sublists = [
_linearize_component(sub, new_step, original_order)
for sub in component.root
]
all_items = [item for sublist in sublists for item in sublist]

def sort_key(item: str) -> tuple[int, int]:
# Existing steps sort by their original position (priority 0);
# the new step sorts last (priority 1).
# Steps not found in original_order are placed after all known steps.
if item == new_step:
return (1, 0)
try:
return (0, original_order.index(item))
except ValueError:
return (0, len(original_order))

all_items.sort(key=sort_key)
return all_items
elif isinstance(component, DepGroup):
return _linearize_component(component.series, new_step, original_order)
else:
assert_never(component)
119 changes: 119 additions & 0 deletions tests/usethis/_integrations/pre_commit/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,3 +743,122 @@ def test_aliases(self):
schema.HookDefinition(id="ruff-check"),
schema.HookDefinition(id="ruff"),
)


class TestAddRepoPipeweld:
"""Integration tests for pipeweld-based hook insertion."""

def test_insert_between_nondependent_and_postrequisite(self, tmp_path: Path):
"""Insert a recognized hook between an unrecognized hook and a postrequisite."""
with change_cwd(tmp_path), files_manager():
# Set up: foo (unrecognized) then codespell (recognized, late in order)
add_repo(
schema.LocalRepo(
repo="local",
hooks=[
schema.HookDefinition(
id="foo",
name="foo",
entry="foo .",
language=schema.Language("system"),
)
],
),
)
add_repo(
schema.LocalRepo(
repo="local",
hooks=[
schema.HookDefinition(
id="codespell",
name="codespell",
entry="codespell .",
language=schema.Language("system"),
)
],
)
)

# Act: add ruff-format (comes before codespell, after foo)
add_repo(
schema.LocalRepo(
repo="local",
hooks=[
schema.HookDefinition(
id="ruff-format",
name="ruff-format",
entry="ruff format .",
language=schema.Language("system"),
)
],
)
)

# Assert: ruff-format should be between foo and codespell
assert get_hook_ids() == ["foo", "ruff-format", "codespell"]

def test_insert_with_prerequisite_present(self, tmp_path: Path):
"""Insert a hook after an existing prerequisite."""
with change_cwd(tmp_path), files_manager():
add_repo(
schema.LocalRepo(
repo="local",
hooks=[
schema.HookDefinition(
id="ruff-check",
name="ruff-check",
entry="ruff check .",
language=schema.Language("system"),
)
],
)
)

add_repo(
schema.LocalRepo(
repo="local",
hooks=[
schema.HookDefinition(
id="ruff-format",
name="ruff-format",
entry="ruff format .",
language=schema.Language("system"),
)
],
)
)

assert get_hook_ids() == ["ruff-check", "ruff-format"]

def test_insert_before_postrequisite_only(self, tmp_path: Path):
"""Insert a hook before an existing postrequisite when no predecessor exists."""
with change_cwd(tmp_path), files_manager():
add_repo(
schema.LocalRepo(
repo="local",
hooks=[
schema.HookDefinition(
id="codespell",
name="codespell",
entry="codespell .",
language=schema.Language("system"),
)
],
)
)

add_repo(
schema.LocalRepo(
repo="local",
hooks=[
schema.HookDefinition(
id="ruff-check",
name="ruff-check",
entry="ruff check .",
language=schema.Language("system"),
)
],
)
)

assert get_hook_ids() == ["ruff-check", "codespell"]
Loading
Loading