StatefulTrainer#
- class penzai.toolshed.basic_training.StatefulTrainer[source]#
Bases:
Struct
A trainer object that updates its state in place.
StatefulTrainer manages its own state as well as the state of the model using pz.StateVariable objects.
- Variables:
root_rng (PRNGKeyArray) – Base random number generator; used in combination with the training step index to derive per-step random numbers.
model (ModelPyTree) – The model being optimized. Usually will contain Variables for parameters and possibly other state.
state (pz.StateVariable[InternalTrainerState]) – The internal state of the trainer, wrapped in a Variable to allow in-place updates.
optimizer_def (optax.GradientTransformation) – An optax optimizer.
loss_fn (LossFunction) – A user-specified loss function.
step_fn (Any) – The function that performs a single step of training. Usually constructed for you by the
build()
class method.
Methods
__init__
(root_rng, model, state, ...)build
(root_rng, model, optimizer_def, loss_fn)step
(**kwargs)Runs one step of training.
Attributes
root_rng
model
state
optimizer_def
loss_fn
step_fn
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.