StochasticDropout#

class penzai.nn.dropout.StochasticDropout[source]#

Bases: Layer

Stochastic dropout layer.

Dropout layers randomly mask out elements with a probability drop_rate, and then scale the output up by a factor of 1 / (1 - drop_rate).

For simplicity, and to avoid having to pass configuration through the model, dropout layers are always stochastic. To disable dropout, you can remove the dropout layers from a model using logic such as

model.select().at_instances_of(StochasticDropout).remove_from_parent()

or just disable them using

.at_instances_of(StochasticDropout).apply(lambda x: x.disable())

Note that dropout by default assigns different random dropout masks along every axis of the input. If you wish to share masks along different axes and thus drop out entire slices at a time, you can add those axis names to share_across_axes.

Variables:
  • drop_rate (float) – Probability of dropping an element.

  • share_across_axes (tuple[str, ...]) – Name or names of axes to share the dropout mask over. A single dropout mask will be broadcast across these axes. Any other axes will have independently-sampled dropout masks.

  • rng (random.RandomEffect) – The (request for the) random stream used by the model at runtime.

Methods

__init__(drop_rate[, share_across_axes, rng])

disable()

Returns a disabled version of this layer.

input_structure()

output_structure()

__call__(value, /)

Randomly drops out components of the input.

Attributes

rng

share_across_axes

drop_rate

Inherited Methods

(expand to view inherited methods)

attributes_dict()

Constructs a dictionary with all of the fields in the class.

from_attributes(**field_values)

Directly instantiates a struct given all of its fields.

key_for_field(field_name)

Generates a JAX PyTree key for a given field name.

select()

Wraps this struct in a selection, enabling functional-style mutations.

tree_flatten()

Flattens this tree node.

tree_flatten_with_keys()

Flattens this tree node with keys.

tree_unflatten(aux_data, children)

Unflattens this tree node.

treescope_color()

Computes a CSS color to display for this object in treescope.

__call__(value: named_axes.NamedArray, /) named_axes.NamedArray[source]#

Randomly drops out components of the input.

disable() DisabledDropout[source]#

Returns a disabled version of this layer.