StochasticDropout#

class penzai.nn.dropout.StochasticDropout[source]#

Bases: Layer

Stochastic dropout layer.

Dropout layers randomly mask out elements with a probability drop_rate, and then scale the output up by a factor of 1 / (1 - drop_rate).

For simplicity, and to avoid having to pass configuration through the model, dropout layers are always stochastic. To disable dropout, you can remove the dropout layers from a model using logic such as

model.select().at_instances_of(StochasticDropout).remove_from_parent()

or just disable them using

.at_instances_of(StochasticDropout).apply(lambda x: x.disable())

Note that dropout by default assigns different random dropout masks along every axis of the input. If you wish to share masks along different axes and thus drop out entire slices at a time, you can add those axis names to share_across_axes.

Variables:
  • drop_rate (float) – Probability of dropping an element.

  • share_across_axes (tuple[str, ...]) – Name or names of axes to share the dropout mask over. A single dropout mask will be broadcast across these axes. Any other axes will have independently-sampled dropout masks.

  • random_stream_input_name (str) – The side input key for the random stream used by the model at runtime.

Methods

__init__(drop_rate[, share_across_axes, ...])

disable()

Returns a disabled version of this layer.

__call__(value, /, **side_inputs)

Randomly drops out components of the input.

Attributes

random_stream_input_name

share_across_axes

drop_rate

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

bind_variables(variables[, allow_unused])

Convenience function to bind variables to a layer.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

stateless_call(variable_values, argument, /, ...)

Calls a layer with temporary variables, without modifying its state.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

__call__(value: named_axes.NamedArray, /, **side_inputs) named_axes.NamedArray[source]#

Randomly drops out components of the input.

disable() DisabledDropout[source]#

Returns a disabled version of this layer.