Linear#

class penzai.nn.linear_and_affine.Linear[source]#

Bases: Layer

A generalized linear (not affine) operator, for named arrays.

Applies an arbitrary contraction to the input NamedArray and a weight parameter. This can be used to express an arbitrary linear operator.

Linear layers are often (but not always) followed by AddBias to make an affine transformation.

Variables:
  • weights (parameters.ParameterLike[NamedArray]) – The named array holding the weights for the linear operator.

  • in_axis_names (tuple[str, ...]) – The names of the axes to contract with the input, removing them.

  • out_axis_names (tuple[str, ...]) – The names of the axes that should not appear in the input and will be inserted into the output.

Methods

__init__(weights, in_axis_names, out_axis_names)

from_config(input_axes, output_axes[, ...])

Constructs a Linear layer from a configuration.

input_structure()

output_structure()

treescope_color()

__call__(in_array)

Runs the linear operator.

Attributes

input_axes

The axis names and sizes that should appear in the input only.

output_axes

The axis names and sizes that will appear in the output only.

parallel_axes

The axis names and sizes that should appear in both the input and output.

weights

in_axis_names

out_axis_names

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

__call__(in_array: NamedArray) NamedArray[source]#

Runs the linear operator.

classmethod from_config(input_axes: dict[str, int], output_axes: dict[str, int], parallel_axes: dict[str, int] | None = None, parallel_broadcast_axes: dict[str, int] | None = None, initializer: LinearOperatorWeightInitializer = functools.partial(<function variance_scaling_initializer>, scale=1.0, mode='fan_avg', distribution='uniform'), dtype: jax.typing.DTypeLike = <class 'jax.numpy.float32'>, rename_outputs_if_necessary: bool = True) Linear | LinearInPlace[source]#

Constructs a Linear layer from a configuration.

This can be used when building a new linear operator at the start of training. The returned operator will include UninitializedParameter nodes which should be initialized before training.

Note: For the purposes of the initializer, the parallel_axes and parallel_broadcast_axes are treated in the same way, without participating in output-dimension variance scaling. However, after initialization, the parallel_broadcast_axes will be treated like extra output axes (and assumed not to be present in the input).

Parameters:
  • input_axes – Names and lengths for axes that the linear operator should contract over.

  • output_axes – Names and lengths for new axes that the linear operator should produce. If any axis names overlap with input_axes, the argument rename_outputs_if_necessary must be True.

  • parallel_axes – Names and lengths for axes that should be processed in parallel. These axes should appear in both the input and the output, and the resulting linear operator will apply a different operator to each slice. (This is similar to a block-diagonal matrix.) Must not overlap with any axes named in input_axes or output_axes.

  • parallel_broadcast_axes – Names and lengths for axes that should be treated like parallel_axes but will only appear in the output. The input will be implicitly broadcast over these axes. Must not overlap with any axes named in input_axes, output_axes or parallel_axes.

  • initializer – Function to use to initialize the weight.

  • dtype – Dtype for the weight.

  • rename_outputs_if_necessary – If True, and if output_axes and input_axes have overlapping names, avoids name conflicts by adding “primed” versions of the overlapping names, and returns an instance of LinearInPlace instead of a Linear layer directly.

Returns:

A Linear layer with uninitialized weights, or possibly a LinearInPlace layer if rename_outputs_if_necessary is True and input_axes overlaps with output_axes.

property input_axes: dict[str, int]#

The axis names and sizes that should appear in the input only.

property output_axes: dict[str, int]#

The axis names and sizes that will appear in the output only.

property parallel_axes: dict[str, int]#

The axis names and sizes that should appear in both the input and output.