vars_for_axes

Contents

vars_for_axes#

penzai.core.shapecheck.vars_for_axes(var_name: str, axis_names_or_specs: Collection[str] | Mapping[str, int | None]) dict[str, DimVar | KnownDim][source]#

Creates variables for a known collection of named axes.

Parameters:
  • var_name – A name for the variable that will store the concrete sizes for all of these named axes.

  • axis_names_or_specs – Either a collection of axis names, or a mapping from axis names to integer axis sizes or to None (for unknown axis sizes).

Returns:

A dictionary with the keys from axis_names_or_specs, and values that either reflect either the known size from the value of axis_names_or_specs or an unknown dimension. Intended to be passed as the named_shape of an ArraySpec or unpacked into a larger dictionary of named shapes.