Patch

class Patch

Base class for operations applying to nodes in a PatchableGraph. Derived classes should be keyword-only dataclasses (i.e. decorated with @dataclass(kw_only=True)) and override op().

requires_clone

Whether the operation modifies the original output. Set to True and hidden from the constructor; can be overridden in derived classes for read-only operations.

Type:

bool

path

For nodes that output nested structures, the path within that structure that this operation should apply to. Hidden from the constructor, since setting the path will be handled by PatchableGraph.

Type:

str | None

op(original_output: PatchTarget) PatchTarget

The operation to perform at this node. Should take in a single argument, which will be populated with the original output at this node, and return a value of the same type.

class AddPatch(*, value: Tensor | float | int | bool, slice: TensorSlice | None = None)

Patch that adds a value to (optionally, a slice of) its target.

Example

pg = PatchableGraph(model, **example_inputs)
delta = torch.ones((seq_len - 1,))
with pg.patch({"output": AddPatch(value=delta, slice=(slice(1, None), 0))}):
    patched_outputs = pg(**sample_inputs)
slice

Slice to perform addition on. Applies to full target if None.

Type:

TensorSlice | None

value

Value to add to target.

Type:

torch.Tensor | float | int | bool

class CustomPatch(*, requires_clone: bool = True, custom_op: Callable[[PatchTarget], PatchTarget])

Convenience for one-off patch operations without the need to define a new Patch class. Also exposes the normally hidden requires_clone field for operations that do not require cloning.

Example

Replace the output of a layer’s MLP with that of a previous layer:

pg = PatchableGraph(model, **example_inputs)
with pg.patch(
    {
        "layers_0.mlp.output": [layer_0 := ProbePatch()],
        "layers_1.mlp.output": CustomPatch(custom_op=lambda t: layer_0.activation),
    }
):
    print(pg(**sample_inputs))
custom_op

Operation to perform. Replace output at this node with the return value of custom_op(original_output).

Type:

Callable[[PatchTarget], PatchTarget]

requires_clone

Whether the operation modifies the original output tensor. Defaults to True. For read-only operations, set to False to avoid creating unnecessary copies.

Type:

bool

class ProbePatch

Patch that records the last activation of its target.

Example

pg = PatchableGraph(**example_inputs)
probe = ProbePatch()
with pg.patch({"transformer.h_17.mlp.act.mul_3": probe}):
    pg(**sample_inputs)
print(probe.activation)
activation

Value of the previous activation of its target, or None if not yet recorded.

Type:

torch.Tensor | float | int | bool | None

class RecordPatch(*, activations: ~typing.List[~torch.Tensor | float | int | bool] = <factory>)

Patch that records all activations of its target.

Example

Replace a layer’s output with a running mean of the previous layer’s activations:

pg = PatchableGraph(**example_inputs)
record = RecordPatch()
for i in range(10):
    with pg.patch(
        {
            "layers_0.output": layer_0,
            "layers_1.output": CustomPatch(
                custom_op=lambda t: torch.mean(
                    torch.stack(record.activations, dim=2), dim=2
                )
            ),
        }
    ):
        print(pg(**sample_inputs[i]))
activations

List of activations.

Type:

List[torch.Tensor | float | int | bool]

class ReplacePatch(*, slice: TensorSlice | None = None, value: Tensor | float | int | bool)

Patch that replaces (optionally, a slice of) its target with the given value.

Example

pg = PatchableGraph(**example_inputs)
with pg.patch("linear.input": ReplacePatch(value=42, slice=(slice(None), 0, 0))):
    print(pg(**sample_inputs))
slice

Slice of the target to replace with value; applies to the whole tensor if None.

Type:

TensorSlice | None

value

Value with which to replace the target or slice of the target.

Type:

torch.Tensor | float | int | bool

class ZeroPatch(*, slice: TensorSlice | None = None)

Patch that zeroes out a slice of its target, or the whole tensor if no slice is provided.

Example

pg = PatchableGraph(**example_inputs)
with pg.patch("layers_0.output": ZeroPatch()):
    print(pg(**sample_inputs))
slice

Slice of the target to apply zeros to; applies to the whole tensor if None.

Type:

TensorSlice | None

Types

TensorSlice: TensorSliceElement | List[TensorSlice] | Tuple[TensorSlice, ...]

This is a datatype representing the indexing operation done when you slice a Tensor, as happens in code like

x[:, 5:8, 2] = 3

This is not a graphpatch-specific type (we have merely aliased it for convenience), but interacts with Python internals which may be unfamiliar.

Briefly, you will almost always want to pass a sequence (tuple or list) with as many elements as the dimensionality of your tensor. Within this sequence, elements can be either integers, subsequences, slices, or Tensors. Each element of the sequence will select a subset of the Tensor along the dimension with the corresponding index. An integer will select a single “row” along that dimension. A subsequence will select multiple “rows”. A slice will select a range of “rows”. (slice(None) selects all rows for that dimension, equivalent to writing a “:” within the bracket expression.) A Tensor will perform a complex operation that is out of the scope of this brief note.

For a concrete example, we can accomplish the above operation with the following ReplacePatch:

ReplacePatch(value=3, slice=((slice(None), slice(5, 8), 2)))

See also: Tensor Indexing API.

TensorSliceElement: int | slice | torch.Tensor

One component of a TensorSlice.

PatchTarget: TypeVar

Generic type argument which will be specialized for patches expecting different data types. Almost always specialized to Tensor.