Patch¶
- class Patch¶
Base class for operations applying to nodes in a
PatchableGraph. Derived classes should be keyword-onlydataclasses(i.e. decorated with@dataclass(kw_only=True)) and overrideop().- requires_clone¶
Whether the operation modifies the original output. Set to True and hidden from the constructor; can be overridden in derived classes for read-only operations.
- Type:
- path¶
For nodes that output nested structures, the path within that structure that this operation should apply to. Hidden from the constructor, since setting the path will be handled by
PatchableGraph.- Type:
str | None
- op(original_output: PatchTarget) PatchTarget¶
The operation to perform at this node. Should take in a single argument, which will be populated with the original output at this node, and return a value of the same type.
- class AddPatch(*, value: Tensor | float | int | bool, slice: TensorSlice | None = None)¶
Patch that adds a value to (optionally, a slice of) its target.
Example
pg = PatchableGraph(model, **example_inputs) delta = torch.ones((seq_len - 1,)) with pg.patch({"output": AddPatch(value=delta, slice=(slice(1, None), 0))}): patched_outputs = pg(**sample_inputs)
- slice¶
Slice to perform addition on. Applies to full target if None.
- Type:
TensorSlice | None
- value¶
Value to add to target.
- Type:
torch.Tensor | float | int | bool
- class CustomPatch(*, requires_clone: bool = True, custom_op: Callable[[PatchTarget], PatchTarget])¶
Convenience for one-off patch operations without the need to define a new Patch class. Also exposes the normally hidden
requires_clonefield for operations that do not require cloning.Example
Replace the output of a layer’s MLP with that of a previous layer:
pg = PatchableGraph(model, **example_inputs) with pg.patch( { "layers_0.mlp.output": [layer_0 := ProbePatch()], "layers_1.mlp.output": CustomPatch(custom_op=lambda t: layer_0.activation), } ): print(pg(**sample_inputs))
- custom_op¶
Operation to perform. Replace output at this node with the return value of
custom_op(original_output).- Type:
Callable[[PatchTarget], PatchTarget]
- class ProbePatch¶
Patch that records the last activation of its target.
Example
pg = PatchableGraph(**example_inputs) probe = ProbePatch() with pg.patch({"transformer.h_17.mlp.act.mul_3": probe}): pg(**sample_inputs) print(probe.activation)
- activation¶
Value of the previous activation of its target, or None if not yet recorded.
- Type:
torch.Tensor | float | int | bool | None
- class RecordPatch(*, activations: ~typing.List[~torch.Tensor | float | int | bool] = <factory>)¶
Patch that records all activations of its target.
Example
Replace a layer’s output with a running mean of the previous layer’s activations:
pg = PatchableGraph(**example_inputs) record = RecordPatch() for i in range(10): with pg.patch( { "layers_0.output": layer_0, "layers_1.output": CustomPatch( custom_op=lambda t: torch.mean( torch.stack(record.activations, dim=2), dim=2 ) ), } ): print(pg(**sample_inputs[i]))
- activations¶
List of activations.
- Type:
List[torch.Tensor | float | int | bool]
- class ReplacePatch(*, slice: TensorSlice | None = None, value: Tensor | float | int | bool)¶
Patch that replaces (optionally, a slice of) its target with the given value.
Example
pg = PatchableGraph(**example_inputs) with pg.patch("linear.input": ReplacePatch(value=42, slice=(slice(None), 0, 0))): print(pg(**sample_inputs))
- slice¶
Slice of the target to replace with
value; applies to the whole tensor if None.- Type:
TensorSlice | None
- value¶
Value with which to replace the target or slice of the target.
- Type:
torch.Tensor | float | int | bool
- class ZeroPatch(*, slice: TensorSlice | None = None)¶
Patch that zeroes out a slice of its target, or the whole tensor if no slice is provided.
Example
pg = PatchableGraph(**example_inputs) with pg.patch("layers_0.output": ZeroPatch()): print(pg(**sample_inputs))
- slice¶
Slice of the target to apply zeros to; applies to the whole tensor if None.
- Type:
TensorSlice | None
Types¶
- TensorSlice: TensorSliceElement | List[TensorSlice] | Tuple[TensorSlice, ...]¶
This is a datatype representing the indexing operation done when you slice a
Tensor, as happens in code likex[:, 5:8, 2] = 3
This is not a
graphpatch-specific type (we have merely aliased it for convenience), but interacts withPython internalswhich may be unfamiliar.Briefly, you will almost always want to pass a sequence (tuple or list) with as many elements as the dimensionality of your tensor. Within this sequence, elements can be either integers, subsequences,
slices, or Tensors. Each element of the sequence will select a subset of the Tensor along the dimension with the corresponding index. An integer will select a single “row” along that dimension. A subsequence will select multiple “rows”. A slice will select a range of “rows”. (slice(None)selects all rows for that dimension, equivalent to writing a “:” within the bracket expression.) A Tensor will perform a complex operation that is out of the scope of this brief note.For a concrete example, we can accomplish the above operation with the following
ReplacePatch:ReplacePatch(value=3, slice=((slice(None), slice(5, 8), 2)))
See also: Tensor Indexing API.
- TensorSliceElement: int | slice | torch.Tensor¶
One component of a
TensorSlice.