build_train_step_fn

build_train_step_fn#

penzai.toolshed.basic_training.build_train_step_fn(loss_fn: LossFunction, jit: bool = True, donate_params_and_state: bool = False, train_state_shardings: TrainState | None = None, input_kwarg_shardings: dict[str, Any] | None = None, aux_output_shardings: AuxOutPyTree | None = None) TrainStepFunction[source]#

Builds a train step function for a common training loop.

For simplicity, the output of the train step function is the third output of the loss function alone, not including the loss value itself. If you want to obtain the loss value, you can return it both as the first output of the loss function and also as part of the third output. For more control, consider forking this function and modifying the logic.

If your model has its own local state variables or stochastic layers, your loss function is responsible for handling those effects using its arguments. For instance, you could transform your model to handle the RandomEffect and LocalStateEffect, and pass the non-effectful transformed model and initial state dict to TrainState.initial_state. Then in your loss_fn, you could forward the rng and state arguments to your non-effectful model as arguments, following the expected argument structure of the handlers you used. Since models are PyTrees before and after handling effects, you have freedom to resolve them wherever is most convenient.

Parameters:
  • loss_fn – Loss function taking a model, state, rng, and additional keyword-argument inputs, and returning (loss_scalar, new_state, outputs).

  • jit – Whether to JIT-compile the train step.

  • donate_params_and_state – Whether to donate the old parameters and states when JIT-compiling the train step. If True, parameter and state arrays may be deleted after each step, meaning that any previous references to them (e.g. the version of the model with initial parameters) will break. Parts of the model that are not learnable will not be donated.

  • train_state_shardings – Optional TrainState with leaves replaced with JAX shardings. If provided, the train step will be compiled to shard its inputs and outputs according to these shardings. If not provided, allows JAX to infer an appropriate sharding. Ignored unless jit=True. Shardings for step and root_rng are ignored.

  • input_kwarg_shardings – Optional mapping from input keyword argument names to shardings. If provided, the train step will be compiled to shard its user-provided inputs according to these shardings. If not provided, allows JAX to infer an appropriate sharding. Ignored unless jit=True.

  • aux_output_shardings – Optional auxiliary output PyTree with leaves replaced with JAX shardings. If provided, the train step will be compiled to shard its aux outputs according to these shardings. If not provided, allows JAX to infer an appropriate sharding. Ignored unless jit=True.

Returns:

A train step, which updates the model and internal states, and returns a new train state and the outputs of the loss function.