TrainState#
- class penzai.deprecated.v1.toolshed.basic_training.TrainState[source]#
Bases:
StructCollection of state for ease of training.
The parameters and nonlearnable parts of the model are kept separate internally in order to avoid overhead of PyTree traversal and to simplify checkpointing the parameters. You can access the full model by accessing the
.modelproperty.- Variables:
step (int) – Current step of training.
root_rng (PRNGKeyArray) – Base random number generator; used in combination with step to derive per-step random numbers.
params (dict[str, Any]) – Values for the parameters of the model being optimized.
model_without_params (ModelPyTree) – The nonlearnable parts of the model being optimized. Should contain
Parameterinstances but without values.opt_state (OptimizerStatePyTree) – An optimizer state for the learnable parts of
model.loss_fn_state (LossStatePyTree) – Arbitrary state managed by the loss function. For instance, if your model has local state, you can functionalize it using
pz.de.handle_local_statesand store its state dict here.optimizer_def (optax.GradientTransformation) – An optax optimizer.
Methods
__init__(step, root_rng, params, ...)initial_state(model, optimizer_def, root_rng)Constructs the initial training state.
Attributes
The full model, including parameters and nonlearnable parts.
steproot_rngparamsmodel_without_paramsopt_stateloss_fn_stateoptimizer_defInherited Methods
(expand to view inherited methods)
attributes_dict()Constructs a dictionary with all of the fields in the class.
from_attributes(**field_values)Directly instantiates a struct given all of its fields.
key_for_field(field_name)Generates a JAX PyTree key for a given field name.
select()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten()Flattens this tree node.
tree_flatten_with_keys()Flattens this tree node with keys.
tree_unflatten(aux_data, children)Unflattens this tree node.
treescope_color()Computes a CSS color to display for this object in treescope.
- classmethod initial_state(model: ModelPyTree, optimizer_def: optax.GradientTransformation, root_rng: PRNGKeyArray, loss_fn_state: LossStatePyTree = None)[source]#
Constructs the initial training state.
- Parameters:
model – The model being optimized.
optimizer_def – The optax optimizer to use.
root_rng – Base random number generator; used in combination with step to derive per-step random numbers.
loss_fn_state – Optional initial state for the loss function.
- Returns:
An initial training state.
- property model: ModelPyTree#
The full model, including parameters and nonlearnable parts.