StatefulTrainer

StatefulTrainer#

class penzai.toolshed.basic_training.StatefulTrainer[source]#

Bases: Struct

A trainer object that updates its state in place.

StatefulTrainer manages its own state as well as the state of the model using pz.StateVariable objects.

Variables:
  • root_rng (PRNGKeyArray) – Base random number generator; used in combination with the training step index to derive per-step random numbers.

  • model (ModelPyTree) – The model being optimized. Usually will contain Variables for parameters and possibly other state.

  • state (pz.StateVariable[InternalTrainerState]) – The internal state of the trainer, wrapped in a Variable to allow in-place updates.

  • optimizer_def (optax.GradientTransformation) – An optax optimizer.

  • loss_fn (LossFunction) – A user-specified loss function.

  • step_fn (Any) – The function that performs a single step of training. Usually constructed for you by the build() class method.

Methods

__init__(root_rng, model, state, ...)

build(root_rng, model, optimizer_def, loss_fn)

step(**kwargs)

Runs one step of training.

Attributes

root_rng

model

state

optimizer_def

loss_fn

step_fn

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

step(**kwargs) AuxOutPyTree[source]#

Runs one step of training.