diff --git a/tokenizer_ts/.npmignore b/tokenizer_ts/.npmignore index 0363f59..10160e0 100644 --- a/tokenizer_ts/.npmignore +++ b/tokenizer_ts/.npmignore @@ -8,4 +8,6 @@ dist/test/* debug.ts *.map *.tiktoken -.eslintrc.js \ No newline at end of file +.eslintrc.js +/perf/* +*.map diff --git a/tokenizer_ts/package.json b/tokenizer_ts/package.json index 3e258b4..e52441d 100644 --- a/tokenizer_ts/package.json +++ b/tokenizer_ts/package.json @@ -34,6 +34,7 @@ "scripts": { "test": "mocha -u tdd --require ts-node/register test/**/*.ts", "build": "tsc -p ./tsconfig.json", + "watch": "tsc -p ./tsconfig.json --watch", "eslint": "eslint src --ext ts", "format": "prettier --write \"./**/*.{ts,tsx}\"" }, diff --git a/tokenizer_ts/perf/.gitignore b/tokenizer_ts/perf/.gitignore new file mode 100644 index 0000000..ff67559 --- /dev/null +++ b/tokenizer_ts/perf/.gitignore @@ -0,0 +1,3 @@ +package.json +package-lock.json +*.cpuprofile diff --git a/tokenizer_ts/perf/benchmark-folder.js b/tokenizer_ts/perf/benchmark-folder.js new file mode 100644 index 0000000..1047065 --- /dev/null +++ b/tokenizer_ts/perf/benchmark-folder.js @@ -0,0 +1,65 @@ +const fs = require('fs/promises'); +const path = require('path'); +const inspector = require('inspector'); +const { promisify } = require('util'); + +const [,, encoderName, folderPath, method, modulePath] = process.argv; +const { createByEncoderName } = require(modulePath); +const minTime = 10_000; +const minCycles = 5; + +const fileExtensions = ['.ts', '.js', '.py']; + +async function readAllFilesInFolder(folderPath) { + const files = await fs.readdir(folderPath, { withFileTypes: true }); + const fileContents = await Promise.all(files.map(async (file) => { + const res = path.resolve(folderPath, file.name); + if (file.isDirectory()) { + return readAllFilesInFolder(res); + } else if (fileExtensions.some(f => res.endsWith(f))) { + return fs.readFile(res, 'utf8'); + } else { + return []; + } + })); + + return fileContents.flat(); +} + +Promise.all([ + readAllFilesInFolder(folderPath), + createByEncoderName(encoderName) +]).then(async ([files, tokenizer]) => { + let totalSize = 0; + for (const file of files) { + totalSize += file.length; + } + + const session = new inspector.Session(); + session.connect(); + const post = promisify(session.post).bind(session); + await post('Profiler.enable'); + await post('Profiler.start'); + + const start = performance.now(); + let cycles = []; + while (performance.now() - start < minTime || cycles.length < minCycles) { + const cycleStart = performance.now(); + switch (method) { + case 'encode': + files.forEach(file => tokenizer.encode(file)); + break; + case 'encodeTrimSuffix': + files.forEach(file => tokenizer.encodeTrimSuffix(file, 1337)); + break; + default: + throw new Error(`unknown method ${method}`); + } + cycles.push(performance.now() - cycleStart); + } + + const data = await post('Profiler.stop'); + await fs.writeFile('profile.cpuprofile', JSON.stringify(data.profile)); + + process.stdout.write(JSON.stringify({ totalSize, cycles })); +}); diff --git a/tokenizer_ts/perf/notebook.ipynb b/tokenizer_ts/perf/notebook.ipynb new file mode 100644 index 0000000..89c926f --- /dev/null +++ b/tokenizer_ts/perf/notebook.ipynb @@ -0,0 +1,188 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# TS Tiktokenizer Performance\n", + "\n", + "This notebook is used for analyzing the performance of and performance improvements to the Tokenizer. It uses the VS Code repo as its corpus. First, let's grab the last released version of `@microsoft/tiktokenizer`, and get a baseline." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "import os\n", + "import subprocess\n", + "import json\n", + "\n", + "vscode_repo_path = \"../../../vscode\"\n", + "if not os.path.exists(vscode_repo_path):\n", + " print(\"The repo does not exist.\")\n", + "\n", + "def run_benchmark(module_path, encoder_name = 'cl100k_base', method = 'encode'):\n", + " command = f\"node ./benchmark-folder.js {encoder_name} {vscode_repo_path}/src {method} {module_path}\"\n", + " result = subprocess.check_output(command, shell=True)\n", + " parsed = json.loads(result)\n", + " return parsed\n", + "\n", + "#os.system('npm install @microsoft/tiktokenizer --prefix ./')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Performance can vary machine to machine, make sure to collect a baseline before you start working. Every time you run a benchmark, there'll be a `profile.cpuprofile` written out that you can inspect." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# This can take a minute, make some tea 🍵\n", + "baseline = run_benchmark('@microsoft/tiktokenizer')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZwAAAEICAYAAABrtkJsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAbKUlEQVR4nO3deVTVdf7H8dcllEwKUVxASRR3ltJgEmvUyqVFTRsrChPb1BptN5uxGc0sGxudpvHMKWc0tRzrlJYokTalNpmpaJMpiWXhguYyYoSpbO/fH56+vyiuiMoHpOfjHM7xftfX/Rzuffld7sVnZiYAAKpYQHUHAAD8MlA4AAAnKBwAgBMUDgDACQoHAOAEhQMAcILCQY0XExOjFStWVHeMM27YsGF6/PHHz/h2g4OD9dVXX53x7dZUVTWOOPMoHFS74OBg7ycgIED16tXzHs+bN0+bN29Wz549qzzHhAkTNGTIkFNev6a88RUUFKh169ZnZFtHjx5VgwYN9P777/9s3oMPPqjBgwdLkj788EN169ZNISEhatiwoS677DKtW7fO73a3bt2qG2+8UWFhYQoJCVF8fLymTZumkpKSM5IbNROFg2pXUFDg/Vx44YVavHix9zglJaW64/2inXvuubr55ps1d+7cMtNLSko0f/58paamKj8/X/369dPo0aN18OBB5ebmavz48QoKCip3m9u2bdOll16qyMhIffbZZ/r222/1+uuvKzMzU999952Lp4XqYkAN0rJlS3v33Xf9Ths/frwNHjzYUlJSLDg42GJjYy07O9uefvppa9y4sbVo0cKWLl3qrXvo0CG74447rFmzZhYREWHjxo2z4uLin+03IyPD6tSpY4GBgVa/fn2Lj483M7Pc3Fzr37+/hYaGWnR0tM2YMaPc3C+++KIFBgZanTp1rH79+tavXz8zM8vKyrIePXpYSEiIderUyRYtWuStk5qaauPGjTMzs/z8fOvZs6eNHj3aSktL7fPPP7devXpZaGiotWvXzl577bUy691777127bXXWnBwsP3qV7+yL7/80psvyb744gvLzc21+vXrez/16tWzH7/kZ86caR06dLAGDRpYnz59LCcnp9zntmrVKgsODrbDhw9709LT061x48ZWVFRk69ats5CQkHLXLU9KSopde+21J1xm0aJF1qlTJwsJCbEePXpYVlaWN2/Dhg3WuXNnCw4OtptuusluvvlmbxzNzBYvXmwXXXSRhYSEWFJSkn366acnnQ1Vi8JBjXIyhRMUFGTvvPOOFRUV2W233WZRUVE2adIkKywstBkzZlhUVJS37vXXX2/Dhw+3goIC27t3ryUmJtoLL7xQ7r7Hjx9vKSkpZaZ1797d7rnnHjty5Ih98sknFhYWZv/+97/LXf/HBWJmVlhYaNHR0fbUU0/ZsWPH7L333rPg4GDbsmVLmeUPHDhgiYmJ3roFBQXWokULmzVrlhUVFdn69eutUaNGtmnTJm+90NBQW7NmjRUVFdmtt95qN998s7ffHwrnp2699VZLTk42M7M333zToqOjLSsry4qKiuzJJ5+0pKSkcp+XmVnbtm3t5Zdf9h4nJyfb/fffb2Zm3377rTVs2NCGDh1qb7/9th08eNDvdszMmjZtarNmzfI7Pzs728477zxbtmyZFRYW2p/+9CeLjo62Y8eO2bFjx+zCCy+0adOmWWFhob3++usWGBjojd369eutcePG9vHHH1txcbHNnj3bWrZsaUePHj1hJrhB4aBGOZnC6dWrlzcvLS3N6tev7x215OfnmyTLy8uzb775xurWrWvff/+9t/y//vUv69mzZ7n7/mnh7NixwwICAiw/P9+b9thjj1lqamq56/+0cD744ANr2rSplZSUeNOSk5Nt/Pjx3vK33367xcTE2JQpU7xlXn31Vbv88svLbHv48OE2YcIEb70777zTm5eenm7t27f3HpdXOM8884x16dLFG4urr77a/vnPf3rzS0pKrF69en6Pcp588knr3bu3mR0vmHr16tmGDRu8+VlZWZaammrNmze3c845x/r372/ffPNNudsKDAy0jIyMcueZmU2cONFuvPHGMtkiIiJs+fLltnLlSgsPD7fS0lJvflJSkjfuI0eOtMcff7zM9tq1a2crVqzwuz+4wzUcnHWaNm3q/btevXoKCwvTOeec4z2Wjl8X2r59u4qKihQeHq4GDRqoQYMGGjFihPbt23dS+9m9e7caNmyo888/35vWsmVL5ebmnvT6kZGRCgj4/5fZT9dPT0/XkSNHNHLkSG/a9u3btWbNGi9zgwYNNG/ePH3zzTfeMs2aNfP+fd5556mgoMBvjoyMDP31r3/VW2+95Y3P9u3bdf/993vbb9iwoczM73MbOnSoli9frtzcXL3xxhtq06aNOnfu7M3v2LGjZs+erV27dmnTpk3avXu3HnjggXK31ahRI+3Zs8dv3t27d6tly5be44CAAEVGRio3N1e7d+9W8+bN5fP5vPk/Xnb79u2aOnVqmbHbuXOndu/e7Xd/cCewugMAVSUyMlJBQUE6cOCAAgMr/lX/8ZuYJEVEROjgwYP67rvvvNLZsWOHmjdvftLr79y5U6WlpV7p7NixQ+3atfOWufvuu5WXl6drr71W77zzjurXr6/IyEj16NFD7777bqWeb3mys7OVmpqqhQsXKjIy0pseGRmpcePGnfRNGRdeeKF+/etfa968ecrIyNDQoUP9LtuhQwcNGzZML774Yrnze/XqpQULFuj2228vd35ERIQ+++wz77GZaefOnV7R5Obmysy88d6xY4eio6PLPK9x48ad1POCWxzhoNYKDw9Xnz599PDDDys/P1+lpaXatm2bVq5cWe7yTZs2VU5OjkpLSyUdf/Pq1q2bfve73+no0aPauHGjZs6c6fdNumnTpmU+/3LppZeqfv36mjJlioqKirRixQotXrxYycnJZdabPn262rdvr379+unIkSPq16+ftm7dqpdffllFRUUqKirSunXr9Pnnn1fq+efn5+v666/XpEmTdPnll5eZN3LkSE2ePFmbN2+WJO9OsRNJTU3V9OnTtWrVqjJjsGXLFk2dOlW7du2SJO3cuVPz589X165dy93OE088oY8++khjxozxjtq+/PJLDRkyRIcOHdJNN92k9PR0vffeeyoqKtLUqVMVFBSkbt26KSkpSYGBgXr++edVXFyshQsXau3atd627777br3wwgtas2aNzEyHDx9Weno6d7/VEBQOarW5c+eqsLBQnTp1UmhoqAYPHuz3dM6NN94o6fgpny5dukiS5s+fr5ycHEVERGjQoEF64okn1Lt373LXv/POO5WVlaUGDRpo4MCBqlu3rtLS0pSRkaGwsDDde++9mjt3rjp06FBmPZ/PpxkzZigyMlLXX3+96tSpo2XLlunVV19VRESEmjVrprFjx+rYsWOVeu4bNmxQdna2HnrooTKfdZKkQYMGaezYsUpOTtYFF1yg2NhYZWRknHB7gwcPVl5enq666iqFh4d7088//3ytWbPGK9iuXbsqNjZWU6dOLXc70dHRWr16tXJychQTE6OQkBD95je/UUJCgs4//3y1b99er7zyikaPHq2wsDAtXrxYixcvVt26dVW3bl0tXLhQs2fPVmhoqF577TXdcMMN3rYTEhL0j3/8Q6NGjVJoaKjatGmj2bNnV2rcUHV8ZvwBNgBA1eMIBwDgBIUDAHCCwgEAOEHhAACc4HM4foSFhSkqKqq6YwDAWSUnJ0cHDhwodx6F40dUVJQyMzOrOwYAnFUSEhL8zuOUGgDACQoHAOAEhQMAcILCAQA4QeEAAJygcAAATlA4AAAnKBwAgBMUDgDACQoHAOAEhQMAcILCAQA4QeEAAJygcAAATlA4AAAnKBwAgBMUDgDACQoHAOAEhQMAcILCAQA4QeEAAJygcAAATlA4AAAnKBwAgBMUDgDACQoHAOAEhQMAcILCAQA4QeEAAJygcAAATlA4AAAnKBwAgBMUDgDACQoHAOAEhQMAcILCAQA4QeEAAJygcAAATlA4AAAnKBwAgBMUDgDACQoHAOAEhQMAcILCAQA4QeEAAJygcAAATlA4AAAnKBwAgBMUDgDACQoHAOAEhQMAcILCAQA4QeEAAJygcAAATlA4AAAnKBwAgBMUDgDACQoHAOAEhQMAcILCAQA4QeEAAJygcAAATlA4AAAnKBwAgBMUDgDACQoHAOAEhQMAcILCAQA4QeEAAJygcAAATlA4AAAnKBwAgBMUDgDACQoHAOAEhQMAcILCAQA4QeEAAJygcAAATlA4AAAnKBwAgBMUDgDACQoHAOAEhQMAcILCAQA4UWHhHD58WKWlpZKkrVu3Ki0tTUVFRVUeDABQu1RYON27d9fRo0eVm5urq666Si+99JKGDRvmIBoAoDapsHDMTOedd54WLlyo0aNH680331RWVpaLbACAWuSkCmf16tWaN2+errvuOklScXFxlQcDANQuFRbOc889p8mTJ2vQoEGKiYnRV199pSuuuMJFNgBALeIzM6vuEDVRQkKCMjMzqzsGAJxVTvTeGVjRypmZmXr66aeVk5NT5lTaxo0bz1xCAECtV2HhpKSk6Nlnn1VcXJwCAvjYDqpGw4YNlZeX52x/Nv4C+Z7Id7Y/1D6hoaE6ePBgdcc4q1RYOI0bN9aAAQNcZMEvWF5enpye3Z0Q4nZ/qHV8Pl91RzjrVFg4TzzxhO666y5dddVVCgoK8qbfcMMNVRoMAFC7VFg4L730krZs2aKioiLvlJrP56NwAACVUmHhfPrpp/rss89cZAEA1GIV3gXQtWtXvlkAAHDaKjzC+fDDDzVnzhy1atVKQUFBMjP5fD5uiwYAVEqFhfPOO++4yAEAqOUqLJyWLVu6yFGr+Hw+brkFcNaqqvcwPskJAHCCwgEAOOG3cPr27au//OUv2rJli8s8AIBaym/hzJkzR6GhoZowYYK6dOmie+65R4sWLVJBQYHLfACAWsJv4TRr1kzDhg3Tq6++qszMTA0dOlTr169X37591atXL02ZMuWEG87JyVFsbOwZDyxJK1asUL9+/SRJaWlpeuaZZ6pkPwCAM6fCu9QkKSAgQElJSUpKStLEiRN14MABLV26tKqznZQBAwbw5aIAcBY4pZsGwsLClJKSUuFyxcXFSk1NVXx8vAYPHqzvv/9eEydOVGJiomJjYzV8+HDv1rvnn39enTp1Unx8vJKTkyVJhw8f1h133KHExER17txZixYt+tk+Zs+erVGjRkmShg0bpvvuu0/dunVT69at9cYbb3jLPfvss0pMTFR8fLzGjx9/Kk8bAHAaqvQutezsbA0fPlwbN27UBRdcoL///e8aNWqU1q1bp02bNunIkSNasmSJJOmZZ57RJ598oo0bN+qFF16QJD311FO68sortW7dOi1fvlxjxozR4cOHT7jPPXv26MMPP9SSJUv02GOPSZKWLVumL774QmvXrtV///tfrV+/Xh988MHP1p0xY4YSEhKUkJCg/fv3n9Zz9/l8/FTiBzgbVffr5mx7PZ7UKbVTFRkZqcsuu0ySNGTIED3//PNq1aqVpkyZou+//14HDx5UTEyM+vfvr/j4eKWkpGjgwIEaOHCgpONFkZaWpj//+c+SpKNHj2rHjh0n3OfAgQMVEBCgTp06ae/evd52li1bps6dO0uSCgoK9MUXX6h79+5l1h0+fLiGDx8u6fifST0dfPCzcigdnI1q6+u8ql6PFRbO3r179fvf/167d+9WRkaGsrKytHr1at15550VbvynoX0+n+69915lZmYqMjJSEyZM0NGjRyVJ6enp+uCDD5SWlqYnn3xSmzdvlplpwYIFat++/c8y+fPjv9nzwy+Dmel3v/udRowYUWFmAEDVqPCU2rBhw9S3b1/t3r1bktSuXTs999xzJ7XxHTt2aPXq1ZKk+fPn6/LLL5d0/BpQQUGBd42ltLRUO3fu1BVXXKEpU6bo0KFDKigoUN++ffW3v/3NK45PPvmk0k9QOv6ZolmzZnm3dOfm5mrfvn2ntC0AwKmp8AjnwIEDuummmzR58uTjKwQG6pxzzjmpjXfs2FFz5szRiBEj1LZtW91zzz3Ky8tTXFycoqKilJiYKEkqKSnRkCFD9O2338rM9OCDD6pBgwb6wx/+oAceeEDx8fEyM0VFRXnXfCqjT58++vzzz5WUlCRJCg4O1iuvvKImTZpUelsAgFPjswpOQvbs2VMLFixQ7969tWHDBn388ccaO3asVq5c6SpjtUhISFBmZuYprevz8eWdleV8zCaESBO+dbc/1Dq1+XV+Os/tRO+dFR7hTJs2TQMGDNC2bdt02WWXaf/+/WVuNwYA4GRUWDhdunTRypUrlZ2dLTNT+/btVadOHRfZAAC1SIWFU1JSorfffls5OTkqLi7WsmXLJEkPPfRQlYcDANQeFRZO//79de655youLk4BAfw1g5NRW8/rAvhlqKr3sAoLZ9euXdq4cWOV7BwA8MtR4SHLNddc451GAwDgVFV4hNO1a1cNGjRIpaWlqlOnjsxMPp9P+fn5LvIBAGqJCgvn4Ycf1urVqxUXF8f3XQEATlmFp9Tatm2r2NhYygYAcFoqPMIJDw9Xz549dc0115T5YkxuiwYAVEaFhdOqVSu1atVKhYWFKiwsdJEJv1Auj6Jt/AUcteO0hIaGVneEs06FhcNfx4QL1fHZJZvgfJfAL5rfwhk1apSmT5+u/v37l/s/wbS0tCoNBgCoXfwWzty5czV9+nQ98sgjLvMAAGopv4UTHR0tSerRo4ezMACA2stv4ezfv1/Tpk3zuyJ3qQEAKsNv4ZSUlKigoIAvogQAnBF+Cyc8PFx//OMfXWYBANRifr9pgCMbAMCZ5Ldw3nvvPZc5AAC1nN/CadiwocscAIBajj/hCQBwgsIBADhB4QAAnKBwAABOUDgAACcoHACAExQOAMAJCgcA4ASFAwBwgsIBADhB4QAAnKBwAABOUDgAACcoHACAExQOAMAJCgcA4ASFAwBwgsIBADhB4QAAnKBwAABOUDgAACcoHACAExQOAMAJCgcA4ASFAwBwgsIBADhB4QAAnKBwAABOUDgAACcoHACAExQOAMAJCgcA4ASFAwBwgsIBADhB4QAAnKBwAABOUDgAACcoHACAExQOAMAJCgcA4ASFAwBwgsIBADhB4QAAnKBwAABOUDgAACcoHACAExQOAMAJCgcA4ASFAwBwgsIBADhB4QAAnKBwAABOUDgAACcoHACAExQOAMAJCgcA4ASFAwBwgsIBADhB4QAAnKBwAABOUDgAACcoHACAExQOAMAJCgcA4ASFAwBwgsIBADhB4QAAnKBwAABOUDgAACcoHACAExQOAMAJCgcA4ASFAwBwgsIBADhB4QAAnKBwAABOUDgAACcoHACAExQOAMAJCgcA4ASFAwBwgsIBADjhMzOr7hA1UVhYmKKioqps+/v371fjxo2rbPuni3ynr6ZnJN/pIV/5cnJydODAgXLnUTjVJCEhQZmZmdUdwy/ynb6anpF8p4d8lccpNQCAExQOAMAJCqeaDB8+vLojnBD5Tl9Nz0i+00O+yuMaDgDACY5wAABOUDgAACconDOspKREnTt3Vr9+/SRJn376qZKSkhQXF6f+/fsrPz/fW3by5Mlq06aN2rdvr6VLl3rT169fr7i4OLVp00b33XefzuRZz5PNl5OTo3r16uniiy/WxRdfrJEjR1Z5vqioKMXFxeniiy9WQkKCJOngwYPq3bu32rZtq969eysvL89bvjrGrzIZa8oYvv7664qJiVFAQMDPbpN1PYaVyVdTxm/MmDHq0KGD4uPjNWjQIB06dMhbviaMn7981TF+FTKcUVOnTrVbbrnFrrvuOjMzS0hIsBUrVpiZ2cyZM+3xxx83M7PNmzdbfHy8HT161L766itr3bq1FRcXm5lZYmKiffTRR1ZaWmpXX321vf32287zff311xYTE1PuNqoqX8uWLW3//v1lpo0ZM8YmT55sZmaTJ0+2Rx991Myqb/wqk7GmjGFWVpZt2bLFevToYevWrfOmV8cYViZfTRm/pUuXWlFRkZmZPfroo9X6O1iZfNUxfhXhCOcM2rVrl9LT03XXXXd507Kzs9W9e3dJUu/evbVgwQJJ0qJFi5ScnKygoCC1atVKbdq00dq1a7Vnzx7l5+crKSlJPp9PQ4cO1VtvveU8nz9Vma88ixYtUmpqqiQpNTXV21d1jF9lM/rjOmPHjh3Vvn37n02vKWPoL58/rvP16dNHgYGBkqSuXbtq165dkmrO+PnL5091vEZ+QOGcQQ888ICmTJmigID/H9bY2FilpaVJOn7qYOfOnZKk3NxcRUZGesu1aNFCubm5ys3NVYsWLX423XU+Sfr666/VuXNn9ejRQ//5z3+83FWVz+fzqU+fPrrkkks0Y8YMSdLevXsVHh4uSQoPD9e+ffu8HK7Hr7IZpZoxhv5UxxhWJp9U88Zv1qxZuuaaa7wcNW38fpxPcj9+FQl0spdfgCVLlqhJkya65JJLtGLFCm/6rFmzdN9992nixIkaMGCA6tatK0nlnjP1+Xx+p7vOFx4erh07dqhRo0Zav369Bg4cqM2bN1dZPklatWqVIiIitG/fPvXu3VsdOnTwu6zr8TuVjDVlDH84gv2p6hjDyuSraeP31FNPKTAwUCkpKZJq3vj9NF91jF9FOMI5Q1atWqW0tDRFRUUpOTlZ77//voYMGaIOHTpo2bJlWr9+vW655RZFR0dLOv6/ih8fTezatUsRERFq0aJFmUPiH6a7zhcUFKRGjRpJki655BJFR0dr69atVZZPkredJk2aaNCgQVq7dq2aNm2qPXv2SDp+KqBJkyaS3I/fqWSsKWPoT3WMYWXy1aTxmzNnjpYsWaJ58+Z5b841afzKy1cd41chJ1eKfmGWL1/uXZTfu3evmZmVlJTYbbfdZjNnzjQzs02bNpW54NiqVSvvgmNCQoKtXr3au6CXnp7uPN++ffu8PNu2bbOIiAj73//+V2X5CgoKLD8/3/t3UlKSZWRk2COPPFLmgvyYMWPMrHrGr7IZa8oY/uCnF+Vdj2Fl89WU8cvIyLCOHTvavn37yixfU8bPXz7X43cyKJwq8OM39Oeee87atm1rbdu2tbFjx1ppaam33KRJk6x169bWrl27MneJrFu3zmJiYqx169b229/+tsw6rvK98cYb1qlTJ4uPj7fOnTtbWlpalebbtm2bxcfHW3x8vHXq1MkmTZpkZmYHDhywK6+80tq0aWNXXnml94Ixcz9+lc1YU8Zw4cKF1rx5c6tbt641adLE+vTp463jcgwrm6+mjF90dLS1aNHCLrroIrvoootsxIgR3jo1Yfz85XM9fieDr7YBADjBNRwAgBMUDgDACQoHAOAEhQMAcILCAQA4QeEAAJygcAAATvwfym6iFz1VaH4AAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "cycles = baseline['cycles']\n", + "fig, ax = plt.subplots()\n", + "ax.boxplot(cycles, vert=False, labels=[\"baseline\"])\n", + "ax.set_title('Time to tokenize VS Code')\n", + "ax.set_ylabel('Time / ms')\n", + "fig.patch.set_facecolor('white')\n", + "ax.set_facecolor('white')\n", + "plt.show()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here's the current performance of the repo. Make sure to `npm run build` or `npm run watch` first!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "updated = run_benchmark('../')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "fig, ax = plt.subplots()\n", + "ax.boxplot([baseline['cycles'], updated['cycles']], vert=False, labels=[\"baseline\", \"updated\"])\n", + "ax.set_title('Time to tokenize VS Code')\n", + "ax.set_ylabel('Time / ms')\n", + "fig.patch.set_facecolor('white')\n", + "ax.set_facecolor('white')\n", + "plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "\n", + "# Calculate average time in seconds\n", + "baseline_avg_time = sum(baseline['cycles']) / len(baseline['cycles']) / 1000\n", + "updated_avg_time = sum(updated['cycles']) / len(updated['cycles']) / 1000\n", + "\n", + "# Calculate total size in MB\n", + "total_size_MB = baseline['totalSize'] / (1024 * 1024)\n", + "\n", + "# Calculate average speed in MB/s\n", + "baseline_speed = total_size_MB / baseline_avg_time\n", + "updated_speed = total_size_MB / updated_avg_time\n", + "\n", + "# Plot the bar chart\n", + "ax.bar(['Baseline', 'Updated'], [baseline_speed, updated_speed])\n", + "ax.set_ylabel('Tokenization Speed / MBs^-1')\n", + "fig.patch.set_facecolor('white')\n", + "ax.set_facecolor('white')\n", + "plt.title('Tokenization Speed')\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tokenizer_ts/src/bytePairEncode.ts b/tokenizer_ts/src/bytePairEncode.ts index d74acd6..fc0f353 100644 --- a/tokenizer_ts/src/bytePairEncode.ts +++ b/tokenizer_ts/src/bytePairEncode.ts @@ -1,17 +1,84 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -/** - * Convert a Uint8Array to a string - * @param uint8Array - * @returns string - */ -export function uint8ArrayToString(uint8Array: Uint8Array): string { - return Array.from(uint8Array) - .map(num => num.toString()) - .join("_"); +const enum Constant { + // we have 48 bits per level, we can safely bitwise encode 32 bits at a time, + // so this works in two passes + BytesPerLevel = 6, +} + +// exported for testing +export const binaryMapKey = (k: Uint8Array, start: number, end: number): number => { + const length = end - start; + + // 'lower' and 'upper' are both 24-bit integers, like + // 0xFF FF FF + // ^3 ^2 ^1 + // If we say have a length of 2, we should disregard the last "3" byte, so we + // create a mask like + // 0x00 FF FF (started at 0xFF FF FF and shifted over by 8 bits) + // ^3 ^2 ^1 + // so that we discard the data outside our range + const lowerMask = 0xFFFFFF >>> Math.max(0, (3 - length) * 8); + const lower = (k[start + 0] | (k[start + 1] << 8) | (k[start + 2] << 16)) & lowerMask; + + const upperMask = 0xFFFFFF >>> Math.min(31, Math.max(0, (6 - length) * 8)); + const upper = (k[start + 3] | (k[start + 4] << 8) | (k[start + 5] << 16)) & upperMask; + return lower + (0x1000000 * upper); +}; + +export class BinaryMap { + private readonly map: Map | V> = new Map(); + private thisValue?: V; + + public get(key: Uint8Array, start: number = 0, end: number = key.length): V | undefined { + const value = this.map.get(binaryMapKey(key, start, end)); + const isFinal = end < Constant.BytesPerLevel + start; + + if (isFinal) { + return value instanceof BinaryMap ? value.thisValue : value; + } else if (value instanceof BinaryMap) { + return value.get(key, Constant.BytesPerLevel + start, end); + } else { + return undefined; + } + } + + public set(key: Uint8Array, value: V): void { + const k = binaryMapKey(key, 0, key.length); + const existing = this.map.get(k); + const isFinal = key.length < Constant.BytesPerLevel; + + if (existing === undefined) { + if (isFinal) { + this.map.set(k, value); + } else { + const newMap = new BinaryMap(); + newMap.set(key.subarray(Constant.BytesPerLevel), value); + this.map.set(k, newMap); + } + } else if (isFinal) { + if (existing instanceof BinaryMap) { + existing.thisValue = value; + } else { + this.map.set(k, value); + } + } else { + if (existing instanceof BinaryMap) { + existing.set(key.subarray(Constant.BytesPerLevel), value); + } else { + const newMap = new BinaryMap(); + newMap.set(key.subarray(Constant.BytesPerLevel), value); + newMap.thisValue = existing; + this.map.set(k, newMap); + } + + } + } } +const maxRank = 0x7FFFFFFF; // max int32, try and keep things in integer space + /** * This function implements the byte pair encoding algorithm. * @param mergingBytes: bytes to be merged @@ -20,67 +87,68 @@ export function uint8ArrayToString(uint8Array: Uint8Array): string { */ export function bytePairEncode( mergingBytes: Uint8Array, - ranks: ReadonlyMap + ranks: BinaryMap, + length: number, ): number[] { - if (mergingBytes.length === 1) { - return [ranks.get(mergingBytes[0].toString())!]; + if (length === 1) { + return [ranks.get(mergingBytes)!]; } + let minRank = maxRank; + let minIndex = -1; + const byteIndicesAndRanks: [number, number][] = []; - for (let i = 0; i < mergingBytes.length + 1; i++) { - byteIndicesAndRanks.push([i, Number.MAX_SAFE_INTEGER]); + for (let i = 0; i < length - 1; i++) { + const rank = ranks.get(mergingBytes, i, i + 2) ?? maxRank; + if (rank < minRank) { + minRank = rank; + minIndex = i; + } + + byteIndicesAndRanks.push([i, rank]); } + byteIndicesAndRanks.push([length - 1, maxRank]); + byteIndicesAndRanks.push([length, maxRank]); function getRank(startIndex: number, skip = 0): number { if (startIndex + skip + 2 < byteIndicesAndRanks.length) { - const slice = mergingBytes.slice( + const rank = ranks.get( + mergingBytes, byteIndicesAndRanks[startIndex][0], byteIndicesAndRanks[startIndex + skip + 2][0] ); - const rank = ranks.get(uint8ArrayToString(slice)); if (rank !== undefined) { return rank; } } - return Number.MAX_SAFE_INTEGER; + return maxRank; } - for (let i = 0; i < byteIndicesAndRanks.length - 2; i++) { - const rank = getRank(i); - if (rank !== Number.MAX_SAFE_INTEGER) { - byteIndicesAndRanks[i][1] = rank; + while (minRank !== maxRank) { + byteIndicesAndRanks[minIndex][1] = getRank(minIndex, 1); + if (minIndex > 0) { + byteIndicesAndRanks[minIndex - 1][1] = getRank(minIndex - 1, 1); } - } + byteIndicesAndRanks.splice(minIndex + 1, 1); + - while (byteIndicesAndRanks.length > 1) { - let minRank: [number, number] = [0, Number.MAX_SAFE_INTEGER]; + minIndex = -1; + minRank = maxRank; for (let i = 0; i < byteIndicesAndRanks.length - 1; i++) { - if (byteIndicesAndRanks[i][1] < minRank[1]) { - minRank = [i, byteIndicesAndRanks[i][1]]; + if (byteIndicesAndRanks[i][1] < minRank) { + minRank = byteIndicesAndRanks[i][1]; + minIndex = i; } } - if (minRank[1] !== Number.MAX_SAFE_INTEGER) { - const j = minRank[0]; - byteIndicesAndRanks[j][1] = getRank(j, 1); - if (j > 0) { - byteIndicesAndRanks[j - 1][1] = getRank(j - 1, 1); - } - byteIndicesAndRanks.splice(j + 1, 1); - } else { - break; - } } const outList: number[] = []; for (let i = 0; i < byteIndicesAndRanks.length - 1; i++) { outList.push( ranks.get( - uint8ArrayToString( - mergingBytes.slice( - byteIndicesAndRanks[i][0], - byteIndicesAndRanks[i + 1][0] - ) - ) + mergingBytes, + byteIndicesAndRanks[i][0], + byteIndicesAndRanks[i + 1][0] )! ); } diff --git a/tokenizer_ts/src/textEncoder.ts b/tokenizer_ts/src/textEncoder.ts new file mode 100644 index 0000000..a5393cd --- /dev/null +++ b/tokenizer_ts/src/textEncoder.ts @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +/** + * A text encoder interface. + */ +export interface ITextEncoder { + /** + * Number of bytes written in the last call to {@link encode} + */ + length: number; + + /** + * Encodes the text and returns the Uint8Array it was written to. The length + * of data written to the array can be found in {@link length}. + * + * The data returned in the array is only valid until the next call to encode. + */ + encode(text: string): Uint8Array; +} + +class UniversalTextEncoder implements ITextEncoder { + public length = 0; + private encoder = new TextEncoder(); + + public encode(text: string): Uint8Array { + const arr = this.encoder.encode(text); + this.length = arr.length; + return arr; + } +} + +class NodeTextEncoder implements ITextEncoder { + private buffer = Buffer.alloc(256); + public length = 0; + + public encode(text: string): Uint8Array { + while (true) { + this.length = this.buffer.write(text, 'utf8'); + + // buffer.write returns the number of bytes written and can write less + // than the length of the string if the buffer is too small. If this + // might have happened (4 bytes is the longest utf8 codepoint), make + // the buffer bigger and try again. + if (this.length < this.buffer.length - 4) { + return this.buffer; + } + + this.buffer = Buffer.alloc(this.length * 2); + this.length = this.buffer.write(text); + } + } +} + +export const makeTextEncoder = (): ITextEncoder => + typeof Buffer !== 'undefined' ? new NodeTextEncoder() : new UniversalTextEncoder(); diff --git a/tokenizer_ts/src/tikTokenizer.ts b/tokenizer_ts/src/tikTokenizer.ts index 88e6183..15eea14 100644 --- a/tokenizer_ts/src/tikTokenizer.ts +++ b/tokenizer_ts/src/tikTokenizer.ts @@ -3,8 +3,9 @@ import * as fs from "fs"; import { LRUCache } from "lru-cache"; -import { TextDecoder, TextEncoder } from "util"; -import { bytePairEncode, uint8ArrayToString } from "./bytePairEncode"; +import { TextDecoder } from "util"; +import { BinaryMap, bytePairEncode } from "./bytePairEncode"; +import { makeTextEncoder } from './textEncoder'; /** * Load BPE ranks from a file @@ -59,12 +60,12 @@ function escapeRegExp(regex: string) { */ export class TikTokenizer { private regex?: RegExp; - private encoder?: Map; + private encoder?: BinaryMap; private decoder?: Map; private specialTokensRegex?: RegExp; private specialTokensEncoder?: ReadonlyMap; private specialTokensDecoder?: Map; - private textEncoder = new TextEncoder(); + private textEncoder = makeTextEncoder(); private textDecoder = new TextDecoder("utf-8"); public readonly cache: LRUCache; @@ -94,9 +95,9 @@ export class TikTokenizer { specialTokensEncoder: ReadonlyMap, regexPattern: string ): void { - this.encoder = new Map(); + this.encoder = new BinaryMap(); for (const [key, value] of bpeDict) { - this.encoder.set(uint8ArrayToString(key), value); + this.encoder.set(key, value); } this.regex = new RegExp(regexPattern, "gu"); this.specialTokensRegex = new RegExp( @@ -111,7 +112,7 @@ export class TikTokenizer { this.decoder.set(value, key); } - if (this.encoder.size !== this.decoder.size) { + if (bpeDict.size !== this.decoder.size) { throw new Error("Encoder and decoder sizes do not match"); } @@ -200,17 +201,18 @@ export class TikTokenizer { const substring = text.substring(start, end); this.regex!.lastIndex = 0; while ((match = this.regex!.exec(substring))) { - if (this.cache.has(match[0])) { - tokenIds.push(...this.cache.get(match[0])!); + const cached = this.cache.get(match[0]); + if (cached) { + tokenIds.push(...cached); } else { // cache miss const bytes = this.textEncoder.encode(match[0]); - const token = this.encoder?.get(uint8ArrayToString(bytes)); + const token = this.encoder?.get(bytes, 0, this.textEncoder.length); if (token !== undefined) { tokenIds.push(token); this.cache.set(match[0], [token]); } else { - const encodedTokens = bytePairEncode(bytes, this.encoder!); + const encodedTokens = bytePairEncode(bytes, this.encoder!, this.textEncoder.length); tokenIds.push(...encodedTokens); this.cache.set(match[0], encodedTokens); } @@ -248,7 +250,7 @@ export class TikTokenizer { } else { // cache miss const bytes = this.textEncoder.encode(piece); - const token = this.encoder!.get(uint8ArrayToString(bytes)); + const token = this.encoder!.get(bytes, 0, bytes.length); if (token !== undefined) { this.cache.set(piece, [token]); if (tokenCount + 1 <= maxTokenCount) { @@ -259,7 +261,7 @@ export class TikTokenizer { break; } } else { - const encodedTokens = bytePairEncode(bytes, this.encoder!); + const encodedTokens = bytePairEncode(bytes, this.encoder!, this.textEncoder.length); this.cache.set(piece, encodedTokens); if (tokenCount + encodedTokens.length <= maxTokenCount) { tokenCount += encodedTokens.length; @@ -394,8 +396,8 @@ export class TikTokenizer { tokenIds.push(...tokens!); tokenCountMap.set(tokenCount, encodeLength); } else { - const bytes = new TextEncoder().encode(piece); - const token = this.encoder!.get(uint8ArrayToString(bytes)); + const bytes = this.textEncoder.encode(piece); + const token = this.encoder!.get(bytes); if (token !== undefined) { this.cache.set(piece, [token]); tokenCount++; @@ -403,7 +405,7 @@ export class TikTokenizer { tokenIds.push(token); tokenCountMap.set(tokenCount, encodeLength); } else { - const encodedTokens = bytePairEncode(bytes, this.encoder!); + const encodedTokens = bytePairEncode(bytes, this.encoder!, this.textEncoder.length); this.cache.set(piece, encodedTokens); tokenCount += encodedTokens.length; encodeLength += piece.length; @@ -473,7 +475,8 @@ export class TikTokenizer { } else { const specialTokenValue = this.specialTokensDecoder?.get(token); if (specialTokenValue !== undefined) { - tokenBytes = Array.from(this.textEncoder.encode(specialTokenValue)); + const bytes = this.textEncoder.encode(specialTokenValue); + tokenBytes = Array.from(bytes.subarray(0, this.textEncoder.length)); } } decoded.push(...tokenBytes); diff --git a/tokenizer_ts/test/binaryMap.test.ts b/tokenizer_ts/test/binaryMap.test.ts new file mode 100644 index 0000000..6d6d3e7 --- /dev/null +++ b/tokenizer_ts/test/binaryMap.test.ts @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import * as assert from "assert"; +import { BinaryMap, binaryMapKey } from "../src/bytePairEncode"; +suite("BinaryMap Test Suite", function () { + test("Test basic input to map - one level", done => { + const binMap: BinaryMap = new BinaryMap(); + binMap.set(new Uint8Array([1, 50, 24]), 1); + assert(binMap.get(new Uint8Array([1, 50, 24])) === 1); + assert(binMap.get(new Uint8Array([1, 50])) === undefined); + assert(binMap.get(new Uint8Array([1, 50, 24, 100])) === undefined); + + binMap.set(new Uint8Array([1, 50, 24, 100]), 100); + assert(binMap.get(new Uint8Array([1, 50, 24, 100])) === 100); + done(); + }); + test("Test basic input to map - one or two levels", done => { + const binMap: BinaryMap = new BinaryMap(); + binMap.set(new Uint8Array([1, 50, 24, 34, 64, 23]), 1); + binMap.set(new Uint8Array([1, 50, 24, 34, 64, 23, 60, 120, 40]), 2); + binMap.set(new Uint8Array([1, 50, 24, 34, 64, 23, 60, 120, 40, 21, 54, 232]), 3); + assert(binMap.get(new Uint8Array([1, 50, 24, 34, 64, 23])) === 1); + assert(binMap.get(new Uint8Array([1, 50, 24, 34, 64, 23, 60, 120, 40])) === 2); + assert(binMap.get(new Uint8Array([1, 50, 24, 34, 64, 23, 60, 120, 40, 21, 54, 232])) === 3); + done(); + }); + test("Test `get` with start and end specified", done => { + const binMap: BinaryMap = new BinaryMap(); + binMap.set(new Uint8Array([64, 23]), 100); + binMap.set(new Uint8Array([1, 50, 24]), 1); + binMap.set(new Uint8Array([24, 34, 64]), 2); + binMap.set(new Uint8Array([23, 60, 120, 1, 50, 24]), 255); + const mainArray = new Uint8Array([64, 23, 60, 120, 1, 50, 24, 34, 64]); + assert(binMap.get(mainArray, 4, 7) === 1); + assert(binMap.get(mainArray, 6, 9) === 2); + assert(binMap.get(mainArray, 1, 7) === 255); + assert(binMap.get(mainArray, 7, 7) === undefined); + assert(binMap.get(mainArray, 6, 10) === 2); + assert(binMap.get(mainArray, 0, 2) === 100); + done(); + }); +}); +suite("Binary Map Key Function Test", function () { + test("First 3 Max Bytes", done => { + const arr = new Uint8Array([0xFF, 0xFF, 0xFF, 0xAB, 0xCD, 0xEF]); + const result = binaryMapKey(arr, 0, arr.length); + assert.strictEqual(result, 0xEFCDABFFFFFF); + done(); + }); + + test("All 6 Max Bytes", done => { + const arr = new Uint8Array([0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]); + const result = binaryMapKey(arr, 0, arr.length); + assert.strictEqual(result, 0xFFFFFFFFFFFF); + done(); + }); + + test("First 3 Min Bytes", done => { + const arr = new Uint8Array([0x00, 0x00, 0x00, 0xAB, 0xCD, 0xEF]); + const result = binaryMapKey(arr, 0, arr.length); + assert.strictEqual(result, 0xEFCDAB000000); + done(); + }); + + test("Last 3 Min Bytes", done => { + const arr = new Uint8Array([0xAB, 0xCD, 0xEF, 0x00, 0x00, 0x00]); + const result = binaryMapKey(arr, 0, arr.length); + assert.strictEqual(result, 0x000000EFCDAB); + done(); + }); + + test("Assorted Bytes", done => { + const arr = new Uint8Array([0xBA, 0xDC, 0xFE, 0xEF, 0xCD, 0xAB]); + const result = binaryMapKey(arr, 0, arr.length); + assert.strictEqual(result, 0xABCDEFFEDCBA); + done(); + }); + + test("Assorted Bytes with start/end defined in lower bits", done => { + const arr = new Uint8Array([0xBA, 0xDC, 0xFE, 0xEF, 0xCD, 0xAB]); + const result = binaryMapKey(arr, 1, 3); + assert.strictEqual(result, 0x00000000FEDC); + done(); + }); + + test("Assorted Bytes with start/end defined in upper bits", done => { + const arr = new Uint8Array([0xBA, 0xDC, 0xFE, 0xEF, 0xCD, 0xAB]); + const result = binaryMapKey(arr, 3, 6); + assert.strictEqual(result, 0x000000ABCDEF); + done(); + }); + + test("Assorted Bytes with start/end defined across upper and lower bits", done => { + const arr = new Uint8Array([0xBA, 0xDC, 0xFE, 0xEF, 0xCD, 0xAB]); + const result = binaryMapKey(arr, 2, 5); + assert.strictEqual(result, 0x000000CDEFFE); + done(); + }); +});