Jitted

Contents

Jitted#

class penzai.toolshed.jit_wrapper.Jitted[source]#

Bases: Layer

Wraps a pure layer to run under JIT compliation.

The Jitted wrapper has the same behavior and input/output structure as the layer it wraps, but modifies __call__ so that every call is run under JIT compilation. Since jax.jit operates on pure functions, the layer wrapped by Jitted should not have any unhandled effects (i.e. EffectRequest instances from penzai.data_effects).

Since Jitted is an ordinary Layer, you can still inspect the contained layer and make modifications to it. This will automatically trigger a recompile, since the jax.jit call depends on the PyTree structure of the Jitted block.

Variables:

body (pz.LayerLike) – The layer that should run under JIT compilation.

Methods

__init__(body)

__call__(argument, /)

Attributes

body

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

input_structure()

Returns the input structure of this layer.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

output_structure()

Returns the output structure of this layer.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.