ShareableUninitializedParameter#

class penzai.nn.parameters.ShareableUninitializedParameter[source]#

Bases: UninitializedParameter

A shareable variant of an uninitialized parameter.

A ShareableUninitializedParameter is just like an ordinary UninitializedParameter, except that they have been tagged as being OK to share by name.

Tagging a parameter as shareable does not actually enable sharing itself, because there must only be one copy of each shared parameter in the PyTree to ensure that gradients propagate correctly. As such, models or submodels with ShareableUninitializedParameter in them should be transformed using attach_shared_parameters, which will take ownership of the shareable parameters and make them actually shared. If ShareableUninitializedParameter instances with the same name are not bound using attach_shared_parameters, this will lead to a name conflict during initialization.

Inherited Attributes

value

Value accessor for compatibility with ParameterLike.

Methods

from_uninitialized(uninit)

Returns a ShareableUninitializedParameter equivalent to uninit.

Attributes

value

Value accessor for compatibility with ParameterLike.

initializer

name

value_structure

Inherited Methods

(expand to view inherited methods)

__init__(initializer, name[, value_structure])

Constructs an uninitialized parameter.

as_empty_parameter()

Creates a placeholder parameter containing jax.ShapeDtypeStruct leaves.

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

initialize(prng_key)

Randomly initializes the parameter.

initialize_with_value(value[, ...])

Directly initializes the parameter with particular value.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

with_renamed_parameters(rename_fn)

classmethod from_uninitialized(uninit: UninitializedParameter) ShareableUninitializedParameter[source]#

Returns a ShareableUninitializedParameter equivalent to uninit.