Skip to content

Commit

Permalink
Merge branch 'main' into no-map
Browse files Browse the repository at this point in the history
  • Loading branch information
sbatten authored Apr 16, 2024
2 parents f3de338 + 40a2b85 commit 5de1f47
Show file tree
Hide file tree
Showing 9 changed files with 546 additions and 60 deletions.
4 changes: 3 additions & 1 deletion tokenizer_ts/.npmignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ dist/test/*
debug.ts
*.map
*.tiktoken
.eslintrc.js
.eslintrc.js
/perf/*
*.map
1 change: 1 addition & 0 deletions tokenizer_ts/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"scripts": {
"test": "mocha -u tdd --require ts-node/register test/**/*.ts",
"build": "tsc -p ./tsconfig.json",
"watch": "tsc -p ./tsconfig.json --watch",
"eslint": "eslint src --ext ts",
"format": "prettier --write \"./**/*.{ts,tsx}\""
},
Expand Down
3 changes: 3 additions & 0 deletions tokenizer_ts/perf/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package.json
package-lock.json
*.cpuprofile
65 changes: 65 additions & 0 deletions tokenizer_ts/perf/benchmark-folder.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
const fs = require('fs/promises');
const path = require('path');
const inspector = require('inspector');
const { promisify } = require('util');

const [,, encoderName, folderPath, method, modulePath] = process.argv;
const { createByEncoderName } = require(modulePath);
const minTime = 10_000;
const minCycles = 5;

const fileExtensions = ['.ts', '.js', '.py'];

async function readAllFilesInFolder(folderPath) {
const files = await fs.readdir(folderPath, { withFileTypes: true });
const fileContents = await Promise.all(files.map(async (file) => {
const res = path.resolve(folderPath, file.name);
if (file.isDirectory()) {
return readAllFilesInFolder(res);
} else if (fileExtensions.some(f => res.endsWith(f))) {
return fs.readFile(res, 'utf8');
} else {
return [];
}
}));

return fileContents.flat();
}

Promise.all([
readAllFilesInFolder(folderPath),
createByEncoderName(encoderName)
]).then(async ([files, tokenizer]) => {
let totalSize = 0;
for (const file of files) {
totalSize += file.length;
}

const session = new inspector.Session();
session.connect();
const post = promisify(session.post).bind(session);
await post('Profiler.enable');
await post('Profiler.start');

const start = performance.now();
let cycles = [];
while (performance.now() - start < minTime || cycles.length < minCycles) {
const cycleStart = performance.now();
switch (method) {
case 'encode':
files.forEach(file => tokenizer.encode(file));
break;
case 'encodeTrimSuffix':
files.forEach(file => tokenizer.encodeTrimSuffix(file, 1337));
break;
default:
throw new Error(`unknown method ${method}`);
}
cycles.push(performance.now() - cycleStart);
}

const data = await post('Profiler.stop');
await fs.writeFile('profile.cpuprofile', JSON.stringify(data.profile));

process.stdout.write(JSON.stringify({ totalSize, cycles }));
});
188 changes: 188 additions & 0 deletions tokenizer_ts/perf/notebook.ipynb

Large diffs are not rendered by default.

152 changes: 110 additions & 42 deletions tokenizer_ts/src/bytePairEncode.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,84 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

/**
* Convert a Uint8Array to a string
* @param uint8Array
* @returns string
*/
export function uint8ArrayToString(uint8Array: Uint8Array): string {
return Array.from(uint8Array)
.map(num => num.toString())
.join("_");
const enum Constant {
// we have 48 bits per level, we can safely bitwise encode 32 bits at a time,
// so this works in two passes
BytesPerLevel = 6,
}

// exported for testing
export const binaryMapKey = (k: Uint8Array, start: number, end: number): number => {
const length = end - start;

// 'lower' and 'upper' are both 24-bit integers, like
// 0xFF FF FF
// ^3 ^2 ^1
// If we say have a length of 2, we should disregard the last "3" byte, so we
// create a mask like
// 0x00 FF FF (started at 0xFF FF FF and shifted over by 8 bits)
// ^3 ^2 ^1
// so that we discard the data outside our range
const lowerMask = 0xFFFFFF >>> Math.max(0, (3 - length) * 8);
const lower = (k[start + 0] | (k[start + 1] << 8) | (k[start + 2] << 16)) & lowerMask;

const upperMask = 0xFFFFFF >>> Math.min(31, Math.max(0, (6 - length) * 8));
const upper = (k[start + 3] | (k[start + 4] << 8) | (k[start + 5] << 16)) & upperMask;
return lower + (0x1000000 * upper);
};

export class BinaryMap<V> {
private readonly map: Map<number, BinaryMap<V> | V> = new Map();
private thisValue?: V;

public get(key: Uint8Array, start: number = 0, end: number = key.length): V | undefined {
const value = this.map.get(binaryMapKey(key, start, end));
const isFinal = end < Constant.BytesPerLevel + start;

if (isFinal) {
return value instanceof BinaryMap ? value.thisValue : value;
} else if (value instanceof BinaryMap) {
return value.get(key, Constant.BytesPerLevel + start, end);
} else {
return undefined;
}
}

public set(key: Uint8Array, value: V): void {
const k = binaryMapKey(key, 0, key.length);
const existing = this.map.get(k);
const isFinal = key.length < Constant.BytesPerLevel;

if (existing === undefined) {
if (isFinal) {
this.map.set(k, value);
} else {
const newMap = new BinaryMap<V>();
newMap.set(key.subarray(Constant.BytesPerLevel), value);
this.map.set(k, newMap);
}
} else if (isFinal) {
if (existing instanceof BinaryMap) {
existing.thisValue = value;
} else {
this.map.set(k, value);
}
} else {
if (existing instanceof BinaryMap) {
existing.set(key.subarray(Constant.BytesPerLevel), value);
} else {
const newMap = new BinaryMap<V>();
newMap.set(key.subarray(Constant.BytesPerLevel), value);
newMap.thisValue = existing;
this.map.set(k, newMap);
}

}
}
}

const maxRank = 0x7FFFFFFF; // max int32, try and keep things in integer space

/**
* This function implements the byte pair encoding algorithm.
* @param mergingBytes: bytes to be merged
Expand All @@ -20,67 +87,68 @@ export function uint8ArrayToString(uint8Array: Uint8Array): string {
*/
export function bytePairEncode(
mergingBytes: Uint8Array,
ranks: ReadonlyMap<string, number>
ranks: BinaryMap<number>,
length: number,
): number[] {
if (mergingBytes.length === 1) {
return [ranks.get(mergingBytes[0].toString())!];
if (length === 1) {
return [ranks.get(mergingBytes)!];
}

let minRank = maxRank;
let minIndex = -1;

const byteIndicesAndRanks: [number, number][] = [];
for (let i = 0; i < mergingBytes.length + 1; i++) {
byteIndicesAndRanks.push([i, Number.MAX_SAFE_INTEGER]);
for (let i = 0; i < length - 1; i++) {
const rank = ranks.get(mergingBytes, i, i + 2) ?? maxRank;
if (rank < minRank) {
minRank = rank;
minIndex = i;
}

byteIndicesAndRanks.push([i, rank]);
}
byteIndicesAndRanks.push([length - 1, maxRank]);
byteIndicesAndRanks.push([length, maxRank]);

function getRank(startIndex: number, skip = 0): number {
if (startIndex + skip + 2 < byteIndicesAndRanks.length) {
const slice = mergingBytes.slice(
const rank = ranks.get(
mergingBytes,
byteIndicesAndRanks[startIndex][0],
byteIndicesAndRanks[startIndex + skip + 2][0]
);
const rank = ranks.get(uint8ArrayToString(slice));
if (rank !== undefined) {
return rank;
}
}
return Number.MAX_SAFE_INTEGER;
return maxRank;
}

for (let i = 0; i < byteIndicesAndRanks.length - 2; i++) {
const rank = getRank(i);
if (rank !== Number.MAX_SAFE_INTEGER) {
byteIndicesAndRanks[i][1] = rank;
while (minRank !== maxRank) {
byteIndicesAndRanks[minIndex][1] = getRank(minIndex, 1);
if (minIndex > 0) {
byteIndicesAndRanks[minIndex - 1][1] = getRank(minIndex - 1, 1);
}
}
byteIndicesAndRanks.splice(minIndex + 1, 1);


while (byteIndicesAndRanks.length > 1) {
let minRank: [number, number] = [0, Number.MAX_SAFE_INTEGER];
minIndex = -1;
minRank = maxRank;
for (let i = 0; i < byteIndicesAndRanks.length - 1; i++) {
if (byteIndicesAndRanks[i][1] < minRank[1]) {
minRank = [i, byteIndicesAndRanks[i][1]];
if (byteIndicesAndRanks[i][1] < minRank) {
minRank = byteIndicesAndRanks[i][1];
minIndex = i;
}
}
if (minRank[1] !== Number.MAX_SAFE_INTEGER) {
const j = minRank[0];
byteIndicesAndRanks[j][1] = getRank(j, 1);
if (j > 0) {
byteIndicesAndRanks[j - 1][1] = getRank(j - 1, 1);
}
byteIndicesAndRanks.splice(j + 1, 1);
} else {
break;
}
}

const outList: number[] = [];
for (let i = 0; i < byteIndicesAndRanks.length - 1; i++) {
outList.push(
ranks.get(
uint8ArrayToString(
mergingBytes.slice(
byteIndicesAndRanks[i][0],
byteIndicesAndRanks[i + 1][0]
)
)
mergingBytes,
byteIndicesAndRanks[i][0],
byteIndicesAndRanks[i + 1][0]
)!
);
}
Expand Down
56 changes: 56 additions & 0 deletions tokenizer_ts/src/textEncoder.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

/**
* A text encoder interface.
*/
export interface ITextEncoder {
/**
* Number of bytes written in the last call to {@link encode}
*/
length: number;

/**
* Encodes the text and returns the Uint8Array it was written to. The length
* of data written to the array can be found in {@link length}.
*
* The data returned in the array is only valid until the next call to encode.
*/
encode(text: string): Uint8Array;
}

class UniversalTextEncoder implements ITextEncoder {
public length = 0;
private encoder = new TextEncoder();

public encode(text: string): Uint8Array {
const arr = this.encoder.encode(text);
this.length = arr.length;
return arr;
}
}

class NodeTextEncoder implements ITextEncoder {
private buffer = Buffer.alloc(256);
public length = 0;

public encode(text: string): Uint8Array {
while (true) {
this.length = this.buffer.write(text, 'utf8');

// buffer.write returns the number of bytes written and can write less
// than the length of the string if the buffer is too small. If this
// might have happened (4 bytes is the longest utf8 codepoint), make
// the buffer bigger and try again.
if (this.length < this.buffer.length - 4) {
return this.buffer;
}

this.buffer = Buffer.alloc(this.length * 2);
this.length = this.buffer.write(text);
}
}
}

export const makeTextEncoder = (): ITextEncoder =>
typeof Buffer !== 'undefined' ? new NodeTextEncoder() : new UniversalTextEncoder();
Loading

0 comments on commit 5de1f47

Please sign in to comment.