basic_training

basic_training#

Basic training logic for training simple models.

This module provides a barebones implementation of training logic that supports training Penzai models. This can be used to train simple models that do not require more sophisticated training loops. It also serves as a starting point for more complex training scripts.

Classes

LossFunction

Signature for loss functions expected by the common training step.

TrainState

Collection of state for ease of training.

TrainStepFunction

Signature for a common training step function after it is built.

Functions

build_train_step_fn(loss_fn[, jit, ...])

Builds a train step function for a common training loop.

compute_training_outputs_and_updates(...)

Runs a loss function and computes all of its outputs.