PatchableGraph

class PatchableGraph(*args: Any, **kwargs: Any)

PatchableGraph is a wrapper around Module allowing activation patching at any tensor operation. It is the main entrypoint for graphpatch’s functionality.

Internally, PatchableGraph builds a GraphModule for the module and each of its submodules using torch.compile(). This exposes the computational structure of the module while still being equivalent to the original–you can perform any operation you would with the original module using the PatchableGraph. In case compilation fails–compile() is not yet compatible with all model code–PatchableGraph will fall back to patching submodule input, output, parameters, and buffers. See ExtractionOptions for options controlling this behavior and Notes on compilation for more discussion.

To perform activation patching, use the patch context manager. This method takes a mapping from NodePaths to lists of Patch to apply at the corresponding node. Note that the activation patches will only be applied inside the context block; using the PatchableGraph outside such a block is equivalent to running the original module.

Example

>>> from graphpatch import PatchableGraph, ZeroPatch
>>> my_llm, my_tokenizer = MyLLM(), MyTokenizer()
>>> my_inputs = MyTokenizer("Hello, ")
>>> patchable_graph = PatchableGraph(my_llm, **my_inputs)
# Patch the input to the third layer's MLP
>>> with patchable_graph.patch({"layers_2.mlp.x": [ZeroPatch()]):
>>>    patched_output = patchable_graph(**my_inputs)
Parameters:
  • module – The Module to wrap.

  • extraction_options_and_args – Arguments (example inputs) to be passed to the module during torch.compile(). If the first argument is an ExtractionOptions instance, apply them during graph extraction.

  • extraction_kwargs – Keyword arguments to be passed to the module during torch.compile().

property graph

Convenience property for working in REPL and notebook environments. Exposes the full NodePath hierarchy of this PatchableGraph via recursive attribute access. Children of the current node can be tab-completed at each step. Has a custom __repr__() to display the subgraph rooted at the current path.

Example:

In [1]: pg.graph
Out[1]: 
<root>: Graph(3)
├─x: Tensor(3, 2)
├─linear: Graph(5)
│ ├─input: Tensor(3, 2)
│ ├─weight: Tensor(3, 2)
│ ├─bias: Tensor(3)
│ ├─linear: Tensor(3, 3)
│ └─output: Tensor(3, 3)
└─output: Tensor(3, 3)

In [2]: pg.graph.linear._code
Out[2]: 
Calling context:
File "/Users/evanlloyd/graphpatch/tests/fixtures/minimal_module.py", line 16, in forward
    return self.linear(x)
Compiled code:
def forward(self, input : torch.Tensor):
    input_1 = input
    weight = self.weight
    bias = self.bias
    linear = torch._C._nn.linear(input_1, weight, bias);  input_1 = weight = bias = None
    return linear

In [3]: pg.graph.output._shape
Out[3]: torch.Size([3, 3])

Also see Inspecting the graph structure for more discussion and examples.

static load(file: str | PathLike | BinaryIO | IO[bytes], *args: Any, **kwargs: Any) PatchableGraph

Wrapper around torch.load(). All the normal caveats around pickling apply; you should not load() anything you downloaded from the Internet.

Future versions of graphpatch will likely implement a more secure serialization scheme and disable the built-in torch.load().

patch(patch_map: Dict[str | NodePath, List[Patch[Tensor | float | int | bool]] | Patch[Tensor | float | int | bool]]) Iterator[None]

Context manager that will cause the given activation patches to be applied when running inference on the wrapped module.

Parameters:

patch_map – A mapping from NodePath to a Patch or list of Patches to apply to each respective node during inference.

Yields:

A context in which the given activation patch(es) will be applied when calling self.forward().

Raises:
save(*args: Any, **kwargs: Any) None

Wrapper around torch.save() because some PatchableGraph internals may need to be handled specially before pickling.

Future versions of graphpatch will likely implement a more secure serialization scheme and disable the built-in torch.save().