TrainStepFunction#
- class penzai.deprecated.v1.toolshed.basic_training.TrainStepFunction[source]#
Bases:
Protocol
Signature for a common training step function after it is built.
Methods
__init__
(*args, **kwargs)__call__
(train_state, **kwargs)Signature for a training step.
- __call__(train_state: TrainState, **kwargs) tuple[TrainState, AuxOutPyTree] [source]#
Signature for a training step.
- Parameters:
train_state – The current state.
**kwargs – Arguments passed to the train step, usually inputs to the model or labels.
- Returns:
A tuple
(new_train_state, aux_outputs)
.