PatchableGraph¶
- class PatchableGraph(*args: Any, **kwargs: Any)¶
PatchableGraph is a wrapper around
Moduleallowing activation patching at any tensor operation. It is the main entrypoint forgraphpatch’s functionality.Internally, PatchableGraph builds a
GraphModulefor the module and each of its submodules usingtorch.compile(). This exposes the computational structure of the module while still being equivalent to the original–you can perform any operation you would with the original module using the PatchableGraph. In case compilation fails–compile()is not yet compatible with all model code–PatchableGraph will fall back to patching submodule input, output, parameters, and buffers. SeeExtractionOptionsfor options controlling this behavior and Notes on compilation for more discussion.To perform activation patching, use the
patchcontext manager. This method takes a mapping from NodePaths to lists of Patch to apply at the corresponding node. Note that the activation patches will only be applied inside the context block; using the PatchableGraph outside such a block is equivalent to running the original module.Example
>>> from graphpatch import PatchableGraph, ZeroPatch >>> my_llm, my_tokenizer = MyLLM(), MyTokenizer() >>> my_inputs = MyTokenizer("Hello, ") >>> patchable_graph = PatchableGraph(my_llm, **my_inputs) # Patch the input to the third layer's MLP >>> with patchable_graph.patch({"layers_2.mlp.x": [ZeroPatch()]): >>> patched_output = patchable_graph(**my_inputs)
- Parameters:
module – The
Moduleto wrap.extraction_options_and_args – Arguments (example inputs) to be passed to the module during
torch.compile(). If the first argument is anExtractionOptionsinstance, apply them during graph extraction.extraction_kwargs – Keyword arguments to be passed to the module during
torch.compile().
- property graph¶
Convenience property for working in REPL and notebook environments. Exposes the full
NodePathhierarchy of this PatchableGraph via recursive attribute access. Children of the current node can be tab-completed at each step. Has a custom__repr__()to display the subgraph rooted at the current path.Example:
In [1]: pg.graph Out[1]: <root>: Graph(3) ├─x: Tensor(3, 2) ├─linear: Graph(5) │ ├─input: Tensor(3, 2) │ ├─weight: Tensor(3, 2) │ ├─bias: Tensor(3) │ ├─linear: Tensor(3, 3) │ └─output: Tensor(3, 3) └─output: Tensor(3, 3) In [2]: pg.graph.linear._code Out[2]: Calling context: File "/Users/evanlloyd/graphpatch/tests/fixtures/minimal_module.py", line 16, in forward return self.linear(x) Compiled code: def forward(self, input : torch.Tensor): input_1 = input weight = self.weight bias = self.bias linear = torch._C._nn.linear(input_1, weight, bias); input_1 = weight = bias = None return linear In [3]: pg.graph.output._shape Out[3]: torch.Size([3, 3])
Also see Inspecting the graph structure for more discussion and examples.
- static load(file: str | PathLike | BinaryIO | IO[bytes], *args: Any, **kwargs: Any) PatchableGraph¶
Wrapper around
torch.load(). All the normal caveats around pickling apply; you should not load() anything you downloaded from the Internet.Future versions of graphpatch will likely implement a more secure serialization scheme and disable the built-in torch.load().
- patch(patch_map: Dict[str | NodePath, List[Patch[Tensor | float | int | bool]] | Patch[Tensor | float | int | bool]]) Iterator[None]¶
Context manager that will cause the given activation patches to be applied when running inference on the wrapped module.
- Parameters:
patch_map – A mapping from NodePath to a Patch or list of Patches to apply to each respective node during inference.
- Yields:
A context in which the given activation patch(es) will be applied when calling
self.forward().- Raises:
KeyError – If any NodePath in
patch_mapdoes not exist in the graph.ValueError – If
patch_maphas any invalid types.
- save(*args: Any, **kwargs: Any) None¶
Wrapper around
torch.save()because some PatchableGraph internals may need to be handled specially before pickling.Future versions of graphpatch will likely implement a more secure serialization scheme and disable the built-in torch.save().