class penzai.toolshed.basic_training.LossFunction[source]#

Bases: Protocol

Signature for loss functions expected by the common training step.


__init__(*args, **kwargs)

__call__(*, model, state, rng, **kwargs)

Signature for a loss function.

__call__(*, model: ModelPyTree, state: LossStatePyTree, rng: PRNGKeyArray, **kwargs) tuple[jax.Array, LossStatePyTree, AuxOutPyTree][source]#

Signature for a loss function.

  • model – The structure with parameters, usually a neural network model.

  • state – Arbitrary state managed by the loss function. Can be None.

  • rng – A JAX PRNGKey, may be ignored.

  • **kwargs – Arguments passed to the train step, usually inputs to the model or labels.


A tuple (loss, new_state, aux_outputs) for this example. loss should be a scalar. new_state should match the structure of state. aux_outputs can be an arbitrary PyTree.