LinearOperatorWeightInitializer#

class penzai.nn.linear_and_affine.LinearOperatorWeightInitializer[source]#

Bases: Protocol

Protocol for an initializer for a general linear NamedArray weight.

Methods

__init__(*args, **kwargs)

__call__(key, *, input_axes, output_axes, ...)

Signature for a generalized linear operator NamedArray initializer.

__call__(key: jax.Array, *, input_axes: dict[str, int], output_axes: dict[str, int], parallel_axes: dict[str, int], convolution_spatial_axes: dict[str, int], dtype: jax.typing.DTypeLike) NamedArray[source]#

Signature for a generalized linear operator NamedArray initializer.

This signature attempts to make explicit all of the dimensions used by an initializer, so that it can be used to initialize general linear layers without making assumptions about which axes are inputs or outputs.

Provided sets of axes must not overlap.

Parameters:
  • key – Random key.

  • input_axes – Names and lengths for axes that the linear operator should contract over.

  • output_axes – Names and lengths for new axes that the linear operator should produce.

  • parallel_axes – Names and lengths for axes that should be processed in parallel, such as the “heads” of an attention layer. These axes may appear in both the input and the output, and the resulting linear operator will apply a different operator to each slice. (This is similar to a block-diagonal matrix.)

  • convolution_spatial_axes – Names and lengths for axes that correspond to spatial dimensions of a convolution, e.g. the convolution kernel’s width and height. (Not expressable as an einsum.)

  • dtype – Desired dtype.

Returns:

An initialized weight.