TrainState#
- class penzai.deprecated.v1.toolshed.basic_training.TrainState[source]#
Bases:
Struct
Collection of state for ease of training.
The parameters and nonlearnable parts of the model are kept separate internally in order to avoid overhead of PyTree traversal and to simplify checkpointing the parameters. You can access the full model by accessing the
.model
property.- Variables:
step (int) – Current step of training.
root_rng (PRNGKeyArray) – Base random number generator; used in combination with step to derive per-step random numbers.
params (dict[str, Any]) – Values for the parameters of the model being optimized.
model_without_params (ModelPyTree) – The nonlearnable parts of the model being optimized. Should contain
Parameter
instances but without values.opt_state (OptimizerStatePyTree) – An optimizer state for the learnable parts of
model
.loss_fn_state (LossStatePyTree) – Arbitrary state managed by the loss function. For instance, if your model has local state, you can functionalize it using
pz.de.handle_local_states
and store its state dict here.optimizer_def (optax.GradientTransformation) – An optax optimizer.
Methods
__init__
(step, root_rng, params, ...)initial_state
(model, optimizer_def, root_rng)Constructs the initial training state.
Attributes
The full model, including parameters and nonlearnable parts.
step
root_rng
params
model_without_params
opt_state
loss_fn_state
optimizer_def
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.
- classmethod initial_state(model: ModelPyTree, optimizer_def: optax.GradientTransformation, root_rng: PRNGKeyArray, loss_fn_state: LossStatePyTree = None)[source]#
Constructs the initial training state.
- Parameters:
model – The model being optimized.
optimizer_def – The optax optimizer to use.
root_rng – Base random number generator; used in combination with step to derive per-step random numbers.
loss_fn_state – Optional initial state for the loss function.
- Returns:
An initial training state.
- property model: ModelPyTree#
The full model, including parameters and nonlearnable parts.