Skip to content

Commit

Permalink
Port gemma sampling tutorial
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 643090672
  • Loading branch information
sibonaci authored and Flax Authors committed Jun 28, 2024
1 parent f610e6c commit 9b08759
Showing 1 changed file with 329 additions and 0 deletions.
329 changes: 329 additions & 0 deletions flax/nnx/examples/gemma/colabs/sampling_tutorial.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,329 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "SC77q_zBESaM"
},
"source": [
"Copyright 2024 The Flax Authors.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
"\n",
"http://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TpESp4p5ESaM"
},
"source": [
"# Getting Started with Gemma Sampling using NNX: A Step-by-Step Guide\n",
"\n",
"You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LtzOe_3XY9R5"
},
"source": [
"## Installation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iq2ebV_6YNiU"
},
"outputs": [],
"source": [
"! pip install --no-deps -U flax\n",
"! pip install jaxtyping kagglehub penzai"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QOzN-gxIYSB4"
},
"source": [
"## Downloading the checkpoint\n",
"\n",
"\"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:\n",
"\n",
"1. Visit https://www.kaggle.com/ and create an account.\n",
"2. Go to your account settings, then the 'API' section.\n",
"3. Click 'Create new token' to download your key.\n",
"\n",
"Then run the cell below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "likVQiEEYS5X"
},
"outputs": [],
"source": [
"import kagglehub\n",
"kagglehub.login()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QRxOFyGbYUjZ"
},
"source": [
"If everything went well, you should see:\n",
"```\n",
"Kaggle credentials set.\n",
"Kaggle credentials successfully validated.\n",
"```\n",
"\n",
"Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "O-sxcasvESaM"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n",
"weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')\n",
"ckpt_path = f'{weights_dir}/{VARIANT}'\n",
"vocab_path = f'{weights_dir}/tokenizer.model'"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "F85d5qWF2C5U"
},
"source": [
"## Python imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DkSYYJZ4HkbK"
},
"outputs": [],
"source": [
"from flax import nnx\n",
"import sentencepiece as spm"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AD2HuOLjHnfN"
},
"source": [
"Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UcIo5YtkHhLH"
},
"outputs": [],
"source": [
"! git clone https://github.com/google/flax.git flax_examples"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "F_sFY1SQ2O_k"
},
"outputs": [],
"source": [
"import sys\n",
"\n",
"sys.path.append(\"./flax_examples/flax/nnx/examples/gemma\")\n",
"import params as params_lib\n",
"import sampler as sampler_lib\n",
"import transformer as transformer_lib\n",
"sys.path.pop();"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4fDQsC87ESaN"
},
"source": [
"## Start Generating with Your Model\n",
"\n",
"Load and prepare your LLM's checkpoint for use with Flax."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "57nMYQ4HESaN"
},
"outputs": [],
"source": [
"# Load parameters\n",
"params = params_lib.load_and_format_params(ckpt_path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NWJ3UvHXESaN"
},
"source": [
"Load your tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "khXrjEF0ESaN"
},
"outputs": [],
"source": [
"vocab = spm.SentencePieceProcessor()\n",
"vocab.Load(vocab_path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tCRtZMg0ESaN"
},
"source": [
"Use the `transformer_lib.TransformerConfig.from_params` function to automatically load the correct configuration from a checkpoint. Note that the vocabulary size is smaller than the number of input embeddings due to unused tokens in this release."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Bfjh4d3x52PW"
},
"outputs": [],
"source": [
"transformer = transformer_lib.Transformer.from_params(params)\n",
"nnx.display(transformer)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KaU-X3_jESaN"
},
"source": [
"Finally, build a sampler on top of your model and your tokenizer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "bdstASGrESaN"
},
"outputs": [],
"source": [
"# Create a sampler with the right param shapes.\n",
"sampler = sampler_lib.Sampler(\n",
" transformer=transformer,\n",
" vocab=vocab,\n",
" params=params['transformer'],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "C1fLns-_ESaN"
},
"source": [
"You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "qA0BhNQvESaN"
},
"outputs": [],
"source": [
"input_batch = [\n",
" \"\\n# Python program for implementation of Bubble Sort\\n\\ndef bubbleSort(arr):\",\n",
" \"What are the planets of the solar system?\",\n",
" ]\n",
"\n",
"out_data = sampler(\n",
" input_strings=input_batch,\n",
" total_generation_steps=300, # number of steps performed when generating\n",
" )\n",
"\n",
"for input_string, out_string in zip(input_batch, out_data.text):\n",
" print(f\"Prompt:\\n{input_string}\\nOutput:\\n{out_string}\")\n",
" print()\n",
" print(10*'#')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tqbJ1SUcESaN"
},
"source": [
"You should get an implementation of bubble sort and a description of the solar system.\n"
]
}
],
"metadata": {
"accelerator": "TPU",
"colab": {
"gpuType": "V28",
"last_runtime": {
"build_target": "//learning/deepmind/evergreen/processors/examples/audio_streaming:notebook",
"kind": "private"
},
"private_outputs": true,
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

0 comments on commit 9b08759

Please sign in to comment.