compute_training_outputs_and_updates

compute_training_outputs_and_updates#

penzai.toolshed.basic_training.compute_training_outputs_and_updates(current_params: dict[str, Any], model_without_params: ModelPyTree, opt_state: OptimizerStatePyTree, loss_fn_state: LossStatePyTree, root_rng: PRNGKeyArray, step: int | jax.Array, loss_kwargs: dict[str, Any], loss_fn: LossFunction, optimizer_def: optax.GradientTransformation) tuple[ModelPyTree, OptimizerStatePyTree, LossStatePyTree, AuxOutPyTree][source]#

Runs a loss function and computes all of its outputs.

This function runs the model and loss function and updates the corresponding parameters in the optimizer. It splits each component of the input into a separate argument to make it easy to JIT-compile, and to allow donating the parts that will be updated.

Parameters:
  • current_params – A dictionary of model parameters. These are the parts that WILL be updated by the optimizer.

  • model_without_params – A model PyTree that includes pz.NotInThisPartition in place of each of the learnable parameter values. These are the parts of the model that will NOT be updated by the optimizer.

  • opt_state – State for the optimizer.

  • loss_fn_state – State for the loss function.

  • root_rng – Root random key for the training process.

  • step – Current step of training, used to adjust root RNG.

  • loss_kwargs – Keyword arguments for the loss function.

  • loss_fn – The loss function.

  • optimizer_def – The optimizer.

Returns:

Tuple of (new_params, new_opt_state, new_loss_fn_state, aux_outputs)