TrainState#

class penzai.toolshed.basic_training.TrainState[source]#

Bases: Struct

Collection of state for ease of training.

The parameters and nonlearnable parts of the model are kept separate internally in order to avoid overhead of PyTree traversal and to simplify checkpointing the parameters. You can access the full model by accessing the .model property.

Variables:
  • step (int) – Current step of training.

  • root_rng (PRNGKeyArray) – Base random number generator; used in combination with step to derive per-step random numbers.

  • params (dict[str, Any]) – Values for the parameters of the model being optimized.

  • model_without_params (ModelPyTree) – The nonlearnable parts of the model being optimized. Should contain Parameter instances but without values.

  • opt_state (OptimizerStatePyTree) – An optimizer state for the learnable parts of model.

  • loss_fn_state (LossStatePyTree) – Arbitrary state managed by the loss function. For instance, if your model has local state, you can functionalize it using pz.de.handle_local_states and store its state dict here.

  • optimizer_def (optax.GradientTransformation) – An optax optimizer.

Methods

__init__(step, root_rng, params, ...)

initial_state(model, optimizer_def, root_rng)

Constructs the initial training state.

Attributes

model

The full model, including parameters and nonlearnable parts.

step

root_rng

params

model_without_params

opt_state

loss_fn_state

optimizer_def

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

classmethod initial_state(model: ModelPyTree, optimizer_def: optax.GradientTransformation, root_rng: PRNGKeyArray, loss_fn_state: LossStatePyTree = None)[source]#

Constructs the initial training state.

Parameters:
  • model – The model being optimized.

  • optimizer_def – The optax optimizer to use.

  • root_rng – Base random number generator; used in combination with step to derive per-step random numbers.

  • loss_fn_state – Optional initial state for the loss function.

Returns:

An initial training state.

property model: ModelPyTree#

The full model, including parameters and nonlearnable parts.