variance_scaling_initializer#
- penzai.nn.linear_and_affine.variance_scaling_initializer(key: jax.Array, *, scale: float, mode: Literal['fan_in', 'fan_out', 'fan_avg'], distribution: Literal['uniform', 'normal', 'truncated_normal'], input_axes: dict[str, int], output_axes: dict[str, int], parallel_axes: dict[str, int], convolution_spatial_axes: dict[str, int], dtype: jax.typing.DTypeLike) jax.Array [source]#
Generic variance scaling initializer.