diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f6aeccf92..4fa3328e4a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -103,7 +103,7 @@ vNext to keyword arguments. See more details in [#3389](https://github.com/google/flax/discussions/3389). - Use new typed PRNG keys throughout flax: this essentially involved changing uses of `jax.random.PRNGKey` to `jax.random.key`. - (See [JEP 9263](https://github.com/google/jax/pull/17297) for details). + (See [JEP 9263](https://github.com/jax-ml/jax/pull/17297) for details). If you notice dispatch performance regressions after this change, be sure you update `jax` to version 0.4.16 or newer. - Added `has_improved` field to EarlyStopping and changed the return signature of diff --git a/docs/contributing.md b/docs/contributing.md index 72c48ae1af..27b1043322 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -240,7 +240,7 @@ section above to keep the contents of both Markdown and Jupyter Notebook files i Some of the notebooks are built automatically as part of the pre-submit checks and as part of the [Read the Docs](https://flax.readthedocs.io/en/latest) build. The build will fail if cells raise errors. If the errors are intentional, you can either catch them, -or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/google/jax/pull/2402/files)). +or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/jax-ml/jax/pull/2402/files)). You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else re-saves the notebook. diff --git a/docs/flip/1777-default-dtype.md b/docs/flip/1777-default-dtype.md index 6344b6bb0e..6d1bca9d01 100644 --- a/docs/flip/1777-default-dtype.md +++ b/docs/flip/1777-default-dtype.md @@ -21,7 +21,7 @@ The current behavior is problematic and results in silent bugs, especially for d ### Dtypes in JAX -JAX uses a NumPy-inspired [dtype promotion](https://github.com/google/jax/blob/main/jax/_src/dtypes.py) mechanism as explained [here](https://jax.readthedocs.io/en/latest/type_promotion.html?highlight=lattice#type-promotion-semantics). The type promotion rules are summarized by the following type lattice: +JAX uses a NumPy-inspired [dtype promotion](https://github.com/jax-ml/jax/blob/main/jax/_src/dtypes.py) mechanism as explained [here](https://jax.readthedocs.io/en/latest/type_promotion.html?highlight=lattice#type-promotion-semantics). The type promotion rules are summarized by the following type lattice: ![JAX type promotion lattice](https://jax.readthedocs.io/en/latest/_images/type_lattice.svg) diff --git a/docs/guides/converting_and_upgrading/convert_pytorch_to_flax.rst b/docs/guides/converting_and_upgrading/convert_pytorch_to_flax.rst index 0ca3edf69b..2b0d0ed5c4 100644 --- a/docs/guides/converting_and_upgrading/convert_pytorch_to_flax.rst +++ b/docs/guides/converting_and_upgrading/convert_pytorch_to_flax.rst @@ -300,7 +300,7 @@ To load ``torch.nn.ConvTranspose2d`` parameters into Flax, we need to use the `` np.testing.assert_almost_equal(j_out, t_out, decimal=6) -.. _`pull request`: https://github.com/google/jax/pull/5772 +.. _`pull request`: https://github.com/jax-ml/jax/pull/5772 .. |nn.ConvTranspose| replace:: ``nn.ConvTranspose`` .. _nn.ConvTranspose: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.ConvTranspose diff --git a/docs/guides/flax_fundamentals/flax_basics.ipynb b/docs/guides/flax_fundamentals/flax_basics.ipynb index e8e43f21c1..c2b16e3dd6 100644 --- a/docs/guides/flax_fundamentals/flax_basics.ipynb +++ b/docs/guides/flax_fundamentals/flax_basics.ipynb @@ -951,7 +951,7 @@ "source": [ "### Exporting to Tensorflow's SavedModel with jax2tf\n", "\n", - "JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax." + "JAX released an experimental converter called [jax2tf](https://github.com/jax-ml/jax/tree/main/jax/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax." ] } ], diff --git a/docs/guides/flax_fundamentals/flax_basics.md b/docs/guides/flax_fundamentals/flax_basics.md index 52755e9b5c..5b7e0657ff 100644 --- a/docs/guides/flax_fundamentals/flax_basics.md +++ b/docs/guides/flax_fundamentals/flax_basics.md @@ -469,4 +469,4 @@ Flax provides a handy wrapper - `TrainState` - that simplifies the above code. C ### Exporting to Tensorflow's SavedModel with jax2tf -JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax. +JAX released an experimental converter called [jax2tf](https://github.com/jax-ml/jax/tree/main/jax/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax. diff --git a/docs/guides/flax_fundamentals/rng_guide.ipynb b/docs/guides/flax_fundamentals/rng_guide.ipynb index b2677f1a7f..708b750121 100644 --- a/docs/guides/flax_fundamentals/rng_guide.ipynb +++ b/docs/guides/flax_fundamentals/rng_guide.ipynb @@ -105,7 +105,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Set the JAX config variable `jax_threefry_partitionable` to `True`. This will be the default value in the future and makes the PRNG more efficiently auto-parallelizable under `jax.jit`. Refer to [JAX discussion](https://github.com/google/jax/discussions/18480) for more details." + "Set the JAX config variable `jax_threefry_partitionable` to `True`. This will be the default value in the future and makes the PRNG more efficiently auto-parallelizable under `jax.jit`. Refer to [JAX discussion](https://github.com/jax-ml/jax/discussions/18480) for more details." ] }, { @@ -1467,7 +1467,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[Flax lifted transforms](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html) allow you to use [JAX transforms](https://github.com/google/jax#transformations) with `Module` arguments. This section will show you how to control how PRNG keys are split in Flax lifted transforms.\n", + "[Flax lifted transforms](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html) allow you to use [JAX transforms](https://github.com/jax-ml/jax#transformations) with `Module` arguments. This section will show you how to control how PRNG keys are split in Flax lifted transforms.\n", "\n", "Refer to [Lifted transformations](https://flax.readthedocs.io/en/latest/developer_notes/lift.html) for more detail." ] diff --git a/docs/guides/flax_fundamentals/rng_guide.md b/docs/guides/flax_fundamentals/rng_guide.md index efb5432a17..838feff1ad 100644 --- a/docs/guides/flax_fundamentals/rng_guide.md +++ b/docs/guides/flax_fundamentals/rng_guide.md @@ -55,7 +55,7 @@ import hashlib jax.devices() ``` -Set the JAX config variable `jax_threefry_partitionable` to `True`. This will be the default value in the future and makes the PRNG more efficiently auto-parallelizable under `jax.jit`. Refer to [JAX discussion](https://github.com/google/jax/discussions/18480) for more details. +Set the JAX config variable `jax_threefry_partitionable` to `True`. This will be the default value in the future and makes the PRNG more efficiently auto-parallelizable under `jax.jit`. Refer to [JAX discussion](https://github.com/jax-ml/jax/discussions/18480) for more details. ```{code-cell} jax.config.update('jax_threefry_partitionable', True) @@ -647,7 +647,7 @@ jax.debug.visualize_array_sharding(out) +++ -[Flax lifted transforms](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html) allow you to use [JAX transforms](https://github.com/google/jax#transformations) with `Module` arguments. This section will show you how to control how PRNG keys are split in Flax lifted transforms. +[Flax lifted transforms](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html) allow you to use [JAX transforms](https://github.com/jax-ml/jax#transformations) with `Module` arguments. This section will show you how to control how PRNG keys are split in Flax lifted transforms. Refer to [Lifted transformations](https://flax.readthedocs.io/en/latest/developer_notes/lift.html) for more detail. diff --git a/docs/guides/flax_sharp_bits.ipynb b/docs/guides/flax_sharp_bits.ipynb index 22ccae6124..8055e829ae 100644 --- a/docs/guides/flax_sharp_bits.ipynb +++ b/docs/guides/flax_sharp_bits.ipynb @@ -49,7 +49,7 @@ "\n", "### Background \n", "\n", - "The [dropout](https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf) stochastic regularization technique randomly removes hidden and visible units in a network. Dropout is a random operation, requiring a PRNG state, and Flax (like JAX) uses [Threefry](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) PRNG that is splittable. \n", + "The [dropout](https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf) stochastic regularization technique randomly removes hidden and visible units in a network. Dropout is a random operation, requiring a PRNG state, and Flax (like JAX) uses [Threefry](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) PRNG that is splittable. \n", "\n", "> Note: Recall that JAX has an explicit way of giving you PRNG keys: you can fork the main PRNG state (such as `key = jax.random.key(seed=0)`) into multiple new PRNG keys with `key, subkey = jax.random.split(key)`. Refresh your memory in [πŸ”ͺ JAX - The Sharp Bits πŸ”ͺ Randomness and PRNG keys](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers).\n", "\n", diff --git a/docs/guides/flax_sharp_bits.md b/docs/guides/flax_sharp_bits.md index 617571144a..a01fc3232c 100644 --- a/docs/guides/flax_sharp_bits.md +++ b/docs/guides/flax_sharp_bits.md @@ -41,7 +41,7 @@ Check out a full example below. ### Background -The [dropout](https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf) stochastic regularization technique randomly removes hidden and visible units in a network. Dropout is a random operation, requiring a PRNG state, and Flax (like JAX) uses [Threefry](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) PRNG that is splittable. +The [dropout](https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf) stochastic regularization technique randomly removes hidden and visible units in a network. Dropout is a random operation, requiring a PRNG state, and Flax (like JAX) uses [Threefry](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) PRNG that is splittable. > Note: Recall that JAX has an explicit way of giving you PRNG keys: you can fork the main PRNG state (such as `key = jax.random.key(seed=0)`) into multiple new PRNG keys with `key, subkey = jax.random.split(key)`. Refresh your memory in [πŸ”ͺ JAX - The Sharp Bits πŸ”ͺ Randomness and PRNG keys](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers). diff --git a/docs/guides/parallel_training/flax_on_pjit.ipynb b/docs/guides/parallel_training/flax_on_pjit.ipynb index 2eddfb04a8..19861778b9 100644 --- a/docs/guides/parallel_training/flax_on_pjit.ipynb +++ b/docs/guides/parallel_training/flax_on_pjit.ipynb @@ -220,7 +220,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Note that device axis names like `'data'`, `'model'` or `None` are passed into both [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) and [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) API calls. This refers to how each dimension of this data should be sharded β€” either across one of the device mesh dimensions, or not sharded at all.\n", + "Note that device axis names like `'data'`, `'model'` or `None` are passed into both [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) and [`jax.lax.with_sharding_constraint`](https://github.com/jax-ml/jax/blob/main/jax/_src/pjit.py#L1516) API calls. This refers to how each dimension of this data should be sharded β€” either across one of the device mesh dimensions, or not sharded at all.\n", "\n", "For example:\n", "\n", diff --git a/docs/guides/parallel_training/flax_on_pjit.md b/docs/guides/parallel_training/flax_on_pjit.md index 98ccdfb248..a20c296627 100644 --- a/docs/guides/parallel_training/flax_on_pjit.md +++ b/docs/guides/parallel_training/flax_on_pjit.md @@ -135,7 +135,7 @@ class DotReluDot(nn.Module): return z, None ``` -Note that device axis names like `'data'`, `'model'` or `None` are passed into both [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) and [`jax.lax.with_sharding_constraint`](https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516) API calls. This refers to how each dimension of this data should be sharded β€” either across one of the device mesh dimensions, or not sharded at all. +Note that device axis names like `'data'`, `'model'` or `None` are passed into both [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) and [`jax.lax.with_sharding_constraint`](https://github.com/jax-ml/jax/blob/main/jax/_src/pjit.py#L1516) API calls. This refers to how each dimension of this data should be sharded β€” either across one of the device mesh dimensions, or not sharded at all. For example: diff --git a/docs/guides/training_techniques/transfer_learning.ipynb b/docs/guides/training_techniques/transfer_learning.ipynb index ae5400c6f2..909a11715a 100644 --- a/docs/guides/training_techniques/transfer_learning.ipynb +++ b/docs/guides/training_techniques/transfer_learning.ipynb @@ -48,7 +48,7 @@ "# Note that the Transformers library doesn't use the latest Flax version.\n", "! pip install -q \"transformers[flax]\"\n", "# Install/upgrade Flax and JAX. For JAX installation with GPU/TPU support,\n", - "# visit https://github.com/google/jax#installation.\n", + "# visit https://github.com/jax-ml/jax#installation.\n", "! pip install -U -q flax jax jaxlib" ] }, diff --git a/docs/guides/training_techniques/transfer_learning.md b/docs/guides/training_techniques/transfer_learning.md index e596ec6c75..acef5b72f6 100644 --- a/docs/guides/training_techniques/transfer_learning.md +++ b/docs/guides/training_techniques/transfer_learning.md @@ -38,7 +38,7 @@ Depending on your task, some of the content in this guide may be suboptimal. For # Note that the Transformers library doesn't use the latest Flax version. ! pip install -q "transformers[flax]" # Install/upgrade Flax and JAX. For JAX installation with GPU/TPU support, -# visit https://github.com/google/jax#installation. +# visit https://github.com/jax-ml/jax#installation. ! pip install -U -q flax jax jaxlib ``` diff --git a/docs/guides/training_techniques/use_checkpointing.ipynb b/docs/guides/training_techniques/use_checkpointing.ipynb index aa13a126b9..38e7ac9e86 100644 --- a/docs/guides/training_techniques/use_checkpointing.ipynb +++ b/docs/guides/training_techniques/use_checkpointing.ipynb @@ -48,7 +48,7 @@ "source": [ "## Setup\n", "\n", - "Install/upgrade Flax and [Orbax](https://github.com/google/orbax). For JAX installation with GPU/TPU support, visit [this section on GitHub](https://github.com/google/jax#installation)." + "Install/upgrade Flax and [Orbax](https://github.com/google/orbax). For JAX installation with GPU/TPU support, visit [this section on GitHub](https://github.com/jax-ml/jax#installation)." ] }, { diff --git a/docs/guides/training_techniques/use_checkpointing.md b/docs/guides/training_techniques/use_checkpointing.md index 1be8c5e0e0..9f01b5526e 100644 --- a/docs/guides/training_techniques/use_checkpointing.md +++ b/docs/guides/training_techniques/use_checkpointing.md @@ -48,7 +48,7 @@ If you need to learn more about `orbax.checkpoint`, refer to the [Orbax docs](ht ## Setup -Install/upgrade Flax and [Orbax](https://github.com/google/orbax). For JAX installation with GPU/TPU support, visit [this section on GitHub](https://github.com/google/jax#installation). +Install/upgrade Flax and [Orbax](https://github.com/google/orbax). For JAX installation with GPU/TPU support, visit [this section on GitHub](https://github.com/jax-ml/jax#installation). Note: Before running `import jax`, create eight fake devices to mimic a [multi-host environment](https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html?#aside-hosts-and-devices-in-jax) in this notebook. Note that the order of imports is important here. The `os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'` command works only with the CPU backend, which means it won't work with GPU/TPU acceleration on if you're running this notebook in Google Colab. If you are already running the code on multiple devices (for example, in a 4x2 TPU environment), you can skip running the next cell. diff --git a/docs/index.rst b/docs/index.rst index c286d4f0a0..f4c13ba9d8 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -102,7 +102,7 @@ Installation # or to install the latest version of Flax: pip install --upgrade git+https://github.com/google/flax.git -Flax installs the vanilla CPU version of JAX, if you need a custom version please check out `JAX's installation page `__. +Flax installs the vanilla CPU version of JAX, if you need a custom version please check out `JAX's installation page `__. Basic usage ^^^^^^^^^^^^ diff --git a/docs/quick_start.ipynb b/docs/quick_start.ipynb index 32530b9bed..ee1224d601 100644 --- a/docs/quick_start.ipynb +++ b/docs/quick_start.ipynb @@ -12,7 +12,7 @@ "\n", "Welcome to Flax!\n", "\n", - "Flax is an open source Python neural network library built on top of [JAX](https://github.com/google/jax). This tutorial demonstrates how to construct a simple convolutional neural\n", + "Flax is an open source Python neural network library built on top of [JAX](https://github.com/jax-ml/jax). This tutorial demonstrates how to construct a simple convolutional neural\n", "network (CNN) using the [Flax](https://flax.readthedocs.io) Linen API and train\n", "the network for image classification on the MNIST dataset." ] diff --git a/docs/quick_start.md b/docs/quick_start.md index ac8a9fb860..1bd713d56d 100644 --- a/docs/quick_start.md +++ b/docs/quick_start.md @@ -16,7 +16,7 @@ jupytext: Welcome to Flax! -Flax is an open source Python neural network library built on top of [JAX](https://github.com/google/jax). This tutorial demonstrates how to construct a simple convolutional neural +Flax is an open source Python neural network library built on top of [JAX](https://github.com/jax-ml/jax). This tutorial demonstrates how to construct a simple convolutional neural network (CNN) using the [Flax](https://flax.readthedocs.io) Linen API and train the network for image classification on the MNIST dataset. diff --git a/docs_nnx/contributing.md b/docs_nnx/contributing.md index 72c48ae1af..27b1043322 100644 --- a/docs_nnx/contributing.md +++ b/docs_nnx/contributing.md @@ -240,7 +240,7 @@ section above to keep the contents of both Markdown and Jupyter Notebook files i Some of the notebooks are built automatically as part of the pre-submit checks and as part of the [Read the Docs](https://flax.readthedocs.io/en/latest) build. The build will fail if cells raise errors. If the errors are intentional, you can either catch them, -or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/google/jax/pull/2402/files)). +or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/jax-ml/jax/pull/2402/files)). You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else re-saves the notebook. diff --git a/docs_nnx/flip/1777-default-dtype.md b/docs_nnx/flip/1777-default-dtype.md index 6344b6bb0e..6d1bca9d01 100644 --- a/docs_nnx/flip/1777-default-dtype.md +++ b/docs_nnx/flip/1777-default-dtype.md @@ -21,7 +21,7 @@ The current behavior is problematic and results in silent bugs, especially for d ### Dtypes in JAX -JAX uses a NumPy-inspired [dtype promotion](https://github.com/google/jax/blob/main/jax/_src/dtypes.py) mechanism as explained [here](https://jax.readthedocs.io/en/latest/type_promotion.html?highlight=lattice#type-promotion-semantics). The type promotion rules are summarized by the following type lattice: +JAX uses a NumPy-inspired [dtype promotion](https://github.com/jax-ml/jax/blob/main/jax/_src/dtypes.py) mechanism as explained [here](https://jax.readthedocs.io/en/latest/type_promotion.html?highlight=lattice#type-promotion-semantics). The type promotion rules are summarized by the following type lattice: ![JAX type promotion lattice](https://jax.readthedocs.io/en/latest/_images/type_lattice.svg) diff --git a/docs_nnx/mnist_tutorial.ipynb b/docs_nnx/mnist_tutorial.ipynb index 147ed07b06..a1aa4eae89 100644 --- a/docs_nnx/mnist_tutorial.ipynb +++ b/docs_nnx/mnist_tutorial.ipynb @@ -12,7 +12,7 @@ "\n", "Welcome to Flax NNX! In this tutorial you will learn how to build and train a simple convolutional neural network (CNN) to classify handwritten digits on the MNIST dataset using the Flax NNX API.\n", "\n", - "Flax NNX is a Python neural network library built upon [JAX](https://github.com/google/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning.\n", + "Flax NNX is a Python neural network library built upon [JAX](https://github.com/jax-ml/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning.\n", "\n", "Let’s get started!" ] diff --git a/docs_nnx/mnist_tutorial.md b/docs_nnx/mnist_tutorial.md index 74be0ec5b4..a4a05cf4ba 100644 --- a/docs_nnx/mnist_tutorial.md +++ b/docs_nnx/mnist_tutorial.md @@ -16,7 +16,7 @@ jupytext: Welcome to Flax NNX! In this tutorial you will learn how to build and train a simple convolutional neural network (CNN) to classify handwritten digits on the MNIST dataset using the Flax NNX API. -Flax NNX is a Python neural network library built upon [JAX](https://github.com/google/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning. +Flax NNX is a Python neural network library built upon [JAX](https://github.com/jax-ml/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning. Let’s get started!