Jitted#
- class penzai.toolshed.jit_wrapper.Jitted[source]#
Bases:
Layer
Wraps a pure layer to run under JIT compliation.
The Jitted wrapper has the same behavior and input/output structure as the layer it wraps, but modifies
__call__
so that every call is run under JIT compilation. Variables in the layer will be updated correctly usingpz.variable_jit
.Since Jitted is an ordinary Layer, you can still inspect the contained layer and make modifications to it. This will automatically trigger a recompile, since the
jax.jit
call depends on the PyTree structure of the Jitted block.- Variables:
body (pz.nn.Layer) – The layer that should run under JIT compilation.
Methods
__init__
(body)__call__
(argument, /, **side_inputs)Attributes
body
Inherited Methods
(expand to view inherited methods)
attributes_dict
()Constructs a dictionary with all of the fields in the class.
bind_variables
(variables[, allow_unused])Convenience function to bind variables to a layer.
from_attributes
(**field_values)Directly instantiates a struct given all of its fields.
key_for_field
(field_name)Generates a JAX PyTree key for a given field name.
select
()Wraps this struct in a selection, enabling functional-style mutations.
stateless_call
(variable_values, argument, /, ...)Calls a layer with temporary variables, without modifying its state.
tree_flatten
()Flattens this tree node.
tree_flatten_with_keys
()Flattens this tree node with keys.
tree_unflatten
(aux_data, children)Unflattens this tree node.
treescope_color
()Computes a CSS color to display for this object in treescope.