RewireComputationPaths#
- class penzai.toolshed.model_rewiring.RewireComputationPaths[source]#
Bases:
Layer
Rewires computation across parallel model runs along a worlds axis.
This layer can be used to implement sophisticated ablation, activation patching, and path patching analyses. It assumes its input activation has a particular “worlds” axis, which indicates a minibatch of examples that represent counterfactual variants of the input or of the model. It then re-writes the activations by coping activations between the worlds according to its weights.
This layer is intended to be directly inserted into the model at the point where you want to “bridge” the parallel worlds. For instance, if you want to patch the value activations from an attention block, you can insert this inside the
input_to_value
sublayer of the attention block, and configure it to copy the desired part of the input. If you want to freeze the attention patterns of a set of blocks, you can insert this after the attention softmax, and configure it to copy from the “original” world instead of the current world. You can also use a length-zero tuple to indicate that the value should be entirely dropped and zeroed out; this can be useful for e.g. disabling writes to the residual stream of a transformer.One useful pattern you can use is to insert a
RewireComputationPaths
block into aLinearizeAndAdjust
block’slinearize_around
attribute. This will allow you to linearize a nonlinear operation around a single world’s input, but evaluate the linear approximation around each world individually.Another useful pattern is to either rewire or zero out the contributions of layers to the residual stream, or rewire the inputs of those layers when they read from the residual stream. This can allow you to measure the residual vectors flowing from one layer to another, or to measure the direct contribution of a layer to the final output logits ignoring the reads or writes from other layers.
Note that this layer is designed to be used for “batched rewiring”: all of the different input conditions are run through the model in a single batched forward pass. For instance, you might have a clean world where nothing is ablated and all rewiring blocks read back from the clean world (a no-op) and a corrupted world where some activations are rewired to be copied from the clean world, where these worlds map to indices 0 and 1 along a “worlds” axis of the input. This is a declarative alternative to running multiple forward passes and saving/restoring activations from a cache. The batched rewiring version is easier to express in Penzai due to being a stateless function, and may reduce memory and compute overhead from saving many small forward passes. It can also be compiled into a single JIT-ted computation using
jax.jit
(which is even easier if you usepenzai.toolshed.jit_wrapper
).- Variables:
worlds_axis (str) – Axis name of the “worlds” axis (often just “worlds”). Should be an axis name that is NOT already used by the model.
world_ordering (tuple[str, ...]) – A tuple of world names. We assume the input will have a worlds axis of the same length as this tuple.
taking (dict[str, From | tuple[From, ...]]) – A dictionary that maps destination world names to a source or sources that those worlds should read from. The keys should exactly match
world_ordering
and represent each world we are outputting to. The values should be instances ofFrom
or tuples of instances ofFrom
, determining where the value for each world should be taken from. If thetaking
key andFrom
source are the same, and the weight is 1, this represents a no-op. A common pattern is to have clean or unablated worlds read from themselves, but ablated or corrupted worlds take from the clean worlds.
Methods
__init__
(worlds_axis, world_ordering, taking)Builds a matrix that maps "from" indices to their "to" indices.
__call__
(inputs, **_unused_side_inputs)Attributes
worlds_axis
world_ordering
taking
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
bind_variables
(variables[, allow_unused])Convenience function to bind variables to a layer.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
stateless_call
(variable_values, argument, /, ...)Calls a layer with temporary variables, without modifying its state.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.