TrainStepFunction

TrainStepFunction#

class penzai.toolshed.basic_training.TrainStepFunction[source]#

Bases: Protocol

Signature for a common training step function after it is built.

Methods

__init__(*args, **kwargs)

__call__(train_state, **kwargs)

Signature for a training step.

__call__(train_state: TrainState, **kwargs) tuple[TrainState, AuxOutPyTree][source]#

Signature for a training step.

Parameters:
  • train_state – The current state.

  • **kwargs – Arguments passed to the train step, usually inputs to the model or labels.

Returns:

A tuple (new_train_state, aux_outputs).