Skip to main content
Back to top
Ctrl
+
K
Quick start
Flax Basics
Guides
Flax fundamentals
JAX 101
Flax Basics
Managing Parameters and State
setup
vs
compact
Dealing with Flax Module arguments
Randomness and PRNGs in Flax
Data preprocessing
Processing the entire Dataset
Loading datasets
Training techniques
Batch normalization
Dropout
Learning rate scheduling
Transfer learning
Save and load checkpoints
Parallel training
Ensembling on multiple devices
Scale up Flax Modules on multiple devices
Model inspection
Model surgery
Extracting intermediate values
Converting and upgrading
Migrating from Haiku to Flax
Convert PyTorch models to Flax
Migrate checkpointing to Orbax
Upgrading my codebase to Optax
Upgrading my codebase to Linen
RNNCellBase Upgrade Guide
Migrate to regular dicts
Quantization
User Guide on Using FP8
The Sharp Bits
Examples
Core examples
Google Research examples
Repositories that use Flax
Community examples
Glossary
Frequently Asked Questions (FAQ)
Developer notes
The Flax Module lifecycle
Lifted transformations
FLIPs
The Flax philosophy
How to contribute
API Reference
flax.config package
flax.core.frozen_dict package
flax.cursor package
flax.errors package
flax.jax_utils package
flax.linen
Module
Init/Apply
Layers
Activation functions
Initializers
Transformations
Inspection
Variable dictionary
SPMD
Decorators
Profiling
flax.serialization package
flax.struct package
flax.traceback_util package
flax.training package
flax.traverse_util package
Flax NNX
.rst
.pdf
Developer notes
Developer notes
#
The Flax Module lifecycle
Lifted transformations
FLIPs