StatefulTrainer#
- class penzai.toolshed.basic_training.StatefulTrainer[source]#
Bases:
StructA trainer object that updates its state in place.
StatefulTrainer manages its own state as well as the state of the model using pz.StateVariable objects.
- Variables:
root_rng (PRNGKeyArray) – Base random number generator; used in combination with the training step index to derive per-step random numbers.
model (ModelPyTree) – The model being optimized. Usually will contain Variables for parameters and possibly other state.
state (pz.StateVariable[InternalTrainerState]) – The internal state of the trainer, wrapped in a Variable to allow in-place updates.
optimizer_def (optax.GradientTransformation) – An optax optimizer.
loss_fn (LossFunction) – A user-specified loss function.
step_fn (Any) – The function that performs a single step of training. Usually constructed for you by the
build()class method.
Methods
__init__(root_rng, model, state, ...)build(root_rng, model, optimizer_def, loss_fn)step(**kwargs)Runs one step of training.
Attributes
root_rngmodelstateoptimizer_defloss_fnstep_fnInherited Methods
(expand to view inherited methods)
attributes_dict()Constructs a dictionary with all of the fields in the class.
from_attributes(**field_values)Directly instantiates a struct given all of its fields.
key_for_field(field_name)Generates a JAX PyTree key for a given field name.
select()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten()Flattens this tree node.
tree_flatten_with_keys()Flattens this tree node with keys.
tree_unflatten(aux_data, children)Unflattens this tree node.
treescope_color()Computes a CSS color to display for this object in treescope.