Index _ | A | B | C | D | E | F | G | H | I | J | K | L | M | N | O | P | Q | R | S | T | U | V | W | X | Z _ __call__() (flax.linen.activation.PReLU method) (flax.linen.BatchApply method) (flax.linen.BatchNorm method) (flax.linen.Bidirectional method) (flax.linen.Conv method) (flax.linen.ConvLocal method) (flax.linen.ConvLSTMCell method) (flax.linen.ConvTranspose method) (flax.linen.Dense method) (flax.linen.DenseGeneral method) (flax.linen.Dropout method) (flax.linen.Einsum method) (flax.linen.Embed method) (flax.linen.GroupNorm method) (flax.linen.GRUCell method) (flax.linen.InstanceNorm method) (flax.linen.LayerNorm method) (flax.linen.LSTMCell method) (flax.linen.MGUCell method) (flax.linen.MultiHeadAttention method) (flax.linen.MultiHeadDotProductAttention method) (flax.linen.OptimizedLSTMCell method) (flax.linen.RMSNorm method) (flax.linen.RNN method) (flax.linen.RNNCellBase method) (flax.linen.SelfAttention method) (flax.linen.Sequential method) (flax.linen.SimpleCell method) (flax.linen.SpectralNorm method) (flax.linen.WeightNorm method) __setattr__() (flax.linen.Module method) A activation_bias_init (flax.linen.MGUCell attribute) activation_fn (flax.linen.GRUCell attribute) (flax.linen.LSTMCell attribute) (flax.linen.MGUCell attribute) (flax.linen.OptimizedLSTMCell attribute) (flax.linen.SimpleCell attribute) AlreadyExistsError apply() (flax.linen.Module method) (in module flax.linen), [1] apply_gradients() (flax.training.train_state.TrainState method) apply_update() (flax.cursor.Cursor method) ApplyModuleInvalidMethodError ApplyScopeInvalidVariablesStructureError ApplyScopeInvalidVariablesTypeError AssignSubModuleError attend() (flax.linen.Embed method) attention_fn (flax.linen.MultiHeadAttention attribute) (flax.linen.MultiHeadDotProductAttention attribute) attn_weights_value_einsum_cls (flax.linen.MultiHeadDotProductAttention attribute) avg_pool() (in module flax.linen) axis (flax.linen.BatchNorm attribute) (flax.linen.DenseGeneral attribute) axis_index_groups (flax.linen.BatchNorm attribute) (flax.linen.GroupNorm attribute) (flax.linen.InstanceNorm attribute) (flax.linen.LayerNorm attribute) (flax.linen.RMSNorm attribute) axis_name (flax.linen.BatchNorm attribute) (flax.linen.GroupNorm attribute) (flax.linen.InstanceNorm attribute) (flax.linen.LayerNorm attribute) (flax.linen.RMSNorm attribute) B batch_dims (flax.linen.DenseGeneral attribute) BatchApply (class in flax.linen) BatchNorm (class in flax.linen) best_metric (flax.training.early_stopping.EarlyStopping attribute) bias (flax.linen.ConvLSTMCell attribute) bias_init (flax.linen.BatchNorm attribute) (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvTranspose attribute) (flax.linen.Dense attribute) (flax.linen.DenseGeneral attribute) (flax.linen.Einsum attribute) (flax.linen.GroupNorm attribute) (flax.linen.GRUCell attribute) (flax.linen.InstanceNorm attribute) (flax.linen.LayerNorm attribute) (flax.linen.LSTMCell attribute) (flax.linen.MultiHeadAttention attribute) (flax.linen.MultiHeadDotProductAttention attribute) (flax.linen.OptimizedLSTMCell attribute) (flax.linen.SimpleCell attribute) Bidirectional (class in flax.linen) bind() (flax.linen.Module method) Bound Module broadcast_dims (flax.linen.Dropout attribute) broadcast_dropout (flax.linen.MultiHeadAttention attribute) (flax.linen.MultiHeadDotProductAttention attribute) build() (flax.cursor.Cursor method) C CallCompactUnboundModuleError CallSetupUnboundModuleError CallShareScopeOnUnboundModuleError CallUnbindOnUnboundModuleError cell (flax.linen.RNN attribute) celu() (in module flax.linen.activation) collection_name (flax.linen.SpectralNorm attribute) Compact / Non-compact Module compact() (in module flax.linen) cond() (in module flax.linen) constant() (in module flax.linen.initializers) Conv (class in flax.linen) convert_pre_linen() (in module flax.training.checkpoints) ConvLocal (class in flax.linen) ConvLSTMCell (class in flax.linen) ConvTranspose (class in flax.linen) copy() (flax.core.frozen_dict.FrozenDict method) (flax.linen.Module method) (in module flax.core.frozen_dict) create() (flax.training.train_state.TrainState class method) create_constant_learning_rate_schedule() (in module flax.training.lr_schedule) create_cosine_learning_rate_schedule() (in module flax.training.lr_schedule) create_stepped_learning_rate_schedule() (in module flax.training.lr_schedule) Cursor (class in flax.cursor) cursor() (in module flax.cursor) CursorFindError custom_vjp() (in module flax.linen) D dataclass() (in module flax.struct) decode (flax.linen.MultiHeadAttention attribute) (flax.linen.MultiHeadDotProductAttention attribute) delta_orthogonal() (in module flax.linen.initializers) Dense (class in flax.linen) DenseGeneral (class in flax.linen) DescriptorAttributeError deterministic (flax.linen.Dropout attribute) (flax.linen.MultiHeadAttention attribute) (flax.linen.MultiHeadDotProductAttention attribute) disable_named_call() (in module flax.linen) dot_product_attention() (in module flax.linen) dot_product_attention_weights() (in module flax.linen) Dropout (class in flax.linen) dropout_rate (flax.linen.MultiHeadAttention attribute) (flax.linen.MultiHeadDotProductAttention attribute) dtype (flax.linen.BatchNorm attribute) (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvLSTMCell attribute) (flax.linen.ConvTranspose attribute) (flax.linen.Dense attribute) (flax.linen.DenseGeneral attribute) (flax.linen.Einsum attribute) (flax.linen.Embed attribute) (flax.linen.GroupNorm attribute) (flax.linen.GRUCell attribute) (flax.linen.InstanceNorm attribute) (flax.linen.LayerNorm attribute) (flax.linen.LSTMCell attribute) (flax.linen.MGUCell attribute) (flax.linen.MultiHeadAttention attribute) (flax.linen.MultiHeadDotProductAttention attribute) (flax.linen.OptimizedLSTMCell attribute) (flax.linen.RMSNorm attribute) (flax.linen.SimpleCell attribute) (flax.linen.SpectralNorm attribute) (flax.linen.WeightNorm attribute) E EarlyStopping (class in flax.training.early_stopping) Einsum (class in flax.linen) einsum_str (flax.linen.Einsum attribute) elu() (in module flax.linen.activation) Embed (class in flax.linen) embedding_init (flax.linen.Embed attribute) enable_named_call() (in module flax.linen) epsilon (flax.linen.BatchNorm attribute) (flax.linen.GroupNorm attribute) (flax.linen.InstanceNorm attribute) (flax.linen.LayerNorm attribute) (flax.linen.RMSNorm attribute) (flax.linen.SpectralNorm attribute) (flax.linen.WeightNorm attribute) error_on_non_matrix (flax.linen.SpectralNorm attribute) F feature_axes (flax.linen.InstanceNorm attribute) (flax.linen.LayerNorm attribute) (flax.linen.RMSNorm attribute) (flax.linen.WeightNorm attribute) feature_group_count (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) features (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvLSTMCell attribute) (flax.linen.ConvTranspose attribute) (flax.linen.Dense attribute) (flax.linen.DenseGeneral attribute) (flax.linen.Embed attribute) (flax.linen.GRUCell attribute) (flax.linen.LSTMCell attribute) (flax.linen.MGUCell attribute) (flax.linen.SimpleCell attribute) find() (flax.cursor.Cursor method) find_all() (flax.cursor.Cursor method) flax.core.variables module flax.errors module flax.jax_utils module flax.linen module flax.linen.activation module flax.linen.initializers module flax.linen.spmd module flax.linen.transforms module flax.serialization module flax.struct module flax.traceback_util module flax.training.checkpoints module flax.training.lr_schedule module Folding in forget_bias_init (flax.linen.MGUCell attribute) freeze() (in module flax.core.frozen_dict) from_bytes() (in module flax.serialization) from_state_dict() (in module flax.serialization) FrozenDict (class in flax.core.frozen_dict) Functional core G gate_fn (flax.linen.GRUCell attribute) (flax.linen.LSTMCell attribute) (flax.linen.MGUCell attribute) (flax.linen.OptimizedLSTMCell attribute) gelu() (in module flax.linen.activation) get_logical_axis_rules() (in module flax.linen) get_metrics() (in module flax.training.common_utils) get_partition_spec() (in module flax.linen) get_sharding() (in module flax.linen) get_variable() (flax.linen.Module method) glorot_normal() (in module flax.linen.initializers) glorot_uniform() (in module flax.linen.initializers) glu() (in module flax.linen.activation) group_size (flax.linen.GroupNorm attribute) GroupNorm (class in flax.linen) GRUCell (class in flax.linen) H hard_sigmoid() (in module flax.linen.activation) hard_silu() (in module flax.linen.activation) hard_swish() (in module flax.linen.activation) hard_tanh() (in module flax.linen.activation) has_improved (flax.training.early_stopping.EarlyStopping attribute) has_rng() (flax.linen.Module method) has_variable() (flax.linen.Module method) he_normal() (in module flax.linen.initializers) he_uniform() (in module flax.linen.initializers) hide_flax_in_tracebacks() (in module flax.traceback_util) I ImmutableVariableError IncorrectPostInitOverrideError init() (flax.linen.Module method) (in module flax.linen), [1] init_with_output() (flax.linen.Module method) (in module flax.linen), [1] initialize_carry() (flax.linen.ConvLSTMCell method) (flax.linen.GRUCell method) (flax.linen.LSTMCell method) (flax.linen.MGUCell method) (flax.linen.OptimizedLSTMCell method) (flax.linen.RNNCellBase method) (flax.linen.SimpleCell method) input_dilation (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) InstanceNorm (class in flax.linen) intercept_methods() (in module flax.linen) InvalidCheckpointError InvalidFilterError InvalidInstanceModuleError InvalidRngError InvalidScopeError is_initializing() (flax.linen.Module method) is_mutable_collection() (flax.linen.Module method) J JaxTransformError jit() (in module flax.linen) jvp() (in module flax.linen) K kaiming_normal() (in module flax.linen.initializers) kaiming_uniform() (in module flax.linen.initializers) keep_order (flax.linen.RNN attribute) kernel_dilation (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvTranspose attribute) kernel_init (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvTranspose attribute) (flax.linen.Dense attribute) (flax.linen.DenseGeneral attribute) (flax.linen.Einsum attribute) (flax.linen.GRUCell attribute) (flax.linen.LSTMCell attribute) (flax.linen.MGUCell attribute) (flax.linen.MultiHeadAttention attribute) (flax.linen.MultiHeadDotProductAttention attribute) (flax.linen.OptimizedLSTMCell attribute) (flax.linen.SimpleCell attribute) kernel_size (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvLSTMCell attribute) (flax.linen.ConvTranspose attribute) L latest_checkpoint() (in module flax.training.checkpoints) layer_instance (flax.linen.SpectralNorm attribute) (flax.linen.WeightNorm attribute) LayerNorm (class in flax.linen) layers (flax.linen.Sequential attribute) Lazy initialization lazy_init() (flax.linen.Module method) LazyInitError leaky_relu() (in module flax.linen.activation) lecun_normal() (in module flax.linen.initializers) lecun_uniform() (in module flax.linen.initializers) Lifted transformation log_sigmoid() (in module flax.linen.activation) log_softmax() (in module flax.linen.activation) logical_axis_rules() (in module flax.linen) logical_to_mesh() (in module flax.linen) logical_to_mesh_axes() (in module flax.linen) logical_to_mesh_sharding() (in module flax.linen) LogicallyPartitioned() (in module flax.linen) logsumexp() (in module flax.linen.activation) LSTMCell (class in flax.linen) M make_attention_mask() (in module flax.linen) make_causal_mask() (in module flax.linen) make_rng() (flax.linen.Module method) map_variables() (in module flax.linen) mask (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvTranspose attribute) max_pool() (in module flax.linen) MGUCell (class in flax.linen) min_delta (flax.training.early_stopping.EarlyStopping attribute) ModifyScopeVariableError Module module (class in flax.linen) flax.core.variables flax.errors flax.jax_utils flax.linen flax.linen.activation flax.linen.initializers flax.linen.spmd flax.linen.transforms flax.serialization flax.struct flax.traceback_util flax.training.checkpoints flax.training.lr_schedule module_paths() (flax.linen.Module method) momentum (flax.linen.BatchNorm attribute) MPACheckpointingRequiredError MPARestoreDataCorruptedError MPARestoreTargetRequiredError msgpack_restore() (in module flax.serialization) msgpack_serialize() (in module flax.serialization) MultiHeadAttention (class in flax.linen) MultiHeadDotProductAttention (class in flax.linen) MultipleMethodsCompactError N n_steps (flax.linen.SpectralNorm attribute) NameInUseError negative_slope_init (flax.linen.activation.PReLU attribute) normal() (in module flax.linen.initializers) normalize_qk (flax.linen.MultiHeadAttention attribute) (flax.linen.MultiHeadDotProductAttention attribute) nowrap() (in module flax.linen) num_embeddings (flax.linen.Embed attribute) num_groups (flax.linen.GroupNorm attribute) num_heads (flax.linen.MultiHeadAttention attribute) (flax.linen.MultiHeadDotProductAttention attribute) O one_hot() (in module flax.linen.activation) onehot() (in module flax.training.common_utils) ones() (in module flax.linen.initializers) ones_init() (in module flax.linen.initializers) OptimizedLSTMCell (class in flax.linen) orthogonal() (in module flax.linen.initializers) out_bias_init (flax.linen.MultiHeadDotProductAttention attribute) out_features (flax.linen.MultiHeadAttention attribute) (flax.linen.MultiHeadDotProductAttention attribute) out_kernel_init (flax.linen.MultiHeadDotProductAttention attribute) override_named_call() (in module flax.linen) P pad_shard_unpad() (in module flax.jax_utils) padding (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvLSTMCell attribute) (flax.linen.ConvTranspose attribute) param() (flax.linen.Module method) param_dtype (flax.linen.activation.PReLU attribute), [1] (flax.linen.BatchNorm attribute) (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvLSTMCell attribute) (flax.linen.ConvTranspose attribute) (flax.linen.Dense attribute) (flax.linen.DenseGeneral attribute) (flax.linen.Einsum attribute) (flax.linen.Embed attribute) (flax.linen.GroupNorm attribute) (flax.linen.GRUCell attribute) (flax.linen.InstanceNorm attribute) (flax.linen.LayerNorm attribute) (flax.linen.LSTMCell attribute) (flax.linen.MGUCell attribute) (flax.linen.MultiHeadAttention attribute) (flax.linen.MultiHeadDotProductAttention attribute) (flax.linen.OptimizedLSTMCell attribute) (flax.linen.RMSNorm attribute) (flax.linen.SimpleCell attribute) (flax.linen.SpectralNorm attribute) (flax.linen.WeightNorm attribute) Params / parameters partial_eval_by_shape() (in module flax.jax_utils) Partitioned() (in module flax.linen) PartitioningUnspecifiedError path (flax.linen.Module property) patience (flax.training.early_stopping.EarlyStopping attribute) patience_count (flax.training.early_stopping.EarlyStopping attribute) perturb() (flax.linen.Module method) pmean() (in module flax.jax_utils) pool() (in module flax.linen) pop() (flax.core.frozen_dict.FrozenDict method) (in module flax.core.frozen_dict) precision (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvTranspose attribute) (flax.linen.Dense attribute) (flax.linen.DenseGeneral attribute) (flax.linen.Einsum attribute) (flax.linen.MultiHeadAttention attribute) (flax.linen.MultiHeadDotProductAttention attribute) prefetch_to_device() (in module flax.jax_utils) PReLU (class in flax.linen.activation) pretty_repr() (flax.core.frozen_dict.FrozenDict method) (in module flax.core.frozen_dict) promote_dtype (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvTranspose attribute) (flax.linen.Dense attribute) (flax.linen.DenseGeneral attribute) (flax.linen.Einsum attribute) (flax.linen.Embed attribute) put_variable() (flax.linen.Module method) PyTreeNode (class in flax.struct) Q qk_attn_weights_einsum_cls (flax.linen.MultiHeadDotProductAttention attribute) qkv_features (flax.linen.MultiHeadAttention attribute) (flax.linen.MultiHeadDotProductAttention attribute) R rate (flax.linen.Dropout attribute) recurrent_kernel_init (flax.linen.GRUCell attribute) (flax.linen.LSTMCell attribute) (flax.linen.MGUCell attribute) (flax.linen.OptimizedLSTMCell attribute) (flax.linen.SimpleCell attribute) reduction_axes (flax.linen.GroupNorm attribute) (flax.linen.LayerNorm attribute) (flax.linen.RMSNorm attribute) register_serialization_state() (in module flax.serialization) relu() (in module flax.linen.activation) remat() (in module flax.linen) remat_scan() (in module flax.linen) replicate() (in module flax.jax_utils) ReservedModuleAttributeError reset_gate (flax.linen.MGUCell attribute) residual (flax.linen.SimpleCell attribute) restore_checkpoint() (in module flax.training.checkpoints) return_carry (flax.linen.RNN attribute) reverse (flax.linen.RNN attribute) RMSNorm (class in flax.linen) RNG sequences rng_collection (flax.linen.Dropout attribute) RNN (class in flax.linen) RNNCellBase (class in flax.linen) S save_checkpoint() (in module flax.training.checkpoints) save_checkpoint_multiprocess() (in module flax.training.checkpoints) scale_init (flax.linen.BatchNorm attribute) (flax.linen.GroupNorm attribute) (flax.linen.InstanceNorm attribute) (flax.linen.LayerNorm attribute) (flax.linen.RMSNorm attribute) (flax.linen.WeightNorm attribute) scan() (in module flax.linen) Scope ScopeCollectionNotFound ScopeParamNotFoundError ScopeParamShapeError ScopeVariableNotFoundError SelfAttention (class in flax.linen) selu() (in module flax.linen.activation) Sequential (class in flax.linen) set() (flax.cursor.Cursor method) set_logical_axis_rules() (in module flax.linen) SetAttributeFrozenModuleError SetAttributeInModuleSetupError setup() (flax.linen.Module method) shape (flax.linen.Einsum attribute) Shape inference shard() (in module flax.training.common_utils) shard_prng_key() (in module flax.training.common_utils) share_scope() (in module flax.linen) should_stop (flax.training.early_stopping.EarlyStopping attribute) show_flax_in_tracebacks() (in module flax.traceback_util) sigmoid() (in module flax.linen.activation) silu() (in module flax.linen.activation) SimpleCell (class in flax.linen) soft_sign() (in module flax.linen.activation) softmax() (in module flax.linen.activation) softplus() (in module flax.linen.activation) sow() (flax.linen.Module method) SpectralNorm (class in flax.linen) split_rngs (flax.linen.RNN attribute) stack_forest() (in module flax.training.common_utils) standardize() (in module flax.linen.activation) strides (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvLSTMCell attribute) (flax.linen.ConvTranspose attribute) swish() (in module flax.linen.activation) switch() (in module flax.linen) T tabulate() (flax.linen.Module method) (in module flax.linen) tanh() (in module flax.linen.activation) time_major (flax.linen.RNN attribute) to_bytes() (in module flax.serialization) to_state_dict() (in module flax.serialization) TraceContextError TrainState (class in flax.training.train_state) TransformedMethodReturnValueError TransformTargetError transpose_kernel (flax.linen.ConvTranspose attribute) TraverseTreeError truncated_normal() (in module flax.linen.initializers) U unbind() (flax.linen.Module method) unfreeze() (flax.core.frozen_dict.FrozenDict method) (in module flax.core.frozen_dict) uniform() (in module flax.linen.initializers) unreplicate() (in module flax.jax_utils) unroll (flax.linen.RNN attribute) update() (flax.training.early_stopping.EarlyStopping method) use_bias (flax.linen.BatchNorm attribute) (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvTranspose attribute) (flax.linen.Dense attribute) (flax.linen.DenseGeneral attribute) (flax.linen.Einsum attribute) (flax.linen.GroupNorm attribute) (flax.linen.InstanceNorm attribute) (flax.linen.LayerNorm attribute) (flax.linen.MultiHeadAttention attribute) (flax.linen.MultiHeadDotProductAttention attribute) use_fast_variance (flax.linen.BatchNorm attribute) (flax.linen.GroupNorm attribute) (flax.linen.InstanceNorm attribute) (flax.linen.LayerNorm attribute) (flax.linen.RMSNorm attribute) use_running_average (flax.linen.BatchNorm attribute) use_scale (flax.linen.BatchNorm attribute) (flax.linen.GroupNorm attribute) (flax.linen.InstanceNorm attribute) (flax.linen.LayerNorm attribute) (flax.linen.RMSNorm attribute) (flax.linen.WeightNorm attribute) V Variable (class in flax.linen) Variable collections Variable dictionary variable() (flax.linen.Module method) variable_axes (flax.linen.RNN attribute) variable_broadcast (flax.linen.RNN attribute) variable_carry (flax.linen.RNN attribute) variable_filter (flax.linen.WeightNorm attribute) variables (flax.linen.Module property) variance_scaling() (in module flax.linen.initializers) vjp() (in module flax.linen) vmap() (in module flax.linen) W WeightNorm (class in flax.linen) while_loop() (in module flax.linen) with_logical_constraint() (in module flax.linen) with_logical_partitioning() (in module flax.linen) with_partitioning() (in module flax.linen) X xavier_normal() (in module flax.linen.initializers) xavier_uniform() (in module flax.linen.initializers) Z zeros() (in module flax.linen.initializers) zeros_init() (in module flax.linen.initializers)