Linear#
- class penzai.deprecated.v1.nn.linear_and_affine.Linear[source]#
Bases:
Layer
A generalized linear (not affine) operator, for named arrays.
Applies an arbitrary contraction to the input
NamedArray
and a weight parameter. This can be used to express an arbitrary linear operator.Linear
layers are often (but not always) followed byAddBias
to make an affine transformation.- Variables:
weights (parameters.ParameterLike[NamedArray]) – The named array holding the weights for the linear operator.
in_axis_names (tuple[str, ...]) – The names of the axes to contract with the input, removing them.
out_axis_names (tuple[str, ...]) – The names of the axes that should not appear in the input and will be inserted into the output.
Methods
__init__
(weights, in_axis_names, out_axis_names)from_config
(input_axes, output_axes[, ...])Constructs a
Linear
layer from a configuration.input_structure
()output_structure
()treescope_color
()__call__
(in_array)Runs the linear operator.
Attributes
The axis names and sizes that should appear in the input only.
The axis names and sizes that will appear in the output only.
The axis names and sizes that should appear in both the input and output.
weights
in_axis_names
out_axis_names
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
- __call__(in_array: NamedArray) NamedArray [source]#
Runs the linear operator.
- classmethod from_config(input_axes: dict[str, int], output_axes: dict[str, int], parallel_axes: dict[str, int] | None = None, parallel_broadcast_axes: dict[str, int] | None = None, initializer: LinearOperatorWeightInitializer = functools.partial(<function variance_scaling_initializer>, scale=1.0, mode='fan_avg', distribution='uniform'), dtype: jax.typing.DTypeLike = <class 'jax.numpy.float32'>, rename_outputs_if_necessary: bool = True) Linear | LinearInPlace [source]#
Constructs a
Linear
layer from a configuration.This can be used when building a new linear operator at the start of training. The returned operator will include
UninitializedParameter
nodes which should be initialized before training.Note: For the purposes of the initializer, the
parallel_axes
andparallel_broadcast_axes
are treated in the same way, without participating in output-dimension variance scaling. However, after initialization, theparallel_broadcast_axes
will be treated like extra output axes (and assumed not to be present in the input).- Parameters:
input_axes – Names and lengths for axes that the linear operator should contract over.
output_axes – Names and lengths for new axes that the linear operator should produce. If any axis names overlap with
input_axes
, the argumentrename_outputs_if_necessary
must be True.parallel_axes – Names and lengths for axes that should be processed in parallel. These axes should appear in both the input and the output, and the resulting linear operator will apply a different operator to each slice. (This is similar to a block-diagonal matrix.) Must not overlap with any axes named in
input_axes
oroutput_axes
.parallel_broadcast_axes – Names and lengths for axes that should be treated like
parallel_axes
but will only appear in the output. The input will be implicitly broadcast over these axes. Must not overlap with any axes named ininput_axes
,output_axes
orparallel_axes
.initializer – Function to use to initialize the weight.
dtype – Dtype for the weight.
rename_outputs_if_necessary – If True, and if
output_axes
andinput_axes
have overlapping names, avoids name conflicts by adding “primed” versions of the overlapping names, and returns an instance ofLinearInPlace
instead of aLinear
layer directly.
- Returns:
A
Linear
layer with uninitialized weights, or possibly aLinearInPlace
layer ifrename_outputs_if_necessary
is True andinput_axes
overlaps withoutput_axes
.