From 564030314cf46dd83ba6ad6326b725036d1eae85 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 25 Oct 2024 13:04:49 +0100 Subject: [PATCH] [nnx] cleanup gemma notebook --- docs_nnx/guides/gemma.ipynb | 216 ++++++++++++++++++++++++++---------- docs_nnx/guides/gemma.md | 51 ++++----- pyproject.toml | 4 +- uv.lock | 115 ++++++++++++++----- 4 files changed, 268 insertions(+), 118 deletions(-) diff --git a/docs_nnx/guides/gemma.ipynb b/docs_nnx/guides/gemma.ipynb index 7230ca1b02..7eca9e3a3e 100644 --- a/docs_nnx/guides/gemma.ipynb +++ b/docs_nnx/guides/gemma.ipynb @@ -4,22 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Copyright 2024 The Flax Authors.\n", - "\n", - "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n", - "\n", - "http://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n", - "\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Getting Started with Gemma Sampling using NNX: A Step-by-Step Guide\n", + "# Getting Started with Gemma Sampling\n", "\n", "You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it." ] @@ -33,12 +18,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "! pip install --no-deps -U flax\n", - "! pip install jaxtyping kagglehub penzai" + "# ! pip install --no-deps -U flax\n", + "# ! pip install jaxtyping kagglehub penzai" ] }, { @@ -58,19 +43,22 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [ { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'kagglehub'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mkagglehub\u001b[39;00m\n\u001b[1;32m 2\u001b[0m kagglehub\u001b[38;5;241m.\u001b[39mlogin()\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'kagglehub'" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2e7cf9f0345845f1a3edc72fa4411eb4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(HTML(value='
(()=>{ if (customElements.get('treescope-container') === undefined) { class TreescopeContainer extends HTMLElement { constructor() { super(); this.attachShadow({mode: \"open\"}); this.defns = {}; this.state = {}; } } customElements.define(\"treescope-container\", TreescopeContainer); } if (customElements.get('treescope-run-here') === undefined) { class RunHere extends HTMLElement { constructor() { super() } connectedCallback() { const run = child => { const fn = new Function(child.textContent); child.textContent = \"\"; fn.call(this); this.remove(); }; const child = this.querySelector(\"script\"); if (child) { run(child); } else { new MutationObserver(()=>{ run(this.querySelector(\"script\")); }).observe(this, {childList: true}); } } } customElements.define(\"treescope-run-here\", RunHere); } })();
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "transformer = transformer_lib.Transformer.from_params(params)\n", "nnx.display(transformer)" @@ -217,7 +255,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": { "cellView": "form" }, @@ -227,7 +265,6 @@ "sampler = sampler_lib.Sampler(\n", " transformer=transformer,\n", " vocab=vocab,\n", - " params=params['transformer'],\n", ")" ] }, @@ -240,16 +277,73 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": { "cellView": "form" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prompt:\n", + "\n", + "# Python program for implementation of Bubble Sort\n", + "\n", + "def bubbleSort(arr):\n", + "Output:\n", + "\n", + " for i in range(len(arr)):\n", + " for j in range(len(arr) - i - 1):\n", + " if arr[j] > arr[j + 1]:\n", + " swap(arr, j, j + 1)\n", + "\n", + "\n", + "def swap(arr, i, j):\n", + " temp = arr[i]\n", + " arr[i] = arr[j]\n", + " arr[j] = temp\n", + "\n", + "\n", + "# Driver code\n", + "arr = [5, 2, 8, 3, 1, 9]\n", + "print(\"Unsorted array:\")\n", + "print(arr)\n", + "bubbleSort(arr)\n", + "print(\"Sorted array:\")\n", + "print(arr)\n", + "\n", + "\n", + "# Time complexity of Bubble sort O(n^2)\n", + "# where n is the length of the array\n", + "\n", + "\n", + "# Space complexity of Bubble sort O(1)\n", + "# as it only requires constant extra space for the swap operation\n", + "\n", + "\n", + "# This program uses the bubble sort algorithm to sort the given array in ascending order.\n", + "\n", + "```python\n", + "# This program uses the bubble sort algorithm to sort the given array in ascending order.\n", + "\n", + "def bubbleSort(arr):\n", + " for i in range(len(arr)):\n", + " for j in range(len(arr) - i - 1):\n", + " if arr[j] > arr[j + 1]:\n", + " swap(arr, j, j + 1)\n", + "\n", + "\n", + "def swap(\n", + "\n", + "##########\n" + ] + } + ], "source": [ "input_batch = [\n", - " \"\\n# Python program for implementation of Bubble Sort\\n\\ndef bubbleSort(arr):\",\n", - " \"What are the planets of the solar system?\",\n", - " ]\n", + " \"\\n# Python program for implementation of Bubble Sort\\n\\ndef bubbleSort(arr):\",\n", + "]\n", "\n", "out_data = sampler(\n", " input_strings=input_batch,\n", @@ -266,7 +360,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You should get an implementation of bubble sort and a description of the solar system." + "You should get an implementation of bubble sort." ] } ], diff --git a/docs_nnx/guides/gemma.md b/docs_nnx/guides/gemma.md index 69ff17acb5..af6d761048 100644 --- a/docs_nnx/guides/gemma.md +++ b/docs_nnx/guides/gemma.md @@ -8,19 +8,7 @@ jupytext: jupytext_version: 1.13.8 --- -Copyright 2024 The Flax Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - ---- - -+++ - -# Getting Started with Gemma Sampling using NNX: A Step-by-Step Guide +# Getting Started with Gemma Sampling You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it. @@ -29,8 +17,8 @@ You will find in this colab a detailed tutorial explaining how to use NNX to loa ## Installation ```{code-cell} ipython3 -! pip install --no-deps -U flax -! pip install jaxtyping kagglehub penzai +# ! pip install --no-deps -U flax +# ! pip install jaxtyping kagglehub penzai ``` ## Downloading the checkpoint @@ -57,10 +45,14 @@ Kaggle credentials successfully validated. Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models. ```{code-cell} ipython3 +from IPython.display import clear_output + VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"} weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}') ckpt_path = f'{weights_dir}/{VARIANT}' vocab_path = f'{weights_dir}/tokenizer.model' + +clear_output() ``` ## Python imports @@ -72,18 +64,19 @@ import sentencepiece as spm Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example. -```{code-cell} ipython3 -! git clone https://github.com/google/flax.git flax_examples -``` - ```{code-cell} ipython3 import sys - -sys.path.append("./flax_examples/flax/nnx/examples/gemma") -import params as params_lib -import sampler as sampler_lib -import transformer as transformer_lib -sys.path.pop(); +import tempfile + +with tempfile.TemporaryDirectory() as tmp: + # Here we create a temporary directory and clone the flax repo + # Then we append the examples/gemma folder to the path to load the gemma modules + ! git clone https://github.com/google/flax.git {tmp}/flax + sys.path.append(f"{tmp}/flax/examples/gemma") + import params as params_lib + import sampler as sampler_lib + import transformer as transformer_lib + sys.path.pop(); ``` ## Start Generating with Your Model @@ -122,7 +115,6 @@ Finally, build a sampler on top of your model and your tokenizer. sampler = sampler_lib.Sampler( transformer=transformer, vocab=vocab, - params=params['transformer'], ) ``` @@ -132,9 +124,8 @@ You're ready to start sampling ! This sampler uses just-in-time compilation, so :cellView: form input_batch = [ - "\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):", - "What are the planets of the solar system?", - ] + "\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):", +] out_data = sampler( input_strings=input_batch, @@ -147,4 +138,4 @@ for input_string, out_string in zip(input_batch, out_data.text): print(10*'#') ``` -You should get an implementation of bubble sort and a description of the solar system. +You should get an implementation of bubble sort. diff --git a/pyproject.toml b/pyproject.toml index 8381caaabc..baab1da052 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,11 +76,9 @@ docs = [ "sphinx-design", "jupytext==1.13.8", "dm-haiku", - # Need to pin docutils to 0.16 to make bulleted lists appear correctly on # ReadTheDocs: https://stackoverflow.com/a/68008428 "docutils==0.16", - # The next packages are for notebooks. "matplotlib", "scikit-learn", @@ -88,6 +86,8 @@ docs = [ "ml_collections", # notebooks "einops", + "kagglehub>=0.3.3", + "ipywidgets>=8.1.5", ] dev = [ "pre-commit>=3.8.0", diff --git a/uv.lock b/uv.lock index bb7c89e2db..6c9e67edfa 100644 --- a/uv.lock +++ b/uv.lock @@ -504,7 +504,7 @@ wheels = [ [package.optional-dependencies] toml = [ - { name = "tomli", marker = "python_full_version <= '3.11'" }, + { name = "tomli", marker = "python_full_version == '3.11'" }, ] [[package]] @@ -794,7 +794,9 @@ docs = [ { name = "einops" }, { name = "ipykernel" }, { name = "ipython-genutils" }, + { name = "ipywidgets" }, { name = "jupytext" }, + { name = "kagglehub" }, { name = "matplotlib" }, { name = "ml-collections" }, { name = "myst-nb" }, @@ -842,11 +844,13 @@ requires-dist = [ { name = "gymnasium", extras = ["accept-rom-license", "atari"], marker = "extra == 'testing'" }, { name = "ipykernel", marker = "extra == 'docs'" }, { name = "ipython-genutils", marker = "extra == 'docs'" }, + { name = "ipywidgets", marker = "extra == 'docs'", specifier = ">=8.1.5" }, { name = "jax", specifier = ">=0.4.27" }, { name = "jaxlib", marker = "extra == 'testing'" }, { name = "jaxtyping", marker = "extra == 'testing'" }, { name = "jraph", marker = "extra == 'testing'", specifier = ">=0.0.6.dev0" }, { name = "jupytext", marker = "extra == 'docs'", specifier = "==1.13.8" }, + { name = "kagglehub", marker = "extra == 'docs'", specifier = ">=0.3.3" }, { name = "matplotlib", marker = "extra == 'all'" }, { name = "matplotlib", marker = "extra == 'docs'" }, { name = "ml-collections", marker = "extra == 'docs'" }, @@ -1217,9 +1221,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/bc/9bd3b5c2b4774d5f33b2d544f1460be9df7df2fe42f352135381c347c69a/ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8", size = 26343 }, ] +[[package]] +name = "ipywidgets" +version = "8.1.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "comm" }, + { name = "ipython" }, + { name = "jupyterlab-widgets" }, + { name = "traitlets" }, + { name = "widgetsnbextension" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c7/4c/dab2a281b07596a5fc220d49827fe6c794c66f1493d7a74f1df0640f2cc5/ipywidgets-8.1.5.tar.gz", hash = "sha256:870e43b1a35656a80c18c9503bbf2d16802db1cb487eec6fab27d683381dde17", size = 116723 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/22/2d/9c0b76f2f9cc0ebede1b9371b6f317243028ed60b90705863d493bae622e/ipywidgets-8.1.5-py3-none-any.whl", hash = "sha256:3290f526f87ae6e77655555baba4f36681c555b8bdbbff430b70e52c34c86245", size = 139767 }, +] + [[package]] name = "jax" -version = "0.4.34" +version = "0.4.35" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxlib" }, @@ -1228,14 +1248,14 @@ dependencies = [ { name = "opt-einsum" }, { name = "scipy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/19/6a/cacfcdf77841a4562e555ef35e0dbc5f8ca79c9f1010aaa4cf3973e79c69/jax-0.4.34.tar.gz", hash = "sha256:44196854f40c5f9cea3142824b9f1051f85afc3fcf7593ec5479fc8db01c58db", size = 1848472 } +sdist = { url = "https://files.pythonhosted.org/packages/e3/34/21da583b9596e72bb8e95b6197dee0a44b96b9ea2c147fccabd43ca5515b/jax-0.4.35.tar.gz", hash = "sha256:c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e", size = 1861189 } wheels = [ - { url = "https://files.pythonhosted.org/packages/06/f3/c499d358dd7f267a63d7d38ef54aadad82e28d2c28bafff15360c3091946/jax-0.4.34-py3-none-any.whl", hash = "sha256:b957ca1fc91f7343f91a186af9f19c7f342c946f95a8c11c7f1e5cdfe2e58d9e", size = 2144294 }, + { url = "https://files.pythonhosted.org/packages/62/20/6c57c50c0ccc645fea1895950f1e5cd02f961ee44b3ffe83617fa46b0c1d/jax-0.4.35-py3-none-any.whl", hash = "sha256:fa99e909a31424abfec750019a6dd36f6acc18a6e7d40e2c0086b932cc351325", size = 2158621 }, ] [[package]] name = "jaxlib" -version = "0.4.34" +version = "0.4.35" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "ml-dtypes" }, @@ -1243,26 +1263,26 @@ dependencies = [ { name = "scipy" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/24/31/2e254fe2fc23201775a7d0ccd1bcde892cfa349eb805744b81b15e0dcf74/jaxlib-0.4.34-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:b7a212a3cb5c6acc201c32ae4f4b5f5a9ac09457fbb77ba8db5ce7e7d4adc214", size = 87399257 }, - { url = "https://files.pythonhosted.org/packages/1e/67/6a344c357caad33e84b871925cd043b4218fc13a427266d1a1dedcb1c095/jaxlib-0.4.34-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:45d719a2ce0ebf21255a277b71d756f3609b7b5be70cddc5d88fd58c35219de0", size = 67617952 }, - { url = "https://files.pythonhosted.org/packages/dd/ea/12c836126419ca80248228f2236831617eedb1e3640c34c942606f33bb08/jaxlib-0.4.34-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:3e60bc826933082e99b19b87c21818a8d26fcdb01f418d47cedff554746fd6cc", size = 69391770 }, - { url = "https://files.pythonhosted.org/packages/e4/b0/a5bd34643c070e50829beec217189eab1acdfea334df1f9ddb4e5f8bec0f/jaxlib-0.4.34-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:d840e64b85f8865404d6d225b9bb340e158df1457152a361b05680e24792b232", size = 86094116 }, - { url = "https://files.pythonhosted.org/packages/d8/c9/35a4233fe74ddd5aabe89aac1b3992b0e463982564252d21fd263d4d9992/jaxlib-0.4.34-cp310-cp310-win_amd64.whl", hash = "sha256:b0001c8f0e2b1c7bc99e4f314b524a340d25653505c1a1484d4041a9d3617f6f", size = 55206389 }, - { url = "https://files.pythonhosted.org/packages/bf/14/00a3385532d72ab51bd8e9f8c3e19a2e257667955565e9fc10236771dd06/jaxlib-0.4.34-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:8ee3f93836e53c86556ccd9449a4ea43516ee05184d031a71dd692e81259f7d9", size = 87420889 }, - { url = "https://files.pythonhosted.org/packages/66/78/d1535ee73fe505dc6c8831c19c4846afdce7df5acefb9f8ee885aa73d700/jaxlib-0.4.34-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c9d3adcae43a33aad4332be9c2aedc5ef751d1e755f917a5afb30c7872eacaa8", size = 67635880 }, - { url = "https://files.pythonhosted.org/packages/aa/06/3e09e794acf308e170905d732eca0d041449503c47505cc22e8ef78a989d/jaxlib-0.4.34-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:571ef03259835458111596a71a2f4a6fabf4ec34595df4cea555035362ac5bf0", size = 69421901 }, - { url = "https://files.pythonhosted.org/packages/c7/d0/6bc81c0b1d507f403e6085ce76a429e6d7f94749d742199252e299dd1424/jaxlib-0.4.34-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:3bcfa639ca3cfaf86c8ceebd5fc0d47300fd98a078014a1d0cc03133e1523d5f", size = 86114491 }, - { url = "https://files.pythonhosted.org/packages/9d/5d/7e71019af5f6fdebe6c10eab97d01f44b931d94609330da9e142cb155f8c/jaxlib-0.4.34-cp311-cp311-win_amd64.whl", hash = "sha256:133070d4fec5525ffea4dc72956398c1cf647a04dcb37f8a935ee82af78d9965", size = 55241262 }, - { url = "https://files.pythonhosted.org/packages/bc/42/5038983664494dfb50f8669a662d965d7ea62f9250e40d8cd36dcf9ac3dd/jaxlib-0.4.34-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:c7b3e724a30426a856070aba0192b5d199e95b4411070e7ad96ad8b196877b10", size = 87473956 }, - { url = "https://files.pythonhosted.org/packages/87/2e/8a75d3107c019c370c50c01acc205da33f9d6fba830950401a772a8e9f6d/jaxlib-0.4.34-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:096f0ca309d41fa692a9d1f2f9baab1c5c8ca0749876ebb3f748e738a27c7ff4", size = 67650276 }, - { url = "https://files.pythonhosted.org/packages/af/09/cceae2d251a506b4297679d10ee9f5e905a6b992b0687d553c9470ffd1db/jaxlib-0.4.34-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:1a30771d85fa77f9ab8f18e63240f455ab3a3f87660ed7b8d5eea6ceecbe5c1e", size = 69431284 }, - { url = "https://files.pythonhosted.org/packages/e7/0d/4faf839e3c8ce2a5b615df64427be3e870899c72c0ebfb5859348150aba1/jaxlib-0.4.34-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:48272e9034ff868d4328cf0055a07882fd2be93f59dfb6283af7de491f9d1290", size = 86151183 }, - { url = "https://files.pythonhosted.org/packages/a4/bc/a38f99071fca6cc31ae949e508a23b0de5de559da594443bb625a1adb8f3/jaxlib-0.4.34-cp312-cp312-win_amd64.whl", hash = "sha256:901cb4040ed24eae40071d8114ea8d10dff436277fa74a1a5b9e7206f641151c", size = 55278745 }, - { url = "https://files.pythonhosted.org/packages/21/4e/fab0606683af7aa9284a32d2b188ff132cffb0ee3ea04d941a547eb776d1/jaxlib-0.4.34-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:72e22e99a5dc890a64443c3fc12f13f20091f578c405a76de077ba42b4c62cd7", size = 87474367 }, - { url = "https://files.pythonhosted.org/packages/3e/1b/709be16d543a3db5b471ee5e7d089c57484c386b08499923e43bd8da5d0b/jaxlib-0.4.34-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c303f5acaf6c56ce5ff133a923c9b6247bdebedde15bd2c893c24be4d8f71306", size = 67651281 }, - { url = "https://files.pythonhosted.org/packages/85/9e/f3801096cd4a2c764af7a1f6b683c769706602ea72b27ec35bacfcc4cd4f/jaxlib-0.4.34-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7be673a876ebd1aef440fb7e3ebaf99a91abeb550c9728c644b7d7c7b5d7c108", size = 69432987 }, - { url = "https://files.pythonhosted.org/packages/e6/79/61301f55b24c3a898ef9bc4e13600b66e3f838623fc6f87648ac1ccbca01/jaxlib-0.4.34-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:87f25a477cd279840e53718403f97092eba0e8a945fcab47bcf435b6f9119dda", size = 86152550 }, - { url = "https://files.pythonhosted.org/packages/16/b0/e682d02126e0062b58dec0f0851048592396f74c24b4a4412dce4ddbbadb/jaxlib-0.4.34-cp313-cp313-win_amd64.whl", hash = "sha256:6b43a974c5d91a19912d138f2658dd8dbb7d30dcdff5c961d896c673e872b611", size = 55279410 }, + { url = "https://files.pythonhosted.org/packages/f4/67/c025520d2c548569f73cd68b885862e56e8946a10c9d43834460007c2671/jaxlib-0.4.35-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:907e548ad6ce53b242a55c5f36c2a2a4c37d38f6cd8c356fc550a2f18ab0e82f", size = 87876323 }, + { url = "https://files.pythonhosted.org/packages/a8/e7/7962830da208ad3fa6596dc2df77824da9bc0196b549ae549ce53d1d1de1/jaxlib-0.4.35-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f8c499644660aefd0ae2ee31039da6d4df0f26d0ee67ba9fb316183a5304288", size = 68025360 }, + { url = "https://files.pythonhosted.org/packages/fa/91/2a1a1551845dd634bb1647fd37157f6f4ea71481e63f4100d08923c29d22/jaxlib-0.4.35-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:5d2d8a5b89d334b875ede98d7fcee946bebef1a1b5abd118ff543bcef4ab09f5", size = 70588250 }, + { url = "https://files.pythonhosted.org/packages/d7/16/6a9053d8b4b2790e330f9143030ab9d456556da5d98887b7e071bd08ffed/jaxlib-0.4.35-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:91a283a72263feebe0d110d1136df96950744e47530f12df42c03f36888c971e", size = 87282292 }, + { url = "https://files.pythonhosted.org/packages/6c/a9/b6bdff31e21a485190985dccbdd5ae1130fe2e4af826c83c10ae1d0d14a9/jaxlib-0.4.35-cp310-cp310-win_amd64.whl", hash = "sha256:d210bab7e1ce0b2f2e568548b3903ea6aec349019fc1398cd2a0c069e8342e62", size = 56484115 }, + { url = "https://files.pythonhosted.org/packages/ee/01/4be899cf8d05920877b46b8acf51083dedaba206e951d88ddf7b098bed80/jaxlib-0.4.35-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:7f8bfc90f68857b223b7e38a9bdf466a4f1cb405c9a4aa11698dc9ab7b35c29b", size = 87895891 }, + { url = "https://files.pythonhosted.org/packages/55/77/ca1e70bc3a161c1043d2e169a618263f865bf959433e5bf40ea56ec13e5e/jaxlib-0.4.35-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:261570c94b169dc90f3af903282eeec856b52736c0944d243504ced93d19b217", size = 68045181 }, + { url = "https://files.pythonhosted.org/packages/cd/2f/a8f4c441718558406cf27749415d1aa14bdac9dbd06fadb7bb4742c53637/jaxlib-0.4.35-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e1cee6dc291251f3fb6b0127fdd96c0439ac1ea97e01571d06910df72d6ac6e1", size = 70614621 }, + { url = "https://files.pythonhosted.org/packages/c8/a6/1abe8d682d46cf2989f9c4928866ae80c30a54d607221a262cff8a5d9366/jaxlib-0.4.35-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:bc9eafba001ff8569cfa252fe7f04ba553622702b4b473b656dd0866edf6b8d4", size = 87309681 }, + { url = "https://files.pythonhosted.org/packages/7d/7c/73a4c4a34f2bbfce63e8baefee11753b0d58a71e0d2c33f210e00edba3cb/jaxlib-0.4.35-cp311-cp311-win_amd64.whl", hash = "sha256:0fd990354d5623d3a34493fcd7213493390dbf5039bea19b62e2aaee1049eda9", size = 56520062 }, + { url = "https://files.pythonhosted.org/packages/ef/1c/901a59d9bc051b2a991163c46f58c50724d18ab25e71fa5556e5f68b84a4/jaxlib-0.4.35-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:b44f3e6e9fb748bb43df914356cf9d0d0c9a6e446a12c21fe843db25ed0df65f", size = 87936215 }, + { url = "https://files.pythonhosted.org/packages/da/ff/38030bc3c96fae50f629830afe9c63a8a040aae332f6e28cd529397ba114/jaxlib-0.4.35-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:504d0a2e2117724359d99d7e3663022686dcdddd85aa14bdad02008d444481ad", size = 68063993 }, + { url = "https://files.pythonhosted.org/packages/55/27/83b6d2a1b380e20610e1449231c30c948cc4352c9a7e74a0d0d01bff8339/jaxlib-0.4.35-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:187cb6929dc139b75d952d67c33118473c1b4105525a3e5607f064e7b8efdc74", size = 70629159 }, + { url = "https://files.pythonhosted.org/packages/6d/3f/5ac6dfef795f4f58645ccff0ebd65234cb77d7dbf1bdd2b6c49a677b64b0/jaxlib-0.4.35-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:04d1db3bf0050d120238bfb9b686b58fefcc4d9dd9e2d96aecd3f68a1f1f5e0a", size = 87349348 }, + { url = "https://files.pythonhosted.org/packages/97/05/093b3c511837ba514f0b97581f7b21e1bb79768b8b9c29013a406b00d484/jaxlib-0.4.35-cp312-cp312-win_amd64.whl", hash = "sha256:dddffce48d7e6057008999aed2d8a9daecc57a48c45a4f8c475e00880eb2e41d", size = 56561679 }, + { url = "https://files.pythonhosted.org/packages/99/40/aedef37c44797779a01bf71a392145724e3e0fc369e5f08f55c3c82de733/jaxlib-0.4.35-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:14aeac3fea2ca1d5afb1878f72470b159cc89adb2633c5f0686f5d7c39f2ac18", size = 87934299 }, + { url = "https://files.pythonhosted.org/packages/94/42/62d4d13078886f4d22ca95ca07135f740cf9dd925f4cdb23d7b7d432403b/jaxlib-0.4.35-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e8c9579e20d5ecdc4f61336cdd032710cb8c38d5ae9c4fce0cf9ea031cef21cb", size = 68065641 }, + { url = "https://files.pythonhosted.org/packages/4d/a0/87a4eae3811ce7014ce2c59b811ad930273bfbbb8252ba78079606f9ec40/jaxlib-0.4.35-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7b11ad7c13f7f96f36efd303711ecac425f19ca2ddf65cf1be1541167a959ee5", size = 70629568 }, + { url = "https://files.pythonhosted.org/packages/b3/89/59d6fe10e30ff5a48a73319bafa9a11cd999f91a47e4f08f7dc3651c899c/jaxlib-0.4.35-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:0be3cf9df879d9ae1b5b92fc281f77d21f522fcbae1a48a02661026bbd9b9309", size = 87350315 }, + { url = "https://files.pythonhosted.org/packages/79/d7/d7600c65fe0412a6584d84ca172816a8cf19965219ee3dd59542447ffe2f/jaxlib-0.4.35-cp313-cp313-win_amd64.whl", hash = "sha256:330c090bb9af413f552d8a92d097e50baec6b75823430fb2966a49f5298d4c43", size = 56562022 }, ] [[package]] @@ -1412,6 +1432,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c9/fb/108ecd1fe961941959ad0ee4e12ee7b8b1477247f30b1fdfd83ceaf017f0/jupyter_core-5.7.2-py3-none-any.whl", hash = "sha256:4f7315d2f6b4bcf2e3e7cb6e46772eba760ae459cd1f59d29eb57b0a01bd7409", size = 28965 }, ] +[[package]] +name = "jupyterlab-widgets" +version = "3.0.13" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/59/73/fa26bbb747a9ea4fca6b01453aa22990d52ab62dd61384f1ac0dc9d4e7ba/jupyterlab_widgets-3.0.13.tar.gz", hash = "sha256:a2966d385328c1942b683a8cd96b89b8dd82c8b8f81dda902bb2bc06d46f5bed", size = 203556 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a9/93/858e87edc634d628e5d752ba944c2833133a28fa87bb093e6832ced36a3e/jupyterlab_widgets-3.0.13-py3-none-any.whl", hash = "sha256:e3cda2c233ce144192f1e29914ad522b2f4c40e77214b0cc97377ca3d323db54", size = 214392 }, +] + [[package]] name = "jupytext" version = "1.13.8" @@ -1428,6 +1457,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/e3/538509410372acd6d41f12c028dfc75ebddfbc4f7544f933bff7b5cc3e97/jupytext-1.13.8-py3-none-any.whl", hash = "sha256:625d2d2012763cc87d3f0dd60383516cec442c11894f53ad0c5ee5aa2a52caa2", size = 297592 }, ] +[[package]] +name = "kagglehub" +version = "0.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, + { name = "requests" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b0/69/3e3d9533b44535903011157102bcf08ad4124f12b5d2c294850e6fad5032/kagglehub-0.3.3.tar.gz", hash = "sha256:0777d4d1ee1e59d4125b14ba62a46b2eadedb68bc6517479f6fb02a522a262f8", size = 60620 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c4/d1/4ab25019a168f5c414202f124d156e11ac79f07845d67288929311f1b1b2/kagglehub-0.3.3-py3-none-any.whl", hash = "sha256:5370acde855d04b6d8a7bc242edff339266913fffc8b198d31859b25b7d095f7", size = 42852 }, +] + [[package]] name = "keras" version = "3.5.0" @@ -2014,6 +2057,7 @@ version = "12.1.3.1" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/37/6d/121efd7382d5b0284239f4ab1fc1590d86d34ed4a4a2fdb13b30ca8e5740/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728", size = 410594774 }, + { url = "https://files.pythonhosted.org/packages/c5/ef/32a375b74bea706c93deea5613552f7c9104f961b21df423f5887eca713b/nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906", size = 439918445 }, ] [[package]] @@ -2022,6 +2066,7 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/7e/00/6b218edd739ecfc60524e585ba8e6b00554dd908de2c9c66c1af3e44e18d/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e", size = 14109015 }, + { url = "https://files.pythonhosted.org/packages/d0/56/0021e32ea2848c24242f6b56790bd0ccc8bf99f973ca790569c6ca028107/nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4", size = 10154340 }, ] [[package]] @@ -2030,6 +2075,7 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/b6/9f/c64c03f49d6fbc56196664d05dba14e3a561038a81a638eeb47f4d4cfd48/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2", size = 23671734 }, + { url = "https://files.pythonhosted.org/packages/ad/1d/f76987c4f454eb86e0b9a0e4f57c3bf1ac1d13ad13cd1a4da4eb0e0c0ce9/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed", size = 19331863 }, ] [[package]] @@ -2038,6 +2084,7 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/eb/d5/c68b1d2cdfcc59e72e8a5949a37ddb22ae6cade80cd4a57a84d4c8b55472/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40", size = 823596 }, + { url = "https://files.pythonhosted.org/packages/9f/e2/7a2b4b5064af56ea8ea2d8b2776c0f2960d95c88716138806121ae52a9c9/nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344", size = 821226 }, ] [[package]] @@ -2049,6 +2096,7 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, + { url = "https://files.pythonhosted.org/packages/3f/d0/f90ee6956a628f9f04bf467932c0a25e5a7e706a684b896593c06c82f460/nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a", size = 679925892 }, ] [[package]] @@ -2057,6 +2105,7 @@ version = "11.0.2.54" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/86/94/eb540db023ce1d162e7bea9f8f5aa781d57c65aed513c33ee9a5123ead4d/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56", size = 121635161 }, + { url = "https://files.pythonhosted.org/packages/f7/57/7927a3aa0e19927dfed30256d1c854caf991655d847a4e7c01fe87e3d4ac/nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253", size = 121344196 }, ] [[package]] @@ -2065,6 +2114,7 @@ version = "10.3.2.106" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/44/31/4890b1c9abc496303412947fc7dcea3d14861720642b49e8ceed89636705/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0", size = 56467784 }, + { url = "https://files.pythonhosted.org/packages/5c/97/4c9c7c79efcdf5b70374241d48cf03b94ef6707fd18ea0c0f53684931d0b/nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a", size = 55995813 }, ] [[package]] @@ -2078,6 +2128,7 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 }, + { url = "https://files.pythonhosted.org/packages/b8/80/8fca0bf819122a631c3976b6fc517c1b10741b643b94046bd8dd451522c5/nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5", size = 121643081 }, ] [[package]] @@ -2089,6 +2140,7 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 }, + { url = "https://files.pythonhosted.org/packages/0f/95/48fdbba24c93614d1ecd35bc6bdc6087bd17cbacc3abc4b05a9c2a1ca232/nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a", size = 195414588 }, ] [[package]] @@ -2107,6 +2159,7 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/81/b3/e456a1b2d499bb84bdc6670bfbcf41ff3bac58bd2fae6880d62834641558/nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_aarch64.whl", hash = "sha256:84fb38465a5bc7c70cbc320cfd0963eb302ee25a5e939e9f512bbba55b6072fb", size = 19252608 }, { url = "https://files.pythonhosted.org/packages/59/65/7ff0569494fbaea45ad2814972cc88da843d53cc96eb8554fcd0908941d9/nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl", hash = "sha256:562ab97ea2c23164823b2a89cb328d01d45cb99634b8c65fe7cd60d14562bd79", size = 19724950 }, + { url = "https://files.pythonhosted.org/packages/cb/ef/8f96c82e1cfcf6d5b770f7b043c3cc24841fc247b37629a7cc643dbf72a1/nvidia_nvjitlink_cu12-12.6.20-py3-none-win_amd64.whl", hash = "sha256:ed3c43a17f37b0c922a919203d2d36cbef24d41cc3e6b625182f8b58203644f6", size = 162012830 }, ] [[package]] @@ -2115,6 +2168,7 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/da/d3/8057f0587683ed2fcd4dbfbdfdfa807b9160b809976099d36b8f60d08f03/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5", size = 99138 }, + { url = "https://files.pythonhosted.org/packages/b8/d7/bd7cb2d95ac6ac6e8d05bfa96cdce69619f1ef2808e072919044c2d47a8c/nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82", size = 66307 }, ] [[package]] @@ -3403,7 +3457,9 @@ dependencies = [ { name = "tensorflow", marker = "platform_system != 'Darwin'" }, ] wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/e3/33fc5957790cf4710e0a9116cf37c0a881eda673e5f8b569bfff5654a48c/tensorflow_text-2.17.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8eba0b5804235519b571c827c97337c332de270107f06af6d2171cdefdc4c6a0", size = 6109587 }, { url = "https://files.pythonhosted.org/packages/61/59/2090318555d98dc9dc868b3c585ada2e1139be538d954340726aa3d3899a/tensorflow_text-2.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89f04c3f478f1885ad4c7380643a768a72a3de79e1f8f40d50b48cc1fbf73893", size = 5205819 }, + { url = "https://files.pythonhosted.org/packages/92/65/e2d3d9300173a0927e8b7e3cf9a35f9539e9269786c1e1d9d945223fe21a/tensorflow_text-2.17.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a9b9f9c8a06878714a14f4e086fa8122beb2e141f82d0aa5a8f6b8f9b694db51", size = 6109684 }, { url = "https://files.pythonhosted.org/packages/de/32/182ecf4eb1432942876d9b0b089625564084c5ed4d03c02ddf2872177e95/tensorflow_text-2.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:161c09380b090774ed721cdcce973194458708250d7dfbac7cb9ea8a3e9ac762", size = 5205866 }, ] @@ -3653,6 +3709,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1b/d1/9babe2ccaecff775992753d8686970b1e2755d21c8a63be73aba7a4e7d77/wheel-0.44.0-py3-none-any.whl", hash = "sha256:2376a90c98cc337d18623527a97c31797bd02bad0033d41547043a1cbfbe448f", size = 67059 }, ] +[[package]] +name = "widgetsnbextension" +version = "4.0.13" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/56/fc/238c424fd7f4ebb25f8b1da9a934a3ad7c848286732ae04263661eb0fc03/widgetsnbextension-4.0.13.tar.gz", hash = "sha256:ffcb67bc9febd10234a362795f643927f4e0c05d9342c727b65d2384f8feacb6", size = 1164730 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/21/02/88b65cc394961a60c43c70517066b6b679738caf78506a5da7b88ffcb643/widgetsnbextension-4.0.13-py3-none-any.whl", hash = "sha256:74b2692e8500525cc38c2b877236ba51d34541e6385eeed5aec15a70f88a6c71", size = 2335872 }, +] + [[package]] name = "wrapt" version = "1.16.0"