UninitializedParameter#

class penzai.nn.parameters.UninitializedParameter[source]#

Bases: Struct, Generic[T], SupportsParameterRenaming

An uninitialized parameter.

UninitializedParameter represents a parameter that has not yet been initialized, along with an initialization strategy for it. In most cases, model-building code should use UninitializedParameter to construct their initial parameters, to make it possible to build a model without initializing its parameters immediately.

Variables:
  • initializer (Callable[[jax.Array], T]) – Callable used to initialize the parameter.

  • name (str) – The name to use when checkpointing this parameter, or restoring it from a checkpoint.

  • value_structure (shapecheck.StructureAnnotation) – The structure of the value that will be returned by the initializer, but with UninitializedArray in place of any of the actual parameters. Usually inferred automatically from initializer.

Methods

__init__(initializer, name[, value_structure])

Constructs an uninitialized parameter.

as_empty_parameter()

Creates a placeholder parameter containing jax.ShapeDtypeStruct leaves.

initialize(prng_key)

Randomly initializes the parameter.

initialize_with_value(value[, ...])

Directly initializes the parameter with particular value.

treescope_color()

with_renamed_parameters(rename_fn)

Attributes

value

Value accessor for compatibility with ParameterLike.

initializer

name

value_structure

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

__init__(initializer: Callable[[jax.Array], T], name: str, value_structure: Any = <object object>)[source]#

Constructs an uninitialized parameter.

Parameters:
  • initializer – Initializer to use.

  • name – Name to use.

  • value_structure – If provided, an explicit value structure that the initializer should return. If not provided, will be inferred by using jax.eval_shape.

as_empty_parameter() Parameter[T][source]#

Creates a placeholder parameter containing jax.ShapeDtypeStruct leaves.

This can be used to create a model with the right PyTree structure without actually initializing the parameters; this is useful for restoring a model from a checkpoint, for instance.

Returns:

A new instance of Parameter, but with empty jax.ShapeDtypeStruct leaves instead of initialized arrays.

initialize(prng_key: jax.Array) Parameter[T][source]#

Randomly initializes the parameter.

Parameters:

prng_key – Key to use for initialization.

Returns:

A new instance of Parameter.

Raises:
  • ValueError – If the initializer’s output does not match the expected

  • structure from self.value_structure.

initialize_with_value(value: T, check_structure: bool = True, strict_dtype: bool = True) Parameter[T][source]#

Directly initializes the parameter with particular value.

This can be used to bypass the initializer function and set the value directly, which can be useful for loading from checkpoints (for instance).

Parameters:
  • value – Value to set.

  • check_structure – Whether to check that the value matches the expected structure.

  • strict_dtype – Whether to check that the value matches the expected dtype. Ignored unless check_structure is True.

Returns:

An initialized parameter with the given value.

property value#

Value accessor for compatibility with ParameterLike.

Raises:

UninitializedParameterError – Since UninitializedParameters are not initialized, retrieving their value always raises an error.