LossFunction#
- class penzai.deprecated.v1.toolshed.basic_training.LossFunction[source]#
Bases:
Protocol
Signature for loss functions expected by the common training step.
Methods
__init__
(*args, **kwargs)__call__
(*, model, state, rng, **kwargs)Signature for a loss function.
- __call__(*, model: ModelPyTree, state: LossStatePyTree, rng: PRNGKeyArray, **kwargs) tuple[jax.Array, LossStatePyTree, AuxOutPyTree] [source]#
Signature for a loss function.
- Parameters:
model – The structure with parameters, usually a neural network model.
state – Arbitrary state managed by the loss function. Can be None.
rng – A JAX PRNGKey, may be ignored.
**kwargs – Arguments passed to the train step, usually inputs to the model or labels.
- Returns:
A tuple
(loss, new_state, aux_outputs)
for this example.loss
should be a scalar.new_state
should match the structure ofstate
.aux_outputs
can be an arbitrary PyTree.