From a46578ec8e1f483d6354ce9a40cf63176fd992d7 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 4 Nov 2024 22:19:39 +0000 Subject: [PATCH] Upgrade Flax NNX Gemma --- docs_nnx/guides/gemma.ipynb | 81 ++++++++++++++++++++++--------------- docs_nnx/guides/gemma.md | 77 +++++++++++++++++++++-------------- 2 files changed, 96 insertions(+), 62 deletions(-) diff --git a/docs_nnx/guides/gemma.ipynb b/docs_nnx/guides/gemma.ipynb index 1c59c951df..412aae7107 100644 --- a/docs_nnx/guides/gemma.ipynb +++ b/docs_nnx/guides/gemma.ipynb @@ -4,16 +4,22 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Example: Using Pretrained Gemma\n", + "# Example: Using pretrained Gemma for inference with Flax NNX\n", "\n", - "You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it." + "This example shows how to use Flax NNX to load the [Gemma](https://ai.google.dev/gemma) open model files and use them to perform sampling/inference for generating text. You will use [Flax NNX `gemma` modules](https://github.com/google/flax/tree/main/examples/gemma) written with Flax and JAX for model parameter configuration and inference.\n", + "\n", + "> Gemma is a family of lightweight, state-of-the-art open models based on Google DeepMind’s [Gemini](https://deepmind.google/technologies/gemini/#introduction). Read more about [Gemma](https://blog.google/technology/developers/gemma-open-models/) and [Gemma 2](https://blog.google/technology/developers/google-gemma-2/).\n", + "\n", + "You are recommended to use [Google Colab](https://colab.research.google.com/) with access to A100 GPU acceleration to run the code." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Installation" + "## Installation\n", + "\n", + "Install the necessary dependencies, including `kagglehub`." ] }, { @@ -30,13 +36,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Downloading the checkpoint\n", + "## Download the model\n", "\n", - "\"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:\n", + "To use Gemma model, you'll need a [Kaggle](https://www.kaggle.com/models/google/gemma/) account and API key:\n", "\n", - "1. Visit https://www.kaggle.com/ and create an account.\n", - "2. Go to your account settings, then the 'API' section.\n", - "3. Click 'Create new token' to download your key.\n", + "1. To create an account, visit [Kaggle](https://www.kaggle.com/) and click on 'Register'.\n", + "2. If/once you have an account, you need to sign in, go to your ['Settings'](https://www.kaggle.com/settings), and under 'API' click on 'Create New Token' to generate and download your Kaggle API key.\n", + "3. In [Google Colab](https://colab.research.google.com/), under 'Secrets' add your Kaggle username and API key, storing the username as `KAGGLE_USERNAME` and the key as `KAGGLE_KEY`. If you are using a [Kaggle Notebook](https://www.kaggle.com/code) for free TPU or other hardware acceleration, it has a key storage feature under 'Add-ons' > 'Secrets', along with instructions for accessing stored keys.\n", "\n", "Then run the cell below." ] @@ -70,13 +76,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "If everything went well, you should see:\n", - "```\n", - "Kaggle credentials set.\n", - "Kaggle credentials successfully validated.\n", + "If everything went well, it should say `Kaggle credentials set. Kaggle credentials successfully validated.`.\n", + "\n", + "**Note:** In Google Colab, you can instead authenticate into Kaggle using the code below after following the optional step 3 from above.\n", + "\n", "```\n", + "import os\n", + "from google.colab import userdata # `userdata` is a Colab API.\n", + "\n", + "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n", + "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')\n", + "``` \n", "\n", - "Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models." + "Now, load the Gemma model you want to try. The code in the next cell utilizes [`kagglehub.model_download`](https://github.com/Kaggle/kagglehub/blob/8efe3e99477aa4f41885840de6903e61a49df4aa/src/kagglehub/models.py#L16) to download model files.\n", + "\n", + "**Note:** For larger models, such as `gemma 7b` and `gemma 7b-it` (instruct), you may require a hardware accelerator with plenty of memory, such as the NVIDIA A100." ] }, { @@ -90,9 +104,7 @@ "VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n", "weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')\n", "ckpt_path = f'{weights_dir}/{VARIANT}'\n", - "vocab_path = f'{weights_dir}/tokenizer.model'\n", - "\n", - "clear_output()" + "vocab_path = f'{weights_dir}/tokenizer.model'" ] }, { @@ -116,7 +128,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example." + "To interact with the Gemma model, you will use the Flax NNX `gemma` code from [`google/flax` examples on GitHub](https://github.com/google/flax/tree/main/examples/gemma). Since it is not exposed as a package, you need to use the following workaround to import from the Flax NNX `examples/gemma` on GitHub." ] }, { @@ -141,10 +153,9 @@ "source": [ "import sys\n", "import tempfile\n", - "\n", "with tempfile.TemporaryDirectory() as tmp:\n", - " # Here we create a temporary directory and clone the flax repo\n", - " # Then we append the examples/gemma folder to the path to load the gemma modules\n", + " # Create a temporary directory and clone the `flax` repo.\n", + " # Then, append the `examples/gemma` folder to the path for loading the `gemma` modules.\n", " ! git clone https://github.com/google/flax.git {tmp}/flax\n", " sys.path.append(f\"{tmp}/flax/examples/gemma\")\n", " import params as params_lib\n", @@ -157,9 +168,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Start Generating with Your Model\n", + "## Load and prepare the Gemma model\n", "\n", - "Load and prepare your LLM's checkpoint for use with Flax." + "First, load the Gemma model parameters for use with Flax." ] }, { @@ -170,7 +181,6 @@ }, "outputs": [], "source": [ - "# Load parameters\n", "params = params_lib.load_and_format_params(ckpt_path)" ] }, @@ -178,7 +188,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Load your tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library." + "Next, load the tokenizer file constructed using the [SentencePiece](https://github.com/google/sentencepiece) library." ] }, { @@ -208,7 +218,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Use the `transformer_lib.TransformerConfig.from_params` function to automatically load the correct configuration from a checkpoint. Note that the vocabulary size is smaller than the number of input embeddings due to unused tokens in this release." + "Then, use the Flax NNX [`gemma.transformer.TransformerConfig.from_params`](https://github.com/google/flax/blob/3f3c03b23d4fd3d85d1c5d4d97381a8a2c48b475/examples/gemma/transformer.py#L193) function to automatically load the correct configuration from a checkpoint.\n", + "\n", + "**Note:** The vocabulary size is smaller than the number of input embeddings due to unused tokens in this release." ] }, { @@ -250,7 +262,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Finally, build a sampler on top of your model and your tokenizer." + "## Perform sampling/inference\n", + "\n", + "Build a Flax NNX [`gemma.Sampler`](https://github.com/google/flax/blob/main/examples/gemma/sampler.py) on top of your model and tokenizer with the right parameter shapes." ] }, { @@ -261,7 +275,6 @@ }, "outputs": [], "source": [ - "# Create a sampler with the right param shapes.\n", "sampler = sampler_lib.Sampler(\n", " transformer=transformer,\n", " vocab=vocab,\n", @@ -272,7 +285,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent." + "You're ready to start sampling!\n", + "\n", + "**Note:** This Flax NNX [`gemma.Sampler`](https://github.com/google/flax/blob/main/examples/gemma/sampler.py) uses JAX’s [just-in-time (JIT) compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html), so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent.\n", + "\n", + "Write a prompt in `input_batch` and perform inference. Feel free to tweak `total_generation_steps` (the number of steps performed when generating a response)." ] }, { @@ -342,12 +359,12 @@ ], "source": [ "input_batch = [\n", - " \"\\n# Python program for implementation of Bubble Sort\\n\\ndef bubbleSort(arr):\",\n", - "]\n", + " \"\\n# Python program for implementation of Bubble Sort\\n\\ndef bubbleSort(arr):\",\n", + " ]\n", "\n", "out_data = sampler(\n", " input_strings=input_batch,\n", - " total_generation_steps=300, # number of steps performed when generating\n", + " total_generation_steps=300, # The number of steps performed when generating a response.\n", " )\n", "\n", "for input_string, out_string in zip(input_batch, out_data.text):\n", @@ -360,7 +377,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You should get an implementation of bubble sort." + "You should get a Python implementation of the bubble sort algorithm." ] } ], diff --git a/docs_nnx/guides/gemma.md b/docs_nnx/guides/gemma.md index e479201af0..1ef7aaf75f 100644 --- a/docs_nnx/guides/gemma.md +++ b/docs_nnx/guides/gemma.md @@ -8,26 +8,32 @@ jupytext: jupytext_version: 1.13.8 --- -# Example: Using Pretrained Gemma +# Example: Using pretrained Gemma for inference with Flax NNX -You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it. +This example shows how to use Flax NNX to load the [Gemma](https://ai.google.dev/gemma) open model files and use them to perform sampling/inference for generating text. You will use [Flax NNX `gemma` modules](https://github.com/google/flax/tree/main/examples/gemma) written with Flax and JAX for model parameter configuration and inference. + +> Gemma is a family of lightweight, state-of-the-art open models based on Google DeepMind’s [Gemini](https://deepmind.google/technologies/gemini/#introduction). Read more about [Gemma](https://blog.google/technology/developers/gemma-open-models/) and [Gemma 2](https://blog.google/technology/developers/google-gemma-2/). + +You are recommended to use [Google Colab](https://colab.research.google.com/) with access to A100 GPU acceleration to run the code. +++ ## Installation +Install the necessary dependencies, including `kagglehub`. + ```{code-cell} ipython3 ! pip install --no-deps -U flax ! pip install jaxtyping kagglehub treescope ``` -## Downloading the checkpoint +## Download the model -"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them: +To use Gemma model, you'll need a [Kaggle](https://www.kaggle.com/models/google/gemma/) account and API key: -1. Visit https://www.kaggle.com/ and create an account. -2. Go to your account settings, then the 'API' section. -3. Click 'Create new token' to download your key. +1. To create an account, visit [Kaggle](https://www.kaggle.com/) and click on 'Register'. +2. If/once you have an account, you need to sign in, go to your ['Settings'](https://www.kaggle.com/settings), and under 'API' click on 'Create New Token' to generate and download your Kaggle API key. +3. In [Google Colab](https://colab.research.google.com/), under 'Secrets' add your Kaggle username and API key, storing the username as `KAGGLE_USERNAME` and the key as `KAGGLE_KEY`. If you are using a [Kaggle Notebook](https://www.kaggle.com/code) for free TPU or other hardware acceleration, it has a key storage feature under 'Add-ons' > 'Secrets', along with instructions for accessing stored keys. Then run the cell below. @@ -36,13 +42,21 @@ import kagglehub kagglehub.login() ``` -If everything went well, you should see: -``` -Kaggle credentials set. -Kaggle credentials successfully validated. +If everything went well, it should say `Kaggle credentials set. Kaggle credentials successfully validated.`. + +**Note:** In Google Colab, you can instead authenticate into Kaggle using the code below after following the optional step 3 from above. + ``` +import os +from google.colab import userdata # `userdata` is a Colab API. + +os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME') +os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY') +``` -Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models. +Now, load the Gemma model you want to try. The code in the next cell utilizes [`kagglehub.model_download`](https://github.com/Kaggle/kagglehub/blob/8efe3e99477aa4f41885840de6903e61a49df4aa/src/kagglehub/models.py#L16) to download model files. + +**Note:** For larger models, such as `gemma 7b` and `gemma 7b-it` (instruct), you may require a hardware accelerator with plenty of memory, such as the NVIDIA A100. ```{code-cell} ipython3 from IPython.display import clear_output @@ -51,8 +65,6 @@ VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"} weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}') ckpt_path = f'{weights_dir}/{VARIANT}' vocab_path = f'{weights_dir}/tokenizer.model' - -clear_output() ``` ## Python imports @@ -62,15 +74,14 @@ from flax import nnx import sentencepiece as spm ``` -Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example. +To interact with the Gemma model, you will use the Flax NNX `gemma` code from [`google/flax` examples on GitHub](https://github.com/google/flax/tree/main/examples/gemma). Since it is not exposed as a package, you need to use the following workaround to import from the Flax NNX `examples/gemma` on GitHub. ```{code-cell} ipython3 import sys import tempfile - with tempfile.TemporaryDirectory() as tmp: - # Here we create a temporary directory and clone the flax repo - # Then we append the examples/gemma folder to the path to load the gemma modules + # Create a temporary directory and clone the `flax` repo. + # Then, append the `examples/gemma` folder to the path for loading the `gemma` modules. ! git clone https://github.com/google/flax.git {tmp}/flax sys.path.append(f"{tmp}/flax/examples/gemma") import params as params_lib @@ -79,18 +90,17 @@ with tempfile.TemporaryDirectory() as tmp: sys.path.pop(); ``` -## Start Generating with Your Model +## Load and prepare the Gemma model -Load and prepare your LLM's checkpoint for use with Flax. +First, load the Gemma model parameters for use with Flax. ```{code-cell} ipython3 :cellView: form -# Load parameters params = params_lib.load_and_format_params(ckpt_path) ``` -Load your tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library. +Next, load the tokenizer file constructed using the [SentencePiece](https://github.com/google/sentencepiece) library. ```{code-cell} ipython3 :cellView: form @@ -99,37 +109,44 @@ vocab = spm.SentencePieceProcessor() vocab.Load(vocab_path) ``` -Use the `transformer_lib.TransformerConfig.from_params` function to automatically load the correct configuration from a checkpoint. Note that the vocabulary size is smaller than the number of input embeddings due to unused tokens in this release. +Then, use the Flax NNX [`gemma.transformer.TransformerConfig.from_params`](https://github.com/google/flax/blob/3f3c03b23d4fd3d85d1c5d4d97381a8a2c48b475/examples/gemma/transformer.py#L193) function to automatically load the correct configuration from a checkpoint. + +**Note:** The vocabulary size is smaller than the number of input embeddings due to unused tokens in this release. ```{code-cell} ipython3 transformer = transformer_lib.Transformer.from_params(params) nnx.display(transformer) ``` -Finally, build a sampler on top of your model and your tokenizer. +## Perform sampling/inference + +Build a Flax NNX [`gemma.Sampler`](https://github.com/google/flax/blob/main/examples/gemma/sampler.py) on top of your model and tokenizer with the right parameter shapes. ```{code-cell} ipython3 :cellView: form -# Create a sampler with the right param shapes. sampler = sampler_lib.Sampler( transformer=transformer, vocab=vocab, ) ``` -You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent. +You're ready to start sampling! + +**Note:** This Flax NNX [`gemma.Sampler`](https://github.com/google/flax/blob/main/examples/gemma/sampler.py) uses JAX’s [just-in-time (JIT) compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html), so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent. + +Write a prompt in `input_batch` and perform inference. Feel free to tweak `total_generation_steps` (the number of steps performed when generating a response). ```{code-cell} ipython3 :cellView: form input_batch = [ - "\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):", -] + "\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):", + ] out_data = sampler( input_strings=input_batch, - total_generation_steps=300, # number of steps performed when generating + total_generation_steps=300, # The number of steps performed when generating a response. ) for input_string, out_string in zip(input_batch, out_data.text): @@ -138,4 +155,4 @@ for input_string, out_string in zip(input_batch, out_data.text): print(10*'#') ``` -You should get an implementation of bubble sort. +You should get a Python implementation of the bubble sort algorithm.