Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] cleanup gemma notebook #4334

Merged
merged 1 commit into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 153 additions & 59 deletions docs_nnx/guides/gemma.ipynb

Large diffs are not rendered by default.

47 changes: 19 additions & 28 deletions docs_nnx/guides/gemma.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,7 @@ jupytext:
jupytext_version: 1.13.8
---

Copyright 2024 The Flax Authors.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

---

+++

# Getting Started with Gemma Sampling using NNX: A Step-by-Step Guide
# Example: Using Pretrained Gemma

You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it.

Expand Down Expand Up @@ -57,10 +45,14 @@ Kaggle credentials successfully validated.
Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models.

```{code-cell} ipython3
from IPython.display import clear_output

VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
ckpt_path = f'{weights_dir}/{VARIANT}'
vocab_path = f'{weights_dir}/tokenizer.model'

clear_output()
```

## Python imports
Expand All @@ -72,18 +64,19 @@ import sentencepiece as spm

Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example.

```{code-cell} ipython3
! git clone https://github.com/google/flax.git flax_examples
```

```{code-cell} ipython3
import sys

sys.path.append("./flax_examples/flax/nnx/examples/gemma")
import params as params_lib
import sampler as sampler_lib
import transformer as transformer_lib
sys.path.pop();
import tempfile

with tempfile.TemporaryDirectory() as tmp:
# Here we create a temporary directory and clone the flax repo
# Then we append the examples/gemma folder to the path to load the gemma modules
! git clone https://github.com/google/flax.git {tmp}/flax
sys.path.append(f"{tmp}/flax/examples/gemma")
import params as params_lib
import sampler as sampler_lib
import transformer as transformer_lib
sys.path.pop();
```

## Start Generating with Your Model
Expand Down Expand Up @@ -122,7 +115,6 @@ Finally, build a sampler on top of your model and your tokenizer.
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer'],
)
```

Expand All @@ -132,9 +124,8 @@ You're ready to start sampling ! This sampler uses just-in-time compilation, so
:cellView: form

input_batch = [
"\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
"What are the planets of the solar system?",
]
"\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
]

out_data = sampler(
input_strings=input_batch,
Expand All @@ -147,4 +138,4 @@ for input_string, out_string in zip(input_batch, out_data.text):
print(10*'#')
```

You should get an implementation of bubble sort and a description of the solar system.
You should get an implementation of bubble sort.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,18 @@ docs = [
"sphinx-design",
"jupytext==1.13.8",
"dm-haiku",

# Need to pin docutils to 0.16 to make bulleted lists appear correctly on
# ReadTheDocs: https://stackoverflow.com/a/68008428
"docutils==0.16",

# The next packages are for notebooks.
"matplotlib",
"scikit-learn",
# The next packages are used in testcode blocks.
"ml_collections",
# notebooks
"einops",
"kagglehub>=0.3.3",
"ipywidgets>=8.1.5",
]
dev = [
"pre-commit>=3.8.0",
Expand Down
Loading
Loading