RewireComputationPaths#

class penzai.toolshed.model_rewiring.RewireComputationPaths[source]#

Bases: Layer

Rewires computation across parallel model runs along a worlds axis.

This layer can be used to implement sophisticated ablation, activation patching, and path patching analyses. It assumes its input activation has a particular “worlds” axis, which indicates a minibatch of examples that represent counterfactual variants of the input or of the model. It then re-writes the activations by coping activations between the worlds according to its weights.

This layer is intended to be directly inserted into the model at the point where you want to “bridge” the parallel worlds. For instance, if you want to patch the value activations from an attention block, you can insert this inside the input_to_value sublayer of the attention block, and configure it to copy the desired part of the input. If you want to freeze the attention patterns of a set of blocks, you can insert this after the attention softmax, and configure it to copy from the “original” world instead of the current world. You can also use a length-zero tuple to indicate that the value should be entirely dropped and zeroed out; this can be useful for e.g. disabling writes to the residual stream of a transformer.

One useful pattern you can use is to insert a RewireComputationPaths block into a LinearizeAndAdjust block’s linearize_around attribute. This will allow you to linearize a nonlinear operation around a single world’s input, but evaluate the linear approximation around each world individually.

Another useful pattern is to either rewire or zero out the contributions of layers to the residual stream, or rewire the inputs of those layers when they read from the residual stream. This can allow you to measure the residual vectors flowing from one layer to another, or to measure the direct contribution of a layer to the final output logits ignoring the reads or writes from other layers.

Note that this layer is designed to be used for “batched rewiring”: all of the different input conditions are run through the model in a single batched forward pass. For instance, you might have a clean world where nothing is ablated and all rewiring blocks read back from the clean world (a no-op) and a corrupted world where some activations are rewired to be copied from the clean world, where these worlds map to indices 0 and 1 along a “worlds” axis of the input. This is a declarative alternative to running multiple forward passes and saving/restoring activations from a cache. The batched rewiring version is easier to express in Penzai due to being a stateless function, and may reduce memory and compute overhead from saving many small forward passes. It can also be compiled into a single JIT-ted computation using jax.jit (which is even easier if you use penzai.toolshed.jit_wrapper).

Variables:
  • worlds_axis (str) – Axis name of the “worlds” axis (often just “worlds”). Should be an axis name that is NOT already used by the model.

  • world_ordering (tuple[str, ...]) – A tuple of world names. We assume the input will have a worlds axis of the same length as this tuple.

  • taking (dict[str, From | tuple[From, ...]]) – A dictionary that maps destination world names to a source or sources that those worlds should read from. The keys should exactly match world_ordering and represent each world we are outputting to. The values should be instances of From or tuples of instances of From, determining where the value for each world should be taken from. If the taking key and From source are the same, and the weight is 1, this represents a no-op. A common pattern is to have clean or unablated worlds read from themselves, but ablated or corrupted worlds take from the clean worlds.

Methods

__init__(worlds_axis, world_ordering, taking)

path_matrix()

Builds a matrix that maps "from" indices to their "to" indices.

__call__(inputs)

Attributes

worlds_axis

world_ordering

taking

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

input_structure()

Returns the input structure of this layer.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

output_structure()

Returns the output structure of this layer.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

path_matrix() pz.nx.NamedArray[source]#

Builds a matrix that maps “from” indices to their “to” indices.