diff --git a/docs/conf.py b/docs/conf.py index 32dc8addfb..f39622c4ad 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -116,7 +116,7 @@ href="https://flax-nnx.readthedocs.io/en/latest/index.html" style="text-decoration: none; color: white;" > - This is the Flax Linen site. Check out the new Flax NNX! + This is the Flax Linen site. Check out the new Flax NNX API! """ diff --git a/docs_nnx/conf.py b/docs_nnx/conf.py index 69a9116d40..d6bcdb6a6e 100644 --- a/docs_nnx/conf.py +++ b/docs_nnx/conf.py @@ -113,10 +113,10 @@ # href with no underline and white bold text color announcement = """ - This is the Flax NNX site. Click here for Flax Linen. + This site covers the new Flax NNX API. Click here for Flax Linen. """ diff --git a/docs_nnx/examples/community_examples.rst b/docs_nnx/examples/community_examples.rst deleted file mode 100644 index 079568c9a7..0000000000 --- a/docs_nnx/examples/community_examples.rst +++ /dev/null @@ -1,110 +0,0 @@ -Community examples -================== - -In addition to the `curated list of official Flax examples on GitHub `__, -there is a growing community of people using Flax to build new types of machine -learning models. We are happy to showcase any example built by the community here! - -If you want to submit your own Flax example, you can start by forking -one of the `official Flax examples on GitHub `__. - -Models -****** -.. list-table:: - :header-rows: 1 - - * - Link - - Author - - Task type - - Reference - * - `matthias-wright/flaxmodels `__ - - `@matthias-wright `__ - - Various - - GPT-2, ResNet, StyleGAN-2, VGG, ... - * - `DarshanDeshpande/jax-models `__ - - `@DarshanDeshpande `__ - - Various - - Segformer, Swin Transformer, ... also some stand-alone layers - * - `google/vision_transformer `__ - - `@andsteing `__ - - Image classification, image/text - - https://arxiv.org/abs/2010.11929, https://arxiv.org/abs/2105.01601, https://arxiv.org/abs/2111.07991, ... - * - `jax-resnet `__ - - `@n2cholas `__ - - Various resnet implementations - - `torch.hub `__ - * - `Wav2Vec2 finetuning `__ - - `@vasudevgupta7 `__ - - Automatic Speech Recognition - - https://arxiv.org/abs/2006.11477 - -Examples -******** - -.. list-table:: - :header-rows: 1 - - * - Link - - Author - - Task type - - Reference - * - `JAX-RL `__ - - `@henry-prior `__ - - Reinforcement learning - - N/A - * - `BigBird Fine-tuning `__ - - `@vasudevgupta7 `__ - - Question-Answering - - https://arxiv.org/abs/2007.14062 - * - `DCGAN `__ - - `@bkkaggle `__ - - Image Synthesis - - https://arxiv.org/abs/1511.06434 - * - `denoising-diffusion-flax `__ - - `@yiyixuxu `__ - - Image generation - - https://arxiv.org/abs/2006.11239 - -Tutorials -********* - -.. currently left empty as a placeholder for tutorials -.. list-table:: - :header-rows: 1 - - * - Link - - Author - - Task type - - Reference - * - - - - - - - - -Contributing policy -******************* - -If you are interested in adding a project to the Community Examples section, take the following -into consideration: - -* **Code examples**: Examples must contain a README that is helpful, clear, and explains - how to run the code. The code itself should be easy to follow. -* **Tutorials**: These docs should preferrably be a Jupyter Notebook format - (refer to `Contributing `__ - to learn how to convert a Jupyter Notebook into a Markdown file with `jupytext`). - Your tutorial should be well-written, and discuss/describe an interesting topic/task. - To avoid duplication, the content of these docs must be different from - `existing docs on the Flax documentation site `__ - or other community examples mentioned in this document. -* **Models**: repositories with models ported to Flax must provide at least one of the following: - - * Metrics that are comparable to the original work when the model is trained to completion. Having - available plots of the metric's history during training is highly encouraged. - * Tests to verify numerical equivalence against a well known implementation (same inputs - + weights = same outputs) preferably using pretrained weights. - -In all cases mentioned above, the code must work with the latest stable versions of the -following packages: ``jax``, ``flax``, and ``optax``, and make substantial use of Flax. -Note that both ``jax`` and ``optax`` are `required packages `__ -of ``flax`` (refer to the `installation instructions `__ -for more details). diff --git a/docs_nnx/examples/core_examples.rst b/docs_nnx/examples/core_examples.rst index 34e3bacc11..40f705228d 100644 --- a/docs_nnx/examples/core_examples.rst +++ b/docs_nnx/examples/core_examples.rst @@ -7,81 +7,20 @@ directory. Each example is designed to be **self-contained and easily forkable**, while reproducing relevant results in different areas of machine learning. -As discussed in `#231 `__, we decided -to go for a standard pattern for all examples including the simplest ones (like MNIST). -This makes every example a bit more verbose, but once you know one example, you -know the structure of all of them. Having unit tests and integration tests is also -very useful when you fork these examples. - Some of the examples below have a link "Interactive🕹" that lets you run them directly in Colab. -Image classification +Transformers ******************** -- :octicon:`mark-github;0.9em` `MNIST `__ - - `Interactive🕹 `__: - Convolutional neural network for MNIST classification (featuring simple - code). - -- :octicon:`mark-github;0.9em` `ImageNet `__ - - `Interactive🕹 `__: - Resnet-50 on ImageNet with weight decay (featuring multi-host SPMD, custom - preprocessing, checkpointing, dynamic scaling, mixed precision). - -Reinforcement learning -********************** - -- :octicon:`mark-github;0.9em` `Proximal Policy Optimization `__: - Learning to play Atari games (featuring single host SPMD, RL setup). - -Natural language processing -*************************** - -- :octicon:`mark-github;0.9em` `Sequence to sequence for number - addition `__: - (featuring simple code, LSTM state handling, on the fly data generation). -- :octicon:`mark-github;0.9em` `Parts-of-speech - tagging `__: Simple - transformer encoder model using the universal dependency dataset. -- :octicon:`mark-github;0.9em` `Sentiment - classification `__: - with a LSTM model. -- :octicon:`mark-github;0.9em` `Transformer encoder/decoder model trained on - WMT `__: - Translating English/German (featuring multihost SPMD, dynamic bucketing, - attention cache, packed sequences, recipe for TPU training on GCP). -- :octicon:`mark-github;0.9em` `Transformer encoder trained on one billion word - benchmark `__: - for autoregressive language modeling, based on the WMT example above. +- :octicon:`mark-github;0.9em` `Gemma `__ : + A family of open-weights Large Language Model (LLM) by Google DeepMind, based on Gemini research and technology. -Generative models -***************** +- :octicon:`mark-github;0.9em` `LM1B `__ : + Transformer encoder trained on the One Billion Word Benchmark. -- :octicon:`mark-github;0.9em` `Variational - auto-encoder `__: - Trained on binarized MNIST (featuring simple code, vmap). - -Graph modeling -************** - -- :octicon:`mark-github;0.9em` `Graph Neural Networks `__: - Molecular predictions on ogbg-molpcba from the Open Graph Benchmark. - -Contributing to core Flax examples -********************************** - -Most of the `core Flax examples on GitHub `__ -follow a structure that the Flax dev team found works well with Flax projects. -The team strives to make these examples easy to explore and fork. In particular -(as per GitHub Issue `#231 `__): +Toy examples +******************** -- README: contains links to paper, command line, `TensorBoard `__ metrics. -- Focus: an example is about a single model/dataset. -- Configs: we use ``ml_collections.ConfigDict`` stored under ``configs/``. -- Tests: executable ``main.py`` loads ``train.py`` which has ``train_test.py``. -- Data: is read from `TensorFlow Datasets `__. -- Standalone: every directory is self-contained. -- Requirements: versions are pinned in ``requirements.txt``. -- Boilerplate: is reduced by using `clu `__. -- Interactive: the example can be explored with a `Colab `__. \ No newline at end of file +`NNX toy examples `__ +directory contains a few smaller, standalone toy examples for simple training scenarios. diff --git a/docs_nnx/examples/google_research_examples.rst b/docs_nnx/examples/google_research_examples.rst deleted file mode 100644 index 83e0101001..0000000000 --- a/docs_nnx/examples/google_research_examples.rst +++ /dev/null @@ -1,269 +0,0 @@ -######################## -Google Research examples -######################## - -A collection of research by Google Research made with Flax. - -Attention -********* - -Fast Attention (FAVOR+) and Rethinking Attention with Performers -================================================================ - -- Code on GitHub: - - - `Performer's Fast Attention (FAVOR+) module `__ - -- Research paper: - - - `Rethinking Attention with Performers `__ (Choromanski et al., 2020) - - - Introduces *"Performers, Transformer architectures which can estimate regular (softmax) full-rank-attention Transformers with provable accuracy, but using only linear (as opposed to quadratic) space and time complexity, without relying on any priors such as sparsity or low-rankness. To approximate softmax attention-kernels, Performers use a novel Fast Attention Via positive Orthogonal Random features approach (FAVOR+), which may be of independent interest for scalable kernel methods. FAVOR+ can be also used to efficiently model kernelizable attention mechanisms beyond softmax."* - -Self-attention Does Not Need O(n^2) Memory -========================================== - -- `Code on GitHub `__ -- `Colab notebook `__ - -- Research paper: - - - `Self-attention Does Not Need O(n^2) Memory `__ (Rabe and Staats, 2021) - - - *"We present a very simple algorithm for attention that requires O(1) memory with respect to sequence length and an extension to self-attention that requires O(log n) memory. This is in contrast with the frequently stated belief that self-attention requires O(n^2) memory. While the time complexity is still O(n^2), device memory rather than compute capability is often the limiting factor on modern accelerators. Thus, reducing the memory requirements of attention allows processing of longer sequences than might otherwise be feasible..."* - -Computer vision -*************** - -Colorization Transformer (ColTran) -================================== - -- `Code on GitHub `__ - -- Research paper: - - - `Colorization Transformer `__ (Kumar et al., 2020) - - - *"We presented the Colorization Transformer (ColTran), an architecture that entirely relies on self-attention for image colorization. We introduce conditional transformer layers, a novel building block for conditional, generative models based on self-attention. Our ablations show the superiority of employing this mechanism over a number of different baselines. Finally, we demonstrate that ColTran can generate diverse, high-fidelity colorizations on ImageNet, which are largely indistinguishable from the ground-truth even for human raters."* - -Vision Transformer (ViT), MLP-Mixer Architectures *and* Big Vision -================================================================== - -- Code on GitHub: - - - `Vision Transformer and MLP-Mixer Architectures `__ - - - `Big Vision `__ - - - *"This codebase is designed for training large-scale vision models using Cloud TPU VMs or GPU machines. It is based on Jax/Flax libraries, and uses tf.data and TensorFlow Datasets for scalable and reproducible input pipelines."* - -- `Colab notebooks `__: - - - The JAX code of Vision Transformers and MLP Mixers - - More than 50k Vision Transformer and hybrid checkpoints that were used to generate the data of "How to train your ViT?" - -- Research papers: - - - `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `__ (Dosovitskiy et al., 2020) - - - *"In vision, attention is either applied in conjunction with convolutional networks, or used to replace certain components of convolutional networks while keeping their overall structure in place. We show that this reliance on CNNs is not necessary and a pure transformer applied directly to sequences of image patches can perform very well on image classification tasks. When pre-trained on large amounts of data and transferred to multiple mid-sized or small image recognition benchmarks (ImageNet, CIFAR-100, VTAB, etc.), Vision Transformer (ViT) attains excellent results compared to state-of-the-art convolutional networks while requiring substantially fewer computational resources to train."* - - - `MLP-Mixer: An All-MLP Architecture for Vision `__ (Tolstikhin et al., 2021) - - - *"In this paper we show that while convolutions and attention are both sufficient for good performance, neither of them are necessary. We present MLP-Mixer, an architecture based exclusively on multi-layer perceptrons (MLPs). MLP-Mixer contains two types of layers: one with MLPs applied independently to image patches (i.e. "mixing" the per-location features), and one with MLPs applied across patches (i.e. "mixing" spatial information). When trained on large datasets, or with modern regularization schemes, MLP-Mixer attains competitive scores on image classification benchmarks, with pre-training and inference cost comparable to state-of-the-art models."* - - - `How to Train Your ViT? Data, Augmentation, and Regularization in Vision Transformers `__ (Steiner et al., 2021) - - - *"Vision Transformers (ViT) have been shown to attain highly competitive performance for a wide range of vision applications, such as image classification, object detection and semantic image segmentation. In comparison to convolutional neural networks, the Vision Transformer's weaker inductive bias is generally found to cause an increased reliance on model regularization or data augmentation ("AugReg" for short) when training on smaller training datasets. We conduct a systematic empirical study in order to better understand the interplay between the amount of training data, AugReg, model size and compute budget."* - - - `When Vision Transformers Outperform ResNets without Pretraining or Strong Data Augmentations `__ (X. Chen et al., 2021) - - - *"Vision Transformers (ViTs) and MLPs signal further efforts on replacing hand-wired features or inductive biases with general-purpose neural architectures. Existing works empower the models by massive data, such as large-scale pre-training and/or repeated strong data augmentations, and still report optimization-related problems (e.g., sensitivity to initialization and learning rates). Hence, this paper investigates ViTs and MLP-Mixers from the lens of loss geometry, intending to improve the models' data efficiency at training and generalization at inference."* - - - `LiT: Zero-Shot Transfer with Locked-image Text Tuning `__ (X. Zhai et al., 2021) - - - *"This paper presents contrastive-tuning, a simple method employing contrastive training to align image and text models while still taking advantage of their pre-training. In our empirical study we find that locked pre-trained image models with unlocked text models work best. We call this instance of contrastive-tuning "Locked-image Tuning" (LiT), which just teaches a text model to read out good representations from a pre-trained image model for new tasks. A LiT model gains the capability of zero-shot transfer to new vision tasks, such as image classification or retrieval. The proposed LiT is widely applicable; it works reliably with multiple pre-training methods (supervised and unsupervised) and across diverse architectures (ResNet, Vision Transformers and MLP-Mixer) using three different image-text datasets."* - -Scaling Vision with Sparse Mixture of Experts (MoE) -=================================================== - -- `Code on GitHub `__ -- Research paper: - - - `Scaling Vision with Sparse Mixture of Experts `__ (Riquelme et al., 2021) - - - *"Sparsely-gated Mixture of Experts networks (MoEs) have demonstrated excellent scalability in Natural Language Processing. In Computer Vision, however, almost all performant networks are "dense", that is, every input is processed by every parameter. We present a Vision MoE (V-MoE), a sparse version of the Vision Transformer, that is scalable and competitive with the largest dense networks... we demonstrate the potential of V-MoE to scale vision models, and train a 15B parameter model that attains 90.35% on ImageNet..."* - -Diffusion -********* - -Variational Diffusion Models -============================ - -- `Code on GitHub `__ -- `Colab notebooks `__ -- Research paper: - - - `Variational Diffusion Models `__ (Kingma et al., 2021) - - - *"Diffusion-based generative models have demonstrated a capacity for perceptually impressive synthesis, but can they also be great likelihood-based models? We answer this in the affirmative, and introduce a family of diffusion-based generative models that obtain state-of-the-art likelihoods on standard image density estimation benchmarks. Unlike other diffusion-based models, our method allows for efficient optimization of the noise schedule jointly with the rest of the model. We show that the variational lower bound (VLB) simplifies to a remarkably short expression in terms of the signal-to-noise ratio of the diffused data, thereby improving our theoretical understanding of this model class. Using this insight, we prove an equivalence between several models proposed in the literature. In addition, we show that the continuous-time VLB is invariant to the noise schedule, except for the signal-to-noise ratio at its endpoints. This enables us to learn a noise schedule that minimizes the variance of the resulting VLB estimator, leading to faster optimization..."* - -Domain adaptation -***************** - -GIFT (Gradual Interpolation of Features toward Target) -====================================================== - -- `Code on GitHub `__ -- Research paper: - - - `Gradual Domain Adaptation in the Wild: When Intermediate Distributions are Absent `__ (Abnar et al., 2021) - - - *"We focus on the problem of domain adaptation when the goal is shifting the model towards the target distribution, rather than learning domain invariant representations. It has been shown that under the following two assumptions: (a) access to samples from intermediate distributions, and (b) samples being annotated with the amount of change from the source distribution, self-training can be successfully applied on gradually shifted samples to adapt the model toward the target distribution. We hypothesize having (a) is enough to enable iterative self-training to slowly adapt the model to the target distribution, by making use of an implicit curriculum. In the case where (a) does not hold, we observe that iterative self-training falls short. We propose GIFT, a method that creates virtual samples from intermediate distributions by interpolating representations of examples from source and target domains..."* - -Generalization -************** - -Surrogate Gap Minimization Improves Sharpness-Aware Training -============================================================ - -- `Code on GitHub `__ -- Research paper: - - - `Surrogate Gap Minimization Improves Sharpness-Aware Training `__ (J. Zhuang et al., 2022) - - - *"The recently proposed Sharpness-Aware Minimization (SAM) improves generalization by minimizing a perturbed loss defined as the maximum loss within a neighborhood in the parameter space. However, we show that both sharp and flat minima can have a low perturbed loss, implying that SAM does not always prefer flat minima. Instead, we define a surrogate gap, a measure equivalent to the dominant eigenvalue of Hessian at a local minimum when the radius of neighborhood (to derive the perturbed loss) is small. The surrogate gap is easy to compute and feasible for direct minimization during training. Based on the above observations, we propose Surrogate Gap Guided Sharpness-Aware Minimization (GSAM), a novel improvement over SAM with negligible computation overhead..."* - -Meta learning -************* - -``learned_optimization`` -======================= - -- Code on GitHub: `learned_optimization `__ -- `Colab notebooks `__ - -- Research papers: - - - `Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies `__ (Vicol et al., 2021) - - - *"We introduce a method called Persistent Evolution Strategies (PES), which divides the computation graph into a series of truncated unrolls, and performs an evolution strategies-based update step after each unroll. PES eliminates bias from these truncations by accumulating correction terms over the entire sequence of unrolls. PES allows for rapid parameter updates, has low memory usage, is unbiased, and has reasonable variance characteristics."* - - - `Gradients Are Not All You Need `__ (Metz et al., 2021) - - - *"...In this short report, we discuss a common chaos based failure mode which appears in a variety of differentiable circumstances, ranging from recurrent neural networks and numerical physics simulation to training learned optimizers. We trace this failure to the spectrum of the Jacobian of the system under study, and provide criteria for when a practitioner might expect this failure to spoil their differentiation based optimization algorithms."* - -Model efficiency -**************** - -Efficiently Scaling Transformer Inference -========================================= - -- Code on GitHub: - - - `T5X `__ - - `AQT: Accurate Quantized Training `__ - -- Research paper: - - - `Efficiently Scaling Transformer Inference `__ (Pope et al., 2022) - - - *"We develop a simple analytical model for inference efficiency to select the best multi-dimensional partitioning techniques optimized for TPU v4 slices based on the application requirements. We combine these with a suite of low-level optimizations to achieve a new Pareto frontier on the latency and model FLOPS utilization (MFU) tradeoffs on 500B+ parameter models that outperforms the FasterTransformer suite of benchmarks. We further show that with appropriate partitioning, the lower memory requirements of multiquery attention (i.e. multiple query heads share single key/value head) enables scaling up to 32× larger context lengths."* - -Neural rendering / NeRF -*********************** - -Generalizable Patch-Based Neural Rendering -========================================== - -- `Code on GitHub `__ -- Research paper: - - - `Generalizable Patch-Based Neural Rendering `__ (Suhail et al., 2022) - - - *"...We propose a different paradigm, where no deep features and no NeRF-like volume rendering are needed. Our method is capable of predicting the color of a target ray in a novel scene directly, just from a collection of patches sampled from the scene."* - -Voxel-based Radiance Fields in JAX and Flax -=========================================== - -- `Colab notebook `__ (Velez and Dellaert, 2022) - - - *"In this notebook we show how with JAX/Flax, it is relatively easy to quickly get a voxel-based NeRF variant up and running. Specifically, we will develop a simplified version of DVGO that directly regresses color instead of having a small MLP. It works remarkably well."* - -Optimization -************ - -Amos Optimizer *and* JEstimator -=============================== - -- Code on GitHub: - - - `Amos and JEstimator `__ - - - *"... implements Amos, an optimizer compatible with the optax library, and JEstimator, a light-weight library with a tf.Estimator-like interface to manage T5X-compatible checkpoints for machine learning programs in JAX, which we use to run experiments in the paper."* - -- Research paper: - - - `Amos: An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale `__ (Tian and Parikh, 2022) - - - Presents *"Amos, an optimizer compatible with the optax library, and JEstimator, a light-weight library with a tf.Estimator-like interface to manage T5X-compatible checkpoints for machine learning programs in JAX."* *"When used for pre-training BERT variants and T5, Amos consistently converges faster than the state-of-the-art settings of AdamW, achieving better validation loss within <=70% training steps and time, while requiring <=51% memory for slot variables."* - -Quantization -************ - -Pareto-Optimal Quantized ResNet Is Mostly 4-bit *and* AQT: Accurate Quantized Training -====================================================================================== - -- Code on GitHub: - - - `AQT: Accurate Quantized Training `__ - -- Research paper: - - - `Pareto-Optimal Quantized ResNet Is Mostly 4-bit `__ (Abdolrashidi et al., 2021) - - - *"In this work, we use ResNet as a case study to systematically investigate the effects of quantization on inference compute cost-quality tradeoff curves. Our results suggest that for each bfloat16 ResNet model, there are quantized models with lower cost and higher accuracy; in other words, the bfloat16 compute cost-quality tradeoff curve is Pareto-dominated by the 4-bit and 8-bit curves, with models primarily quantized to 4-bit yielding the best Pareto curve... The quantization method we used is optimized for practicality: It requires little tuning and is designed with hardware capabilities in mind... As part of this work, we contribute a quantization library written in JAX..."* - -Reinforcement learning -********************** - -Continuous Control with Action Quantization from Demonstrations (AQuaDem) -========================================================================= - -- `Code on GitHub `__ - -- Research paper: - - - `Continuous Control with Action Quantization from Demonstrations `__ (Dadashi et al., 2021) - - - Proposes *"a novel Reinforcement Learning (RL) framework for problems with continuous action spaces: Action Quantization from Demonstrations (AQuaDem). The proposed approach consists in learning a discretization of continuous action spaces from human demonstrations. This discretization returns a set of plausible actions (in light of the demonstrations) for each input state, thus capturing the priors of the demonstrator and their multimodal behavior. By discretizing the action space, any discrete action deep RL technique can be readily applied to the continuous control problem. Experiments show that the proposed approach outperforms state-of-the-art methods such as SAC in the RL setup, and GAIL in the Imitation Learning setup."* - -Sequence models / Model parallelism -*********************************** - -T5X: Scaling Up Models and Data with ``t5x`` and ``seqio`` -========================================================== - -- `Code on GitHub `__ - - - *"T5X is a modular, composable, research-friendly framework for high-performance, configurable, self-service training, evaluation, and inference of sequence models (starting with language) at many scales."* - -- Research paper: - - - `T5X: Scaling Up Models and Data with t5x and seqio `__ (Roberts et al., 2022) - - - *"Recent neural network-based language models have benefited greatly from scaling up the size of training datasets and the number of parameters in the models themselves. Scaling can be complicated due to various factors including the need to distribute computation on supercomputer clusters (e.g., TPUs), prevent bottlenecks when infeeding data, and ensure reproducible results. In this work, we present two software libraries that ease these issues: t5x simplifies the process of building and training large language models at scale while maintaining ease of use, and seqio provides a task-based API for simple creation of fast and reproducible training data and evaluation pipelines. These open-source libraries have been used to train models with hundreds of billions of parameters on datasets with multiple terabytes of training data. Along with the libraries, we release configurations and instructions for T5-like encoder-decoder models as well as GPT-like decoder-only architectures."* - -Simulation -********** - -Brax - A Differentiable Physics Engine for Large Scale Rigid Body Simulation -============================================================================ - -- `Code on GitHub `__ -- `Colab notebooks `__ -- Research paper: - - - `Brax - A Differentiable Physics Engine for Large Scale Rigid Body Simulation `__ (Freeman et al., 2021) - - - *"We present Brax, an open source library for rigid body simulation with a focus on performance and parallelism on accelerators, written in JAX. We present results on a suite of tasks inspired by the existing reinforcement learning literature, but remade in our engine. Additionally, we provide reimplementations of PPO, SAC, ES, and direct policy optimization in JAX that compile alongside our environments, allowing the learning algorithm and the environment processing to occur on the same device, and to scale seamlessly on accelerators."* diff --git a/docs_nnx/examples/index.rst b/docs_nnx/examples/index.rst index cd77fd9cee..1d5ebf7285 100644 --- a/docs_nnx/examples/index.rst +++ b/docs_nnx/examples/index.rst @@ -5,8 +5,5 @@ Examples :maxdepth: 2 core_examples - google_research_examples - repositories_that_use_flax - community_examples diff --git a/docs_nnx/examples/repositories_that_use_flax.rst b/docs_nnx/examples/repositories_that_use_flax.rst deleted file mode 100644 index dfc23f6ad4..0000000000 --- a/docs_nnx/examples/repositories_that_use_flax.rst +++ /dev/null @@ -1,51 +0,0 @@ -Repositories that use Flax -========================== - -The following code bases use Flax and provide training frameworks and a wealth -of examples. In many cases, you can also find pre-trained weights: - - -🤗 Hugging Face -*************** - -`🤗 Hugging Face `__ is a -very popular library for building, training, and deploying state of the art -machine learning models. -These models can be applied on text, images, and audio. After organizing the -`JAX/Flax community week `__, -they have now over 5,000 -`Flax/JAX models `__ in -their repository. - -🥑 DALLE Mini -************* - -`🥑 DALLE Mini `__ is a Transformer-based -text-to-image model implemented in JAX/Flax that follows the ideas from the -original `DALLE `__ paper by OpenAI. - -Scenic -****** - -`Scenic `__ is a codebase/library -for computer vision research and beyond. Scenic's main focus is around -attention-based models. Scenic has been successfully used to develop -classification, segmentation, and detection models for multiple modalities -including images, video, audio, and multimodal combinations of them. - -Big Vision -********** - -`Big Vision `__ is a codebase -designed for training large-scale vision models using Cloud TPU VMs or GPU -machines. It is based on Jax/Flax libraries, and uses tf.data and TensorFlow -Datasets for scalable and reproducible input pipelines. This is the original -codebase of ViT, MLP-Mixer, LiT, UViM, and many more models. - -T5X -*** - -`T5X `__ is a modular, composable, -research-friendly framework for high-performance, configurable, self-service -training, evaluation, and inference of sequence models (starting with -language) at many scales. \ No newline at end of file diff --git a/docs_nnx/index.rst b/docs_nnx/index.rst index d5e0c9d34d..30fd54d06c 100644 --- a/docs_nnx/index.rst +++ b/docs_nnx/index.rst @@ -8,10 +8,10 @@ Flax NNX ---- -Flax NNX is a new simplified API that is designed to make it easier to create, inspect, -debug, and analyze neural networks in JAX. It achieves this by adding first class support +**Flax NNX is a simplified API that makes it easier to create, inspect, +debug, and analyze neural networks in JAX.** It has first class support for Python reference semantics, allowing users to express their models using regular -Python objects. Flax NNX is an evolution of the previous Flax Linen APIs, it takes years of +Python objects. Flax NNX is an evolution of the previous Flax Linen APIs, and it takes years of experience to bring a simpler and more user-friendly experience. .. note:: @@ -104,14 +104,14 @@ Basic usage model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing - @nnx.jit # automatic state management + @nnx.jit # automatic state management for JAX transforms def train_step(model, optimizer, x, y): def loss_fn(model): y_pred = model(x) # call methods directly return ((y_pred - y) ** 2).mean() loss, grads = nnx.value_and_grad(loss_fn)(model) - optimizer.update(grads) # inplace updates + optimizer.update(grads) # in-place updates return loss @@ -184,6 +184,7 @@ Learn more nnx_basics mnist_tutorial guides/index + examples/index The Flax philosophy How to contribute api_reference/index