Skip to content

Commit

Permalink
Merge pull request #4325 from 8bitmp3:update-nnx-gemma
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707189220
  • Loading branch information
Flax Authors committed Dec 17, 2024
2 parents fc38f21 + a46578e commit cc29a7b
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 62 deletions.
81 changes: 49 additions & 32 deletions docs_nnx/guides/gemma.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,22 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example: Using Pretrained Gemma\n",
"# Example: Using pretrained Gemma for inference with Flax NNX\n",
"\n",
"You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it."
"This example shows how to use Flax NNX to load the [Gemma](https://ai.google.dev/gemma) open model files and use them to perform sampling/inference for generating text. You will use [Flax NNX `gemma` modules](https://github.com/google/flax/tree/main/examples/gemma) written with Flax and JAX for model parameter configuration and inference.\n",
"\n",
"> Gemma is a family of lightweight, state-of-the-art open models based on Google DeepMind’s [Gemini](https://deepmind.google/technologies/gemini/#introduction). Read more about [Gemma](https://blog.google/technology/developers/gemma-open-models/) and [Gemma 2](https://blog.google/technology/developers/google-gemma-2/).\n",
"\n",
"You are recommended to use [Google Colab](https://colab.research.google.com/) with access to A100 GPU acceleration to run the code."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Installation"
"## Installation\n",
"\n",
"Install the necessary dependencies, including `kagglehub`."
]
},
{
Expand All @@ -30,13 +36,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Downloading the checkpoint\n",
"## Download the model\n",
"\n",
"\"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:\n",
"To use Gemma model, you'll need a [Kaggle](https://www.kaggle.com/models/google/gemma/) account and API key:\n",
"\n",
"1. Visit https://www.kaggle.com/ and create an account.\n",
"2. Go to your account settings, then the 'API' section.\n",
"3. Click 'Create new token' to download your key.\n",
"1. To create an account, visit [Kaggle](https://www.kaggle.com/) and click on 'Register'.\n",
"2. If/once you have an account, you need to sign in, go to your ['Settings'](https://www.kaggle.com/settings), and under 'API' click on 'Create New Token' to generate and download your Kaggle API key.\n",
"3. In [Google Colab](https://colab.research.google.com/), under 'Secrets' add your Kaggle username and API key, storing the username as `KAGGLE_USERNAME` and the key as `KAGGLE_KEY`. If you are using a [Kaggle Notebook](https://www.kaggle.com/code) for free TPU or other hardware acceleration, it has a key storage feature under 'Add-ons' > 'Secrets', along with instructions for accessing stored keys.\n",
"\n",
"Then run the cell below."
]
Expand Down Expand Up @@ -70,13 +76,21 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"If everything went well, you should see:\n",
"```\n",
"Kaggle credentials set.\n",
"Kaggle credentials successfully validated.\n",
"If everything went well, it should say `Kaggle credentials set. Kaggle credentials successfully validated.`.\n",
"\n",
"**Note:** In Google Colab, you can instead authenticate into Kaggle using the code below after following the optional step 3 from above.\n",
"\n",
"```\n",
"import os\n",
"from google.colab import userdata # `userdata` is a Colab API.\n",
"\n",
"os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
"os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')\n",
"``` \n",
"\n",
"Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models."
"Now, load the Gemma model you want to try. The code in the next cell utilizes [`kagglehub.model_download`](https://github.com/Kaggle/kagglehub/blob/8efe3e99477aa4f41885840de6903e61a49df4aa/src/kagglehub/models.py#L16) to download model files.\n",
"\n",
"**Note:** For larger models, such as `gemma 7b` and `gemma 7b-it` (instruct), you may require a hardware accelerator with plenty of memory, such as the NVIDIA A100."
]
},
{
Expand All @@ -90,9 +104,7 @@
"VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n",
"weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')\n",
"ckpt_path = f'{weights_dir}/{VARIANT}'\n",
"vocab_path = f'{weights_dir}/tokenizer.model'\n",
"\n",
"clear_output()"
"vocab_path = f'{weights_dir}/tokenizer.model'"
]
},
{
Expand All @@ -116,7 +128,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example."
"To interact with the Gemma model, you will use the Flax NNX `gemma` code from [`google/flax` examples on GitHub](https://github.com/google/flax/tree/main/examples/gemma). Since it is not exposed as a package, you need to use the following workaround to import from the Flax NNX `examples/gemma` on GitHub."
]
},
{
Expand All @@ -141,10 +153,9 @@
"source": [
"import sys\n",
"import tempfile\n",
"\n",
"with tempfile.TemporaryDirectory() as tmp:\n",
" # Here we create a temporary directory and clone the flax repo\n",
" # Then we append the examples/gemma folder to the path to load the gemma modules\n",
" # Create a temporary directory and clone the `flax` repo.\n",
" # Then, append the `examples/gemma` folder to the path for loading the `gemma` modules.\n",
" ! git clone https://github.com/google/flax.git {tmp}/flax\n",
" sys.path.append(f\"{tmp}/flax/examples/gemma\")\n",
" import params as params_lib\n",
Expand All @@ -157,9 +168,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Start Generating with Your Model\n",
"## Load and prepare the Gemma model\n",
"\n",
"Load and prepare your LLM's checkpoint for use with Flax."
"First, load the Gemma model parameters for use with Flax."
]
},
{
Expand All @@ -170,15 +181,14 @@
},
"outputs": [],
"source": [
"# Load parameters\n",
"params = params_lib.load_and_format_params(ckpt_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load your tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library."
"Next, load the tokenizer file constructed using the [SentencePiece](https://github.com/google/sentencepiece) library."
]
},
{
Expand Down Expand Up @@ -208,7 +218,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Use the `transformer_lib.TransformerConfig.from_params` function to automatically load the correct configuration from a checkpoint. Note that the vocabulary size is smaller than the number of input embeddings due to unused tokens in this release."
"Then, use the Flax NNX [`gemma.transformer.TransformerConfig.from_params`](https://github.com/google/flax/blob/3f3c03b23d4fd3d85d1c5d4d97381a8a2c48b475/examples/gemma/transformer.py#L193) function to automatically load the correct configuration from a checkpoint.\n",
"\n",
"**Note:** The vocabulary size is smaller than the number of input embeddings due to unused tokens in this release."
]
},
{
Expand Down Expand Up @@ -250,7 +262,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, build a sampler on top of your model and your tokenizer."
"## Perform sampling/inference\n",
"\n",
"Build a Flax NNX [`gemma.Sampler`](https://github.com/google/flax/blob/main/examples/gemma/sampler.py) on top of your model and tokenizer with the right parameter shapes."
]
},
{
Expand All @@ -261,7 +275,6 @@
},
"outputs": [],
"source": [
"# Create a sampler with the right param shapes.\n",
"sampler = sampler_lib.Sampler(\n",
" transformer=transformer,\n",
" vocab=vocab,\n",
Expand All @@ -272,7 +285,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent."
"You're ready to start sampling!\n",
"\n",
"**Note:** This Flax NNX [`gemma.Sampler`](https://github.com/google/flax/blob/main/examples/gemma/sampler.py) uses JAX’s [just-in-time (JIT) compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html), so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent.\n",
"\n",
"Write a prompt in `input_batch` and perform inference. Feel free to tweak `total_generation_steps` (the number of steps performed when generating a response)."
]
},
{
Expand Down Expand Up @@ -342,12 +359,12 @@
],
"source": [
"input_batch = [\n",
" \"\\n# Python program for implementation of Bubble Sort\\n\\ndef bubbleSort(arr):\",\n",
"]\n",
" \"\\n# Python program for implementation of Bubble Sort\\n\\ndef bubbleSort(arr):\",\n",
" ]\n",
"\n",
"out_data = sampler(\n",
" input_strings=input_batch,\n",
" total_generation_steps=300, # number of steps performed when generating\n",
" total_generation_steps=300, # The number of steps performed when generating a response.\n",
" )\n",
"\n",
"for input_string, out_string in zip(input_batch, out_data.text):\n",
Expand All @@ -360,7 +377,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"You should get an implementation of bubble sort."
"You should get a Python implementation of the bubble sort algorithm."
]
}
],
Expand Down
77 changes: 47 additions & 30 deletions docs_nnx/guides/gemma.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,32 @@ jupytext:
jupytext_version: 1.13.8
---

# Example: Using Pretrained Gemma
# Example: Using pretrained Gemma for inference with Flax NNX

You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it.
This example shows how to use Flax NNX to load the [Gemma](https://ai.google.dev/gemma) open model files and use them to perform sampling/inference for generating text. You will use [Flax NNX `gemma` modules](https://github.com/google/flax/tree/main/examples/gemma) written with Flax and JAX for model parameter configuration and inference.

> Gemma is a family of lightweight, state-of-the-art open models based on Google DeepMind’s [Gemini](https://deepmind.google/technologies/gemini/#introduction). Read more about [Gemma](https://blog.google/technology/developers/gemma-open-models/) and [Gemma 2](https://blog.google/technology/developers/google-gemma-2/).
You are recommended to use [Google Colab](https://colab.research.google.com/) with access to A100 GPU acceleration to run the code.

+++

## Installation

Install the necessary dependencies, including `kagglehub`.

```{code-cell} ipython3
! pip install --no-deps -U flax
! pip install jaxtyping kagglehub treescope
```

## Downloading the checkpoint
## Download the model

"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:
To use Gemma model, you'll need a [Kaggle](https://www.kaggle.com/models/google/gemma/) account and API key:

1. Visit https://www.kaggle.com/ and create an account.
2. Go to your account settings, then the 'API' section.
3. Click 'Create new token' to download your key.
1. To create an account, visit [Kaggle](https://www.kaggle.com/) and click on 'Register'.
2. If/once you have an account, you need to sign in, go to your ['Settings'](https://www.kaggle.com/settings), and under 'API' click on 'Create New Token' to generate and download your Kaggle API key.
3. In [Google Colab](https://colab.research.google.com/), under 'Secrets' add your Kaggle username and API key, storing the username as `KAGGLE_USERNAME` and the key as `KAGGLE_KEY`. If you are using a [Kaggle Notebook](https://www.kaggle.com/code) for free TPU or other hardware acceleration, it has a key storage feature under 'Add-ons' > 'Secrets', along with instructions for accessing stored keys.

Then run the cell below.

Expand All @@ -36,13 +42,21 @@ import kagglehub
kagglehub.login()
```

If everything went well, you should see:
```
Kaggle credentials set.
Kaggle credentials successfully validated.
If everything went well, it should say `Kaggle credentials set. Kaggle credentials successfully validated.`.

**Note:** In Google Colab, you can instead authenticate into Kaggle using the code below after following the optional step 3 from above.

```
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
```

Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models.
Now, load the Gemma model you want to try. The code in the next cell utilizes [`kagglehub.model_download`](https://github.com/Kaggle/kagglehub/blob/8efe3e99477aa4f41885840de6903e61a49df4aa/src/kagglehub/models.py#L16) to download model files.

**Note:** For larger models, such as `gemma 7b` and `gemma 7b-it` (instruct), you may require a hardware accelerator with plenty of memory, such as the NVIDIA A100.

```{code-cell} ipython3
from IPython.display import clear_output
Expand All @@ -51,8 +65,6 @@ VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
ckpt_path = f'{weights_dir}/{VARIANT}'
vocab_path = f'{weights_dir}/tokenizer.model'
clear_output()
```

## Python imports
Expand All @@ -62,15 +74,14 @@ from flax import nnx
import sentencepiece as spm
```

Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example.
To interact with the Gemma model, you will use the Flax NNX `gemma` code from [`google/flax` examples on GitHub](https://github.com/google/flax/tree/main/examples/gemma). Since it is not exposed as a package, you need to use the following workaround to import from the Flax NNX `examples/gemma` on GitHub.

```{code-cell} ipython3
import sys
import tempfile
with tempfile.TemporaryDirectory() as tmp:
# Here we create a temporary directory and clone the flax repo
# Then we append the examples/gemma folder to the path to load the gemma modules
# Create a temporary directory and clone the `flax` repo.
# Then, append the `examples/gemma` folder to the path for loading the `gemma` modules.
! git clone https://github.com/google/flax.git {tmp}/flax
sys.path.append(f"{tmp}/flax/examples/gemma")
import params as params_lib
Expand All @@ -79,18 +90,17 @@ with tempfile.TemporaryDirectory() as tmp:
sys.path.pop();
```

## Start Generating with Your Model
## Load and prepare the Gemma model

Load and prepare your LLM's checkpoint for use with Flax.
First, load the Gemma model parameters for use with Flax.

```{code-cell} ipython3
:cellView: form
# Load parameters
params = params_lib.load_and_format_params(ckpt_path)
```

Load your tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library.
Next, load the tokenizer file constructed using the [SentencePiece](https://github.com/google/sentencepiece) library.

```{code-cell} ipython3
:cellView: form
Expand All @@ -99,37 +109,44 @@ vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
```

Use the `transformer_lib.TransformerConfig.from_params` function to automatically load the correct configuration from a checkpoint. Note that the vocabulary size is smaller than the number of input embeddings due to unused tokens in this release.
Then, use the Flax NNX [`gemma.transformer.TransformerConfig.from_params`](https://github.com/google/flax/blob/3f3c03b23d4fd3d85d1c5d4d97381a8a2c48b475/examples/gemma/transformer.py#L193) function to automatically load the correct configuration from a checkpoint.

**Note:** The vocabulary size is smaller than the number of input embeddings due to unused tokens in this release.

```{code-cell} ipython3
transformer = transformer_lib.Transformer.from_params(params)
nnx.display(transformer)
```

Finally, build a sampler on top of your model and your tokenizer.
## Perform sampling/inference

Build a Flax NNX [`gemma.Sampler`](https://github.com/google/flax/blob/main/examples/gemma/sampler.py) on top of your model and tokenizer with the right parameter shapes.

```{code-cell} ipython3
:cellView: form
# Create a sampler with the right param shapes.
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
)
```

You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent.
You're ready to start sampling!

**Note:** This Flax NNX [`gemma.Sampler`](https://github.com/google/flax/blob/main/examples/gemma/sampler.py) uses JAX’s [just-in-time (JIT) compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html), so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent.

Write a prompt in `input_batch` and perform inference. Feel free to tweak `total_generation_steps` (the number of steps performed when generating a response).

```{code-cell} ipython3
:cellView: form
input_batch = [
"\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
]
"\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
]
out_data = sampler(
input_strings=input_batch,
total_generation_steps=300, # number of steps performed when generating
total_generation_steps=300, # The number of steps performed when generating a response.
)
for input_string, out_string in zip(input_batch, out_data.text):
Expand All @@ -138,4 +155,4 @@ for input_string, out_string in zip(input_batch, out_data.text):
print(10*'#')
```

You should get an implementation of bubble sort.
You should get a Python implementation of the bubble sort algorithm.

0 comments on commit cc29a7b

Please sign in to comment.