Skip to content

Commit

Permalink
[nnx] cleanup gemma notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Oct 29, 2024
1 parent ab122af commit 158bcea
Show file tree
Hide file tree
Showing 4 changed files with 264 additions and 114 deletions.
212 changes: 153 additions & 59 deletions docs_nnx/guides/gemma.ipynb

Large diffs are not rendered by default.

47 changes: 19 additions & 28 deletions docs_nnx/guides/gemma.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,7 @@ jupytext:
jupytext_version: 1.13.8
---

Copyright 2024 The Flax Authors.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

---

+++

# Getting Started with Gemma Sampling using NNX: A Step-by-Step Guide
# Example: Using Pretrained Gemma

You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it.

Expand Down Expand Up @@ -57,10 +45,14 @@ Kaggle credentials successfully validated.
Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models.

```{code-cell} ipython3
from IPython.display import clear_output
VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
ckpt_path = f'{weights_dir}/{VARIANT}'
vocab_path = f'{weights_dir}/tokenizer.model'
clear_output()
```

## Python imports
Expand All @@ -72,18 +64,19 @@ import sentencepiece as spm

Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example.

```{code-cell} ipython3
! git clone https://github.com/google/flax.git flax_examples
```

```{code-cell} ipython3
import sys
sys.path.append("./flax_examples/flax/nnx/examples/gemma")
import params as params_lib
import sampler as sampler_lib
import transformer as transformer_lib
sys.path.pop();
import tempfile
with tempfile.TemporaryDirectory() as tmp:
# Here we create a temporary directory and clone the flax repo
# Then we append the examples/gemma folder to the path to load the gemma modules
! git clone https://github.com/google/flax.git {tmp}/flax
sys.path.append(f"{tmp}/flax/examples/gemma")
import params as params_lib
import sampler as sampler_lib
import transformer as transformer_lib
sys.path.pop();
```

## Start Generating with Your Model
Expand Down Expand Up @@ -122,7 +115,6 @@ Finally, build a sampler on top of your model and your tokenizer.
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer'],
)
```

Expand All @@ -132,9 +124,8 @@ You're ready to start sampling ! This sampler uses just-in-time compilation, so
:cellView: form
input_batch = [
"\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
"What are the planets of the solar system?",
]
"\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
]
out_data = sampler(
input_strings=input_batch,
Expand All @@ -147,4 +138,4 @@ for input_string, out_string in zip(input_batch, out_data.text):
print(10*'#')
```

You should get an implementation of bubble sort and a description of the solar system.
You should get an implementation of bubble sort.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,18 @@ docs = [
"sphinx-design",
"jupytext==1.13.8",
"dm-haiku",

# Need to pin docutils to 0.16 to make bulleted lists appear correctly on
# ReadTheDocs: https://stackoverflow.com/a/68008428
"docutils==0.16",

# The next packages are for notebooks.
"matplotlib",
"scikit-learn",
# The next packages are used in testcode blocks.
"ml_collections",
# notebooks
"einops",
"kagglehub>=0.3.3",
"ipywidgets>=8.1.5",
]
dev = [
"pre-commit>=3.8.0",
Expand Down
Loading

0 comments on commit 158bcea

Please sign in to comment.