UninitializedParameter#
- class penzai.deprecated.v1.nn.parameters.UninitializedParameter[source]#
Bases:
Struct
,Generic
[T
],SupportsParameterRenaming
An uninitialized parameter.
UninitializedParameter
represents a parameter that has not yet been initialized, along with an initialization strategy for it. In most cases, model-building code should useUninitializedParameter
to construct their initial parameters, to make it possible to build a model without initializing its parameters immediately.- Variables:
initializer (Callable[[jax.Array], T]) – Callable used to initialize the parameter.
name (str) – The name to use when checkpointing this parameter, or restoring it from a checkpoint.
value_structure (shapecheck.StructureAnnotation) – The structure of the value that will be returned by the initializer, but with UninitializedArray in place of any of the actual parameters. Usually inferred automatically from
initializer
.
Methods
__init__
(initializer, name[, value_structure])Constructs an uninitialized parameter.
Creates a placeholder parameter containing
jax.ShapeDtypeStruct
leaves.initialize
(prng_key)Randomly initializes the parameter.
initialize_with_value
(value[, ...])Directly initializes the parameter with particular value.
treescope_color
()with_renamed_parameters
(rename_fn)Attributes
Value accessor for compatibility with
ParameterLike
.initializer
name
value_structure
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
- __init__(initializer: Callable[[jax.Array], T], name: str, value_structure: Any = <object object>)[source]#
Constructs an uninitialized parameter.
- Parameters:
initializer – Initializer to use.
name – Name to use.
value_structure – If provided, an explicit value structure that the initializer should return. If not provided, will be inferred by using
jax.eval_shape
.
- as_empty_parameter() Parameter[T] [source]#
Creates a placeholder parameter containing
jax.ShapeDtypeStruct
leaves.This can be used to create a model with the right PyTree structure without actually initializing the parameters; this is useful for restoring a model from a checkpoint, for instance.
- Returns:
A new instance of Parameter, but with empty
jax.ShapeDtypeStruct
leaves instead of initialized arrays.
- initialize(prng_key: jax.Array) Parameter[T] [source]#
Randomly initializes the parameter.
- Parameters:
prng_key – Key to use for initialization.
- Returns:
A new instance of
Parameter
.- Raises:
ValueError – If the initializer’s output does not match the expected
structure from self.value_structure. –
- initialize_with_value(value: T, check_structure: bool = True, strict_dtype: bool = True) Parameter[T] [source]#
Directly initializes the parameter with particular value.
This can be used to bypass the initializer function and set the value directly, which can be useful for loading from checkpoints (for instance).
- Parameters:
value – Value to set.
check_structure – Whether to check that the value matches the expected structure.
strict_dtype – Whether to check that the value matches the expected dtype. Ignored unless
check_structure
is True.
- Returns:
An initialized parameter with the given value.
- property value#
Value accessor for compatibility with
ParameterLike
.- Raises:
UninitializedParameterError – Since
UninitializedParameter
s are not initialized, retrieving their value always raises an error.