LabeledVariable#

class penzai.core.variables.LabeledVariable[source]#

Bases: AbstractVariable, Generic[T]

Base implementation of a variable with a label, value, and metadata.

Conceptually, each variable is only valid inside a single JAX “trace”, corresponding to a JAX transformation or control-flow primitive. Variables can be created and modified in a given trace level, and can be read inside inner traces (e.g. you can read a variable inside a JAX cond), but you should generall avoid assigning a value to a variable inside an inner trace, because the value may leak.

This is currently unchecked but may change in the future.

Variables:
  • label (VariableLabel) – The unique label for this variable.

  • value – The mutable value stored in the variable. Should be a JAX pytree.

  • metadata (dict[Any, Any]) – A dictionary of metadata associated with this variable.

Methods

__init__(*, label, value[, metadata])

Constructs a new variable.

set_value(new_val)

Sets the value of the Variable.

treescope_color()

Returns a color for this variable in Treescope.

update(new_frozen_value)

Updates the value of this variable to match a frozen variable.

Attributes

value

label

metadata

Inherited Methods

(expand to view inherited methods)

freeze()

Returns a frozen copy of this variable.

get_slot()

Returns the slot that this variable is replaced with when unbound.

__init__(*, label: VariableLabel, value: T, metadata: dict[Any, Any] | None = None)[source]#

Constructs a new variable.

Parameters:
  • label – The unique label for this variable.

  • value – The initial value of the variable.

  • metadata – A dictionary of metadata associated with this Variable.

set_value(new_val: T)[source]#

Sets the value of the Variable.

abstract treescope_color() str | tuple[str, str][source]#

Returns a color for this variable in Treescope.

update(new_frozen_value: LabeledVariableValue)[source]#

Updates the value of this variable to match a frozen variable.