Linear#
- class penzai.deprecated.v1.nn.linear_and_affine.Linear[source]#
Bases:
LayerA generalized linear (not affine) operator, for named arrays.
Applies an arbitrary contraction to the input
NamedArrayand a weight parameter. This can be used to express an arbitrary linear operator.Linearlayers are often (but not always) followed byAddBiasto make an affine transformation.- Variables:
weights (parameters.ParameterLike[NamedArray]) – The named array holding the weights for the linear operator.
in_axis_names (tuple[str, ...]) – The names of the axes to contract with the input, removing them.
out_axis_names (tuple[str, ...]) – The names of the axes that should not appear in the input and will be inserted into the output.
Methods
__init__(weights, in_axis_names, out_axis_names)from_config(input_axes, output_axes[, ...])Constructs a
Linearlayer from a configuration.input_structure()output_structure()treescope_color()__call__(in_array)Runs the linear operator.
Attributes
The axis names and sizes that should appear in the input only.
The axis names and sizes that will appear in the output only.
The axis names and sizes that should appear in both input and output.
weightsin_axis_namesout_axis_namesInherited Methods
(expand to view inherited methods)
attributes_dict()Constructs a dictionary with all of the fields in the class.
from_attributes(**field_values)Directly instantiates a struct given all of its fields.
key_for_field(field_name)Generates a JAX PyTree key for a given field name.
select()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten()Flattens this tree node.
tree_flatten_with_keys()Flattens this tree node with keys.
tree_unflatten(aux_data, children)Unflattens this tree node.
- __call__(in_array: NamedArray) NamedArray[source]#
Runs the linear operator.
- classmethod from_config(input_axes: dict[str, int], output_axes: dict[str, int], parallel_axes: dict[str, int] | None = None, parallel_broadcast_axes: dict[str, int] | None = None, initializer: LinearOperatorWeightInitializer = functools.partial(<function variance_scaling_initializer>, scale=1.0, mode='fan_avg', distribution='uniform'), dtype: jax.typing.DTypeLike = <class 'jax.numpy.float32'>, rename_outputs_if_necessary: bool = True) Linear | LinearInPlace[source]#
Constructs a
Linearlayer from a configuration.This can be used when building a new linear operator at the start of training. The returned operator will include
UninitializedParameternodes which should be initialized before training.Note: For the purposes of the initializer, the
parallel_axesandparallel_broadcast_axesare treated in the same way, without participating in output-dimension variance scaling. However, after initialization, theparallel_broadcast_axeswill be treated like extra output axes (and assumed not to be present in the input).- Parameters:
input_axes – Names and lengths for axes that the linear operator should contract over.
output_axes – Names and lengths for new axes that the linear operator should produce. If any axis names overlap with
input_axes, the argumentrename_outputs_if_necessarymust be True.parallel_axes – Names and lengths for axes that should be processed in parallel. These axes should appear in both the input and the output, and the resulting linear operator will apply a different operator to each slice. (This is similar to a block-diagonal matrix.) Must not overlap with any axes named in
input_axesoroutput_axes.parallel_broadcast_axes – Names and lengths for axes that should be treated like
parallel_axesbut will only appear in the output. The input will be implicitly broadcast over these axes. Must not overlap with any axes named ininput_axes,output_axesorparallel_axes.initializer – Function to use to initialize the weight.
dtype – Dtype for the weight.
rename_outputs_if_necessary – If True, and if
output_axesandinput_axeshave overlapping names, avoids name conflicts by adding “primed” versions of the overlapping names, and returns an instance ofLinearInPlaceinstead of aLinearlayer directly.
- Returns:
A
Linearlayer with uninitialized weights, or possibly aLinearInPlacelayer ifrename_outputs_if_necessaryis True andinput_axesoverlaps withoutput_axes.