LabeledVariable#
- class penzai.core.variables.LabeledVariable[source]#
Bases:
AbstractVariable
,Generic
[T
]Base implementation of a variable with a label, value, and metadata.
Conceptually, each variable is only valid inside a single JAX “trace”, corresponding to a JAX transformation or control-flow primitive. Variables can be created and modified in a given trace level, and can be read inside inner traces (e.g. you can read a variable inside a JAX
cond
), but you should generall avoid assigning a value to a variable inside an inner trace, because the value may leak.This is currently unchecked but may change in the future.
- Variables:
label (VariableLabel) – The unique label for this variable.
value – The mutable value stored in the variable. Should be a JAX pytree.
metadata (dict[Any, Any]) – A dictionary of metadata associated with this variable.
Methods
__init__
(*, label, value[, metadata])Constructs a new variable.
set_value
(new_val)Sets the value of the Variable.
Returns a color for this variable in Treescope.
update
(new_frozen_value)Updates the value of this variable to match a frozen variable.
Attributes
value
label
metadata
Inherited Methods
(expand to view inherited methods)
freeze
()Returns a frozen copy of this variable.
get_slot
()Returns the slot that this variable is replaced with when unbound.
- __init__(*, label: VariableLabel, value: T, metadata: dict[Any, Any] | None = None)[source]#
Constructs a new variable.
- Parameters:
label – The unique label for this variable.
value – The initial value of the variable.
metadata – A dictionary of metadata associated with this Variable.
- abstract treescope_color() str | tuple[str, str] [source]#
Returns a color for this variable in Treescope.
- update(new_frozen_value: LabeledVariableValue)[source]#
Updates the value of this variable to match a frozen variable.