From f7be3c60033b298fd5a81ba610e591f08ef49a1e Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Mon, 17 Apr 2023 18:40:26 -0700 Subject: [PATCH 01/20] Implement two methods for avoiding large arraybuffer --- tfjs-core/src/io/weights_loader.ts | 147 ++++++++++++++++++++++++++--- 1 file changed, 134 insertions(+), 13 deletions(-) diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index 8b0f923fd8a..d5690285951 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -212,24 +212,80 @@ export function weightsLoaderFactory( groupIndicesToFetch.forEach(i => { const numBuffers = manifest[i].paths.length; - let groupBytes = 0; - for (let i = 0; i < numBuffers; i++) { - groupBytes += buffers[bufferIndexOffset + i].byteLength; - } + // let groupBytes = 0; + // for (let i = 0; i < numBuffers; i++) { + // groupBytes += buffers[bufferIndexOffset + i].byteLength; + // } // Create a buffer for the whole group. - const groupBuffer = new ArrayBuffer(groupBytes); - const groupByteBuffer = new Uint8Array(groupBuffer); - let groupBufferOffset = 0; - for (let i = 0; i < numBuffers; i++) { - const buffer = new Uint8Array(buffers[bufferIndexOffset + i]); - groupByteBuffer.set(buffer, groupBufferOffset); - groupBufferOffset += buffer.byteLength; + // const groupBuffer = new ArrayBuffer(groupBytes); + // const groupByteBuffer = new Uint8Array(groupBuffer); + // let groupBufferOffset = 0; + // for (let i = 0; i < numBuffers; i++) { + // const buffer = new Uint8Array(buffers[bufferIndexOffset + i]); + // groupByteBuffer.set(buffer, groupBufferOffset); + // groupBufferOffset += buffer.byteLength; + // } + + + + const weightsEntries = [...groupWeightsToFetch[i]]; + weightsEntries.sort((a, b) => a.groupOffset - b.groupOffset); + + let bufferIndex = 0; + let precedingBytes = 0; + + function advanceTo(byteIndex: number) { + if (byteIndex < precedingBytes) { + throw new Error(`Buffer reader at ${precedingBytes} is already past ${byteIndex}`); + } + + for (let i = bufferIndex; i < numBuffers; i++) { + const buffer = buffers[bufferIndexOffset + i]; + const nextBytes = precedingBytes + buffer.byteLength; + if (nextBytes > byteIndex) { + return; + } + bufferIndex = i + 1; + precedingBytes = nextBytes; + } + throw new Error('Advanced past end'); + } + + function slice(start: number, end: number): ArrayBuffer { + advanceTo(start); + const size = end - start; + const outputBuffer = new ArrayBuffer(size); + const outputArray = new Uint8Array(outputBuffer); + let sliced = 0; + for (let i = bufferIndex; i < numBuffers; i++) { + const buffer = buffers[bufferIndexOffset + i]; + const nextBytes = precedingBytes + buffer.byteLength; + + const globalStart = start + sliced; + const localStart = globalStart - precedingBytes; + const outputStart = sliced; + + const globalEnd = Math.min(end, nextBytes); + const localEnd = globalEnd - precedingBytes; + // const outputEnd = outputStart + (localEnd - localStart); + + const outputSlice = new Uint8Array(buffer.slice(localStart, localEnd)); + sliced += outputSlice.length; + outputArray.set(outputSlice, outputStart); + + if (end < nextBytes) { + break; + } + + bufferIndex = i + 1; + precedingBytes = nextBytes; + } + return outputBuffer; } - const weightsEntries = groupWeightsToFetch[i]; weightsEntries.forEach(weightsEntry => { - const byteBuffer = groupBuffer.slice( + const byteBuffer = slice( weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes); const nameToTensorMap = @@ -245,3 +301,68 @@ export function weightsLoaderFactory( return weightsTensorMap; }; } + +class CompositeArrayBuffer { + private ranges: Array<{ + start: number, + end: number, + buffer: ArrayBuffer, + }> = []; + + constructor(buffers: ArrayBuffer[]) { + let start = 0; + for (const buffer of buffers) { + const end = start + buffer.byteLength; + this.ranges.push({buffer, start, end,}); + start = end; + }; + } + + get size() { + return this.ranges[this.ranges.length - 1].end; + } + + slice(start: number, end: number): ArrayBuffer { + if (start < 0 || start >= this.size) { + throw new Error(`Start position ${start} is outside range [0, ${this.size})`); + } + if (end < start) { + throw new Error('End must be greater than start'); + } + + const startRange = this.search(start); + + const size = end - start; + const outputBuffer = new ArrayBuffer(size); + const outputArray = new Uint8Array(outputBuffer); + let sliced = 0; + for (let i = startRange; i < this.ranges.length; i++) { + const range = this.ranges[i]; + + const globalStart = start + sliced; + const localStart = globalStart - range.start; + const outputStart = sliced; + + const globalEnd = Math.min(end, range.end); + const localEnd = globalEnd - range.start; + + const outputSlice = new Uint8Array(range.buffer.slice(localStart, localEnd)); + outputArray.set(outputSlice, outputStart); + + if (end < range.end) { + break; + } + } + return outputBuffer + } + private search(byteIndex: number) { + // TODO: Binsearch + const val = this.ranges.find(r => r.start <= byteIndex && r.end > byteIndex); + if (!val) { + throw new Error(`${byteIndex} not found in ranges`); + } + return this.ranges.indexOf(val); + } +} + +CompositeArrayBuffer; From 076c8a2b8b2f9f6ee2db8b2b80827e0480e9b7fc Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 18 Apr 2023 09:36:13 -0700 Subject: [PATCH 02/20] Use the CompositeArrayBuffer method --- tfjs-core/src/io/weights_loader.ts | 76 ++---------------------------- 1 file changed, 5 insertions(+), 71 deletions(-) diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index d5690285951..6ea57ecc986 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -212,80 +212,13 @@ export function weightsLoaderFactory( groupIndicesToFetch.forEach(i => { const numBuffers = manifest[i].paths.length; - // let groupBytes = 0; - // for (let i = 0; i < numBuffers; i++) { - // groupBytes += buffers[bufferIndexOffset + i].byteLength; - // } + const weightsBuffer = new CompositeArrayBuffer( + buffers.slice(bufferIndexOffset, bufferIndexOffset + numBuffers)); - // Create a buffer for the whole group. - // const groupBuffer = new ArrayBuffer(groupBytes); - // const groupByteBuffer = new Uint8Array(groupBuffer); - // let groupBufferOffset = 0; - // for (let i = 0; i < numBuffers; i++) { - // const buffer = new Uint8Array(buffers[bufferIndexOffset + i]); - // groupByteBuffer.set(buffer, groupBufferOffset); - // groupBufferOffset += buffer.byteLength; - // } - - - - const weightsEntries = [...groupWeightsToFetch[i]]; - weightsEntries.sort((a, b) => a.groupOffset - b.groupOffset); - - let bufferIndex = 0; - let precedingBytes = 0; - - function advanceTo(byteIndex: number) { - if (byteIndex < precedingBytes) { - throw new Error(`Buffer reader at ${precedingBytes} is already past ${byteIndex}`); - } - - for (let i = bufferIndex; i < numBuffers; i++) { - const buffer = buffers[bufferIndexOffset + i]; - const nextBytes = precedingBytes + buffer.byteLength; - if (nextBytes > byteIndex) { - return; - } - bufferIndex = i + 1; - precedingBytes = nextBytes; - } - throw new Error('Advanced past end'); - } - - function slice(start: number, end: number): ArrayBuffer { - advanceTo(start); - const size = end - start; - const outputBuffer = new ArrayBuffer(size); - const outputArray = new Uint8Array(outputBuffer); - let sliced = 0; - for (let i = bufferIndex; i < numBuffers; i++) { - const buffer = buffers[bufferIndexOffset + i]; - const nextBytes = precedingBytes + buffer.byteLength; - - const globalStart = start + sliced; - const localStart = globalStart - precedingBytes; - const outputStart = sliced; - - const globalEnd = Math.min(end, nextBytes); - const localEnd = globalEnd - precedingBytes; - // const outputEnd = outputStart + (localEnd - localStart); - - const outputSlice = new Uint8Array(buffer.slice(localStart, localEnd)); - sliced += outputSlice.length; - outputArray.set(outputSlice, outputStart); - - if (end < nextBytes) { - break; - } - - bufferIndex = i + 1; - precedingBytes = nextBytes; - } - return outputBuffer; - } + const weightsEntries = groupWeightsToFetch[i]; weightsEntries.forEach(weightsEntry => { - const byteBuffer = slice( + const byteBuffer = weightsBuffer.slice( weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes); const nameToTensorMap = @@ -348,6 +281,7 @@ class CompositeArrayBuffer { const outputSlice = new Uint8Array(range.buffer.slice(localStart, localEnd)); outputArray.set(outputSlice, outputStart); + sliced += outputSlice.length; if (end < range.end) { break; From bca64fd103b53b1207e5d60cc10e6b791f72dac1 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 18 Apr 2023 10:02:35 -0700 Subject: [PATCH 03/20] Implement binsearch --- tfjs-core/src/io/weights_loader.ts | 40 +++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index 6ea57ecc986..50be63aca0f 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -290,12 +290,40 @@ class CompositeArrayBuffer { return outputBuffer } private search(byteIndex: number) { - // TODO: Binsearch - const val = this.ranges.find(r => r.start <= byteIndex && r.end > byteIndex); - if (!val) { - throw new Error(`${byteIndex} not found in ranges`); - } - return this.ranges.indexOf(val); + return binsearch(this.ranges, (range) => { + if (byteIndex < range.start) { + return -1; + } + if (byteIndex >= range.end) { + return 1; + } + return 0; + }); + } +} + +/** + * Binary search on list. + * + * @param list The list to search + * @param compare A function to compare the current value against the searched + * value. Return 0 on a match, negative if the searched value is less than + * the value passed to the function, and positive if the searched value is + * greater than the value passed to the function. + */ +function binsearch(list: T[], compare: (t: T) => number, min = 0, max = list.length): number { + if (min > max) { + return -1; + } + + const middle = Math.floor((max - min) / 2); + const side = compare(list[middle]); + if (side === 0) { + return middle; + } else if (side < 0) { + return binsearch(list, compare, min, middle); + } else { + return binsearch(list, compare, middle + 1, max); } } From 0d26418baec559a3c9c9e54b48aace5c9a50b530 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 18 Apr 2023 11:13:01 -0700 Subject: [PATCH 04/20] Check the last used range first for efficiency --- tfjs-core/src/io/weights_loader.ts | 43 ++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index 50be63aca0f..c00cc4aed28 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -235,12 +235,15 @@ export function weightsLoaderFactory( }; } +type BufferRange = { + start: number, + end: number, + buffer: ArrayBuffer, +}; + class CompositeArrayBuffer { - private ranges: Array<{ - start: number, - end: number, - buffer: ArrayBuffer, - }> = []; + private ranges: BufferRange[] = []; + private lastSearchIndex = 0; constructor(buffers: ArrayBuffer[]) { let start = 0; @@ -248,7 +251,7 @@ class CompositeArrayBuffer { const end = start + buffer.byteLength; this.ranges.push({buffer, start, end,}); start = end; - }; + } } get size() { @@ -257,7 +260,8 @@ class CompositeArrayBuffer { slice(start: number, end: number): ArrayBuffer { if (start < 0 || start >= this.size) { - throw new Error(`Start position ${start} is outside range [0, ${this.size})`); + throw new Error(`Start position ${start} is outside range ` + + `[0, ${this.size})`); } if (end < start) { throw new Error('End must be greater than start'); @@ -279,7 +283,8 @@ class CompositeArrayBuffer { const globalEnd = Math.min(end, range.end); const localEnd = globalEnd - range.start; - const outputSlice = new Uint8Array(range.buffer.slice(localStart, localEnd)); + const outputSlice = new Uint8Array(range.buffer.slice(localStart, + localEnd)); outputArray.set(outputSlice, outputStart); sliced += outputSlice.length; @@ -287,10 +292,11 @@ class CompositeArrayBuffer { break; } } - return outputBuffer + return outputBuffer; } private search(byteIndex: number) { - return binsearch(this.ranges, (range) => { + // Check where the byteIndex lies relative to a range. + function check(range: BufferRange) { if (byteIndex < range.start) { return -1; } @@ -298,7 +304,17 @@ class CompositeArrayBuffer { return 1; } return 0; - }); + } + + // For efficiency, try the last searched range + if (check(this.ranges[this.lastSearchIndex]) === 0) { + return this.lastSearchIndex; + } + + // Otherwise, binsearch for the range. + this.lastSearchIndex = binsearch(this.ranges, check); + + return this.lastSearchIndex; } } @@ -311,7 +327,8 @@ class CompositeArrayBuffer { * the value passed to the function, and positive if the searched value is * greater than the value passed to the function. */ -function binsearch(list: T[], compare: (t: T) => number, min = 0, max = list.length): number { +function binsearch(list: T[], compare: (t: T) => number, min = 0, + max = list.length): number { if (min > max) { return -1; } @@ -326,5 +343,3 @@ function binsearch(list: T[], compare: (t: T) => number, min = 0, max = list. return binsearch(list, compare, middle + 1, max); } } - -CompositeArrayBuffer; From 0051fd74173ffe229926d725f7ac2156ea5b5ca2 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 18 Apr 2023 14:07:35 -0700 Subject: [PATCH 05/20] Optimize for when buffers have the same size --- tfjs-core/src/io/weights_loader.ts | 83 ++++++++++++++++++++++-------- 1 file changed, 62 insertions(+), 21 deletions(-) diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index c00cc4aed28..7500242b519 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -243,31 +243,51 @@ type BufferRange = { class CompositeArrayBuffer { private ranges: BufferRange[] = []; - private lastSearchIndex = 0; + private previousRangeIndex = 0; + private bufferUniformSize?: number; constructor(buffers: ArrayBuffer[]) { + if (buffers.length === 0) { + return; + } + + this.bufferUniformSize = buffers[0].byteLength; let start = 0; - for (const buffer of buffers) { + + for (let i = 0; i < buffers.length; i++) { + const buffer = buffers[i]; + // Check that all buffers except the last one have the same length. + if (i !== buffers.length - 1 && + buffer.byteLength !== this.bufferUniformSize) { + // Unset the buffer uniform size, since the buffer sizes are not + // uniform. + this.bufferUniformSize = undefined; + } + + // Create the ranges, including their start and end poionts. const end = start + buffer.byteLength; this.ranges.push({buffer, start, end,}); start = end; } } - get size() { + get byteLength() { + if (this.ranges.length === 0) { + return 0; + } return this.ranges[this.ranges.length - 1].end; } slice(start: number, end: number): ArrayBuffer { - if (start < 0 || start >= this.size) { + if (start < 0 || start >= this.byteLength) { throw new Error(`Start position ${start} is outside range ` + - `[0, ${this.size})`); + `[0, ${this.byteLength})`); } if (end < start) { throw new Error('End must be greater than start'); } - const startRange = this.search(start); + const startRange = this.findRangeForByte(start); const size = end - start; const outputBuffer = new ArrayBuffer(size); @@ -294,8 +314,25 @@ class CompositeArrayBuffer { } return outputBuffer; } - private search(byteIndex: number) { - // Check where the byteIndex lies relative to a range. + + /** + * Get the index of the range that contains the byte at `byteIndex`. + */ + private findRangeForByte(byteIndex: number): number { + if (this.ranges.length === 0 || byteIndex < 0 || + byteIndex >= this.byteLength) { + return -1; + } + + // If the buffers have a uniform size, compute the range directly. + if (this.bufferUniformSize != null) { + this.previousRangeIndex = Math.floor(byteIndex / this.bufferUniformSize); + return this.previousRangeIndex; + } + + // If the buffers don't have a uniform size, we need to search for the + // range. That means we need a function to check where the byteIndex lies + // relative to a given range. function check(range: BufferRange) { if (byteIndex < range.start) { return -1; @@ -306,40 +343,44 @@ class CompositeArrayBuffer { return 0; } - // For efficiency, try the last searched range - if (check(this.ranges[this.lastSearchIndex]) === 0) { - return this.lastSearchIndex; + // For efficiency, try the previous range first. + if (check(this.ranges[this.previousRangeIndex]) === 0) { + return this.previousRangeIndex; } - // Otherwise, binsearch for the range. - this.lastSearchIndex = binsearch(this.ranges, check); + // Otherwise, use a generic search function. + const index = search(this.ranges, check); + if (index === -1) { + return -1; + } - return this.lastSearchIndex; + this.previousRangeIndex = index; + return this.previousRangeIndex; } } /** - * Binary search on list. + * Search for an element of a sorted array. * - * @param list The list to search + * @param sortedArray The sorted array to search * @param compare A function to compare the current value against the searched * value. Return 0 on a match, negative if the searched value is less than * the value passed to the function, and positive if the searched value is * greater than the value passed to the function. */ -function binsearch(list: T[], compare: (t: T) => number, min = 0, - max = list.length): number { +function search(sortedArray: T[], compare: (t: T) => number, min = 0, + max = sortedArray.length): number { if (min > max) { return -1; } const middle = Math.floor((max - min) / 2); - const side = compare(list[middle]); + const side = compare(sortedArray[middle]); if (side === 0) { return middle; } else if (side < 0) { - return binsearch(list, compare, min, middle); + return search(sortedArray, compare, min, middle); } else { - return binsearch(list, compare, middle + 1, max); + return search(sortedArray, compare, middle + 1, max); } } From c5b4f8eadade6f76364910344bff11d4bcedef47 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 18 Apr 2023 14:09:59 -0700 Subject: [PATCH 06/20] Replace recursive binsearch with iterative --- tfjs-core/src/io/weights_loader.ts | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index 7500242b519..fc048f562c7 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -368,19 +368,21 @@ class CompositeArrayBuffer { * the value passed to the function, and positive if the searched value is * greater than the value passed to the function. */ -function search(sortedArray: T[], compare: (t: T) => number, min = 0, - max = sortedArray.length): number { - if (min > max) { - return -1; - } - - const middle = Math.floor((max - min) / 2); - const side = compare(sortedArray[middle]); - if (side === 0) { - return middle; - } else if (side < 0) { - return search(sortedArray, compare, min, middle); - } else { - return search(sortedArray, compare, middle + 1, max); +function search(sortedArray: T[], compare: (t: T) => number): number { + let min = 0; + let max = sortedArray.length; + + while (min <= max) { + const middle = Math.floor((max - min) / 2); + const side = compare(sortedArray[middle]); + + if (side === 0) { + return middle; + } else if (side < 0) { + max = middle; + } else { + min = middle + 1; + } } + return -1; } From 05a98315e7027ae6eb110c9804ef75d57072f94b Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 18 Apr 2023 14:34:57 -0700 Subject: [PATCH 07/20] Comments --- tfjs-core/src/io/weights_loader_test.ts | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tfjs-core/src/io/weights_loader_test.ts b/tfjs-core/src/io/weights_loader_test.ts index 759f3b72b18..e9afae0c226 100644 --- a/tfjs-core/src/io/weights_loader_test.ts +++ b/tfjs-core/src/io/weights_loader_test.ts @@ -175,6 +175,18 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { expect(weight2.dtype).toEqual('float32'); }); + // it('1 group, out of order weights manifest', async () => { + // const shard0 = new Float32Array([1, 2, 3, 4, 5]); + // const shard1 = new Float32Array([1.1, 2.2]); + // const shard2 = new Float32Array([10, 20, 30]); + + // setupFakeWeightFiles({ + // './weightfile0': shard0, + // './weightsfile1': shard1, + // './weightsfile2': shard2 + // }); + // }); + it('1 group, sharded 1 weight across multiple files', async () => { const shard0 = new Float32Array([1, 2, 3, 4, 5]); const shard1 = new Float32Array([1.1, 2.2]); From 0c4e5177c9b0736216b689b3fecaae2266c52e3a Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 18 Apr 2023 14:40:58 -0700 Subject: [PATCH 08/20] Remove commented code. Fix typo --- tfjs-core/src/io/weights_loader.ts | 5 ++++- tfjs-core/src/io/weights_loader_test.ts | 12 ------------ 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index fc048f562c7..18bc3f6aa1b 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -264,7 +264,7 @@ class CompositeArrayBuffer { this.bufferUniformSize = undefined; } - // Create the ranges, including their start and end poionts. + // Create the ranges, including their start and end points. const end = start + buffer.byteLength; this.ranges.push({buffer, start, end,}); start = end; @@ -349,6 +349,8 @@ class CompositeArrayBuffer { } // Otherwise, use a generic search function. + // This should almost never end up being used in practice since the weight + // entries should always be in order. const index = search(this.ranges, check); if (index === -1) { return -1; @@ -369,6 +371,7 @@ class CompositeArrayBuffer { * greater than the value passed to the function. */ function search(sortedArray: T[], compare: (t: T) => number): number { + // Binary search let min = 0; let max = sortedArray.length; diff --git a/tfjs-core/src/io/weights_loader_test.ts b/tfjs-core/src/io/weights_loader_test.ts index e9afae0c226..759f3b72b18 100644 --- a/tfjs-core/src/io/weights_loader_test.ts +++ b/tfjs-core/src/io/weights_loader_test.ts @@ -175,18 +175,6 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { expect(weight2.dtype).toEqual('float32'); }); - // it('1 group, out of order weights manifest', async () => { - // const shard0 = new Float32Array([1, 2, 3, 4, 5]); - // const shard1 = new Float32Array([1.1, 2.2]); - // const shard2 = new Float32Array([10, 20, 30]); - - // setupFakeWeightFiles({ - // './weightfile0': shard0, - // './weightsfile1': shard1, - // './weightsfile2': shard2 - // }); - // }); - it('1 group, sharded 1 weight across multiple files', async () => { const shard0 = new Float32Array([1, 2, 3, 4, 5]); const shard1 = new Float32Array([1.1, 2.2]); From a530b923948b45962f06c2560887a271bbda9f0b Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 18 Apr 2023 14:43:07 -0700 Subject: [PATCH 09/20] Add @returns annotation to search function --- tfjs-core/src/io/weights_loader.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index 18bc3f6aa1b..95f6cbb892d 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -369,6 +369,7 @@ class CompositeArrayBuffer { * value. Return 0 on a match, negative if the searched value is less than * the value passed to the function, and positive if the searched value is * greater than the value passed to the function. + * @returns The index of the element, or -1 if it's not in the array. */ function search(sortedArray: T[], compare: (t: T) => number): number { // Binary search From 0294b53ac0055f766e6fa5297b7c30d182a82baa Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Wed, 19 Apr 2023 13:18:57 -0700 Subject: [PATCH 10/20] Export and test CompositeArrayBuffer --- tfjs-core/src/io/weights_loader.ts | 27 ++++++----- tfjs-core/src/io/weights_loader_test.ts | 64 +++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 13 deletions(-) diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index 95f6cbb892d..d5bb9800dbb 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -241,12 +241,16 @@ type BufferRange = { buffer: ArrayBuffer, }; -class CompositeArrayBuffer { +export class CompositeArrayBuffer { private ranges: BufferRange[] = []; private previousRangeIndex = 0; private bufferUniformSize?: number; + public readonly byteLength: number; - constructor(buffers: ArrayBuffer[]) { + constructor(buffers: ArrayBuffer | ArrayBuffer[]) { + if (buffers instanceof ArrayBuffer) { + buffers = [buffers]; + } if (buffers.length === 0) { return; } @@ -269,22 +273,19 @@ class CompositeArrayBuffer { this.ranges.push({buffer, start, end,}); start = end; } - } - get byteLength() { + // Set the byteLenghth if (this.ranges.length === 0) { - return 0; + this.byteLength = 0; } - return this.ranges[this.ranges.length - 1].end; + this.byteLength = this.ranges[this.ranges.length - 1].end; } - slice(start: number, end: number): ArrayBuffer { - if (start < 0 || start >= this.byteLength) { - throw new Error(`Start position ${start} is outside range ` + - `[0, ${this.byteLength})`); - } - if (end < start) { - throw new Error('End must be greater than start'); + slice(start = 0, end = this.byteLength): ArrayBuffer { + start = Math.max(0, start); + end = Math.min(this.byteLength, end); + if (end <= start) { + return new ArrayBuffer(0); } const startRange = this.findRangeForByte(start); diff --git a/tfjs-core/src/io/weights_loader_test.ts b/tfjs-core/src/io/weights_loader_test.ts index 759f3b72b18..697dc34d78a 100644 --- a/tfjs-core/src/io/weights_loader_test.ts +++ b/tfjs-core/src/io/weights_loader_test.ts @@ -18,6 +18,7 @@ import * as tf from '../index'; import {BROWSER_ENVS, describeWithFlags} from '../jasmine_util'; import {expectArraysClose, expectArraysEqual} from '../test_util'; import {WeightsManifestConfig} from './types'; +import { CompositeArrayBuffer } from './weights_loader'; describeWithFlags('loadWeights', BROWSER_ENVS, () => { const setupFakeWeightFiles = (fileBufferMap: { @@ -542,3 +543,66 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { expect(weight2.dtype).toEqual('float32'); }); }); + +describe('CompositeArrayBuffer', () => { + const uniformBuffers = [ + new Uint8Array([0, 1, 2, 3]).buffer, + new Uint8Array([4, 5, 6, 7]).buffer, + new Uint8Array([8, 9, 10, 11]).buffer, + new Uint8Array([12, 13, 14, 15]).buffer, + new Uint8Array([16]).buffer, + ]; + + const nonUniformBuffers = [ + new Uint8Array([0, 1, 2]).buffer, + new Uint8Array([3, 4, 5, 6, 7]).buffer, + new Uint8Array([8, 9, 10, 11]).buffer, + new Uint8Array([12, 13, 14, 15, 16]).buffer, + ]; + + const bufferTestCases = [ + ['uniform', uniformBuffers], + ['non-uniform', nonUniformBuffers] + ] as const; + + for (const [buffersType, buffers] of bufferTestCases) { + let composite: CompositeArrayBuffer; + beforeEach(() => { + composite = new CompositeArrayBuffer(buffers); + }); + + it(`${buffersType}: slices across multiple buffers`, () => { + expectArraysEqual(new Uint8Array(composite.slice(1, 13)), + [1,2,3,4,5,6,7,8,9,10,11,12]); + }); + + it(`${buffersType}: slices to the end of the array when \'end\' is not ` + + 'specified', () => { + expectArraysEqual(new Uint8Array(composite.slice(5)), + [5,6,7,8,9,10,11,12,13,14,15,16]); + }); + + it(`${buffersType}: makes a copy when slice() is called with no arguments`, + () => { + expectArraysEqual(new Uint8Array(composite.slice()), + [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]); + }); + + it(`${buffersType}: can be created from a single array`, () => { + const singleComposite = new CompositeArrayBuffer(buffers[0]); + expectArraysEqual(new Uint8Array(singleComposite.slice()), + new Uint8Array(buffers[0])); + }); + + it(`${buffersType}: slices from zero when start is negative`, () => { + expectArraysEqual(new Uint8Array(composite.slice(-4, 5)), + [0,1,2,3,4]) + }); + + it(`${buffersType}: slices to the end when end is greater than length`, + () => { + expectArraysEqual(new Uint8Array(composite.slice(7, 1000)), + [7,8,9,10,11,12,13,14,15,16]); + }); + } +}); From f81638fbd59d15c581d4494d23ae8186ffcc0709 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Wed, 19 Apr 2023 14:51:00 -0700 Subject: [PATCH 11/20] Support NaN as a start or end to CompositeArray slice --- tfjs-core/src/io/weights_loader.ts | 14 +++++++++-- tfjs-core/src/io/weights_loader_test.ts | 31 +++++++++++++++++++------ 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index d5bb9800dbb..8ba5d3d1d42 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -282,19 +282,29 @@ export class CompositeArrayBuffer { } slice(start = 0, end = this.byteLength): ArrayBuffer { + // NaN is treated as zero for slicing. This matches ArrayBuffer's behavior. + start = isNaN(start) ? 0 : start; + end = isNaN(end) ? 0 : end; + + // Fix the bounds to within the array. start = Math.max(0, start); end = Math.min(this.byteLength, end); if (end <= start) { return new ArrayBuffer(0); } - const startRange = this.findRangeForByte(start); + const startRangeIndex = this.findRangeForByte(start); + if (startRangeIndex === -1) { + // This should not happen since the start and end indices are always + // within 0 and the composite array's length. + throw new Error(`Could not find start range for byte ${start}`); + } const size = end - start; const outputBuffer = new ArrayBuffer(size); const outputArray = new Uint8Array(outputBuffer); let sliced = 0; - for (let i = startRange; i < this.ranges.length; i++) { + for (let i = startRangeIndex; i < this.ranges.length; i++) { const range = this.ranges[i]; const globalStart = start + sliced; diff --git a/tfjs-core/src/io/weights_loader_test.ts b/tfjs-core/src/io/weights_loader_test.ts index 697dc34d78a..320393bc965 100644 --- a/tfjs-core/src/io/weights_loader_test.ts +++ b/tfjs-core/src/io/weights_loader_test.ts @@ -404,7 +404,6 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { it('throws if requested weight has unknown dtype', async () => { setupFakeWeightFiles({'./weightfile0': new Float32Array([1, 2, 3])}); - const manifest: WeightsManifestConfig = [{ 'paths': ['weightfile0'], 'weights': [{ @@ -588,12 +587,6 @@ describe('CompositeArrayBuffer', () => { [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]); }); - it(`${buffersType}: can be created from a single array`, () => { - const singleComposite = new CompositeArrayBuffer(buffers[0]); - expectArraysEqual(new Uint8Array(singleComposite.slice()), - new Uint8Array(buffers[0])); - }); - it(`${buffersType}: slices from zero when start is negative`, () => { expectArraysEqual(new Uint8Array(composite.slice(-4, 5)), [0,1,2,3,4]) @@ -605,4 +598,28 @@ describe('CompositeArrayBuffer', () => { [7,8,9,10,11,12,13,14,15,16]); }); } + + it('can be passed an empty arraybuffer', () => { + const array = new Uint8Array([]); + const singleComposite = new CompositeArrayBuffer(array.buffer); + expectArraysEqual(new Uint8Array(singleComposite.slice()), []); + }); + + it('can be created from a single array', () => { + const array = new Uint8Array([1,2,3]); + const singleComposite = new CompositeArrayBuffer(array.buffer); + expectArraysEqual(new Uint8Array(singleComposite.slice()), array); + }); + + it('treats NaN as zero when passed as the start of slice', () => { + const array = new Uint8Array([1,2,3]); + const composite = new CompositeArrayBuffer(array.buffer); + expectArraysEqual(new Uint8Array(composite.slice(NaN, 2)), [1,2]); + }); + + it('treats NaN as zero when passed as the end of slice', () => { + const array = new Uint8Array([1,2,3]); + const composite = new CompositeArrayBuffer(array.buffer); + expectArraysEqual(new Uint8Array(composite.slice(0, NaN)), []); + }); }); From 6a0754736b33075ffd0b47ea2c4fb8efc22f8465 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Wed, 19 Apr 2023 17:01:07 -0700 Subject: [PATCH 12/20] CompositeArrayBuffer support TypedArrays in constructor --- tfjs-core/src/io/weights_loader.ts | 14 ++++++++++++-- tfjs-core/src/io/weights_loader_test.ts | 8 ++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index 8ba5d3d1d42..c6cc0014445 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -18,6 +18,7 @@ import {env} from '../environment'; import {NamedTensorMap} from '../tensor_types'; +import {TypedArray} from '../types'; import * as util from '../util'; import {decodeWeights} from './io_utils'; import {monitorPromisesProgress} from './progress'; @@ -247,10 +248,19 @@ export class CompositeArrayBuffer { private bufferUniformSize?: number; public readonly byteLength: number; - constructor(buffers: ArrayBuffer | ArrayBuffer[]) { - if (buffers instanceof ArrayBuffer) { + constructor(buffers: ArrayBuffer | ArrayBuffer[] | TypedArray | TypedArray[]) { + // Normalize the `buffers` input to be `ArrayBuffer[]`. + if (!(buffers instanceof Array)) { buffers = [buffers]; } + buffers = buffers.map((bufferOrTypedArray) => { + if (util.isTypedArray(bufferOrTypedArray)) { + return bufferOrTypedArray.buffer; + } + return bufferOrTypedArray; + }); + + // Skip setting up ranges if there are no buffers. if (buffers.length === 0) { return; } diff --git a/tfjs-core/src/io/weights_loader_test.ts b/tfjs-core/src/io/weights_loader_test.ts index 320393bc965..48bc53a6949 100644 --- a/tfjs-core/src/io/weights_loader_test.ts +++ b/tfjs-core/src/io/weights_loader_test.ts @@ -622,4 +622,12 @@ describe('CompositeArrayBuffer', () => { const composite = new CompositeArrayBuffer(array.buffer); expectArraysEqual(new Uint8Array(composite.slice(0, NaN)), []); }); + + it('supports TypedArray input', () => { + // This support is necessary for some tests in tfjs-converter. Maybe those + // tests are misconfigured? + const array = new Uint8Array([1,2,3]); + const composite = new CompositeArrayBuffer(array); + expectArraysEqual(new Uint8Array(composite.slice(0, 2)), [1,2]); + }); }); From 61caf8c5c9cb6c9d73d2e281920310e47b44474b Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Wed, 19 Apr 2023 17:14:22 -0700 Subject: [PATCH 13/20] Formatting --- tfjs-core/src/io/weights_loader_test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-core/src/io/weights_loader_test.ts b/tfjs-core/src/io/weights_loader_test.ts index 48bc53a6949..c8f77a4302f 100644 --- a/tfjs-core/src/io/weights_loader_test.ts +++ b/tfjs-core/src/io/weights_loader_test.ts @@ -18,7 +18,7 @@ import * as tf from '../index'; import {BROWSER_ENVS, describeWithFlags} from '../jasmine_util'; import {expectArraysClose, expectArraysEqual} from '../test_util'; import {WeightsManifestConfig} from './types'; -import { CompositeArrayBuffer } from './weights_loader'; +import {CompositeArrayBuffer} from './weights_loader'; describeWithFlags('loadWeights', BROWSER_ENVS, () => { const setupFakeWeightFiles = (fileBufferMap: { From b1fe5eb639e57d92a600142637c64af6ca61603b Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Wed, 19 Apr 2023 17:16:18 -0700 Subject: [PATCH 14/20] Lint --- tfjs-core/src/io/weights_loader.ts | 3 ++- tfjs-core/src/io/weights_loader_test.ts | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index c6cc0014445..2062d388bcc 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -248,7 +248,8 @@ export class CompositeArrayBuffer { private bufferUniformSize?: number; public readonly byteLength: number; - constructor(buffers: ArrayBuffer | ArrayBuffer[] | TypedArray | TypedArray[]) { + constructor(buffers: ArrayBuffer | ArrayBuffer[] | TypedArray + | TypedArray[]) { // Normalize the `buffers` input to be `ArrayBuffer[]`. if (!(buffers instanceof Array)) { buffers = [buffers]; diff --git a/tfjs-core/src/io/weights_loader_test.ts b/tfjs-core/src/io/weights_loader_test.ts index c8f77a4302f..93f00c8b25b 100644 --- a/tfjs-core/src/io/weights_loader_test.ts +++ b/tfjs-core/src/io/weights_loader_test.ts @@ -589,7 +589,7 @@ describe('CompositeArrayBuffer', () => { it(`${buffersType}: slices from zero when start is negative`, () => { expectArraysEqual(new Uint8Array(composite.slice(-4, 5)), - [0,1,2,3,4]) + [0,1,2,3,4]); }); it(`${buffersType}: slices to the end when end is greater than length`, From b86ee6e7b52079f5454e3545cf6a6f87b4f0deb2 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Wed, 19 Apr 2023 17:30:37 -0700 Subject: [PATCH 15/20] Fix slicing out of order --- tfjs-core/src/io/weights_loader.ts | 2 +- tfjs-core/src/io/weights_loader_test.ts | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index 2062d388bcc..e648612c049 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -399,7 +399,7 @@ function search(sortedArray: T[], compare: (t: T) => number): number { let max = sortedArray.length; while (min <= max) { - const middle = Math.floor((max - min) / 2); + const middle = Math.floor((max - min) / 2) + min; const side = compare(sortedArray[middle]); if (side === 0) { diff --git a/tfjs-core/src/io/weights_loader_test.ts b/tfjs-core/src/io/weights_loader_test.ts index 93f00c8b25b..cde09ef5bbe 100644 --- a/tfjs-core/src/io/weights_loader_test.ts +++ b/tfjs-core/src/io/weights_loader_test.ts @@ -597,6 +597,12 @@ describe('CompositeArrayBuffer', () => { expectArraysEqual(new Uint8Array(composite.slice(7, 1000)), [7,8,9,10,11,12,13,14,15,16]); }); + + it(`${buffersType}: slices multiple ranges out of order`, () => { + expectArraysEqual(new Uint8Array(composite.slice(13, 15)), [13, 14]); + expectArraysEqual(new Uint8Array(composite.slice(0, 2)), [0, 1]); + expectArraysEqual(new Uint8Array(composite.slice(9, 13)), [9, 10, 11, 12]); + }); } it('can be passed an empty arraybuffer', () => { From 7e88a8971f008a0c5284509fecd2c2f7ece5b0b6 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Thu, 20 Apr 2023 10:17:02 -0700 Subject: [PATCH 16/20] fix lint --- tfjs-core/src/io/weights_loader_test.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tfjs-core/src/io/weights_loader_test.ts b/tfjs-core/src/io/weights_loader_test.ts index cde09ef5bbe..62e3bf0f2e3 100644 --- a/tfjs-core/src/io/weights_loader_test.ts +++ b/tfjs-core/src/io/weights_loader_test.ts @@ -601,7 +601,8 @@ describe('CompositeArrayBuffer', () => { it(`${buffersType}: slices multiple ranges out of order`, () => { expectArraysEqual(new Uint8Array(composite.slice(13, 15)), [13, 14]); expectArraysEqual(new Uint8Array(composite.slice(0, 2)), [0, 1]); - expectArraysEqual(new Uint8Array(composite.slice(9, 13)), [9, 10, 11, 12]); + expectArraysEqual(new Uint8Array(composite.slice(9, 13)), + [9, 10, 11, 12]); }); } From f2fa6f89e415c7bbb97c09d8ef3bfe251f3dadb4 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Thu, 20 Apr 2023 10:32:38 -0700 Subject: [PATCH 17/20] Document CompositeArrayBuffer --- tfjs-core/src/io/weights_loader.ts | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index e648612c049..59a5fd02903 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -242,6 +242,16 @@ type BufferRange = { buffer: ArrayBuffer, }; +/** + * Wraps a list of ArrayBuffers into a `slice()`-able object without allocating + * a large ArrayBuffer. + * + * Allocating large ArrayBuffers (~2GB) can be unstable on Chrome. TFJS loads + * its weights as a list of (usually) 4MB ArrayBuffers and then slices the + * weight tensors out of them. For small models, it's safe to concatenate all + * the weight buffers into a single ArrayBuffer and then slice the weight + * tensors out of it, but for large models, a different approach is needed. + */ export class CompositeArrayBuffer { private ranges: BufferRange[] = []; private previousRangeIndex = 0; From 91db31aa56f2b5f47e8da534bcd7d8fa619c3027 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Thu, 20 Apr 2023 10:38:42 -0700 Subject: [PATCH 18/20] Rename range -> shard --- tfjs-core/src/io/weights_loader.ts | 74 +++++++++++++++--------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index 59a5fd02903..487a082e2c4 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -236,7 +236,7 @@ export function weightsLoaderFactory( }; } -type BufferRange = { +type BufferShard = { start: number, end: number, buffer: ArrayBuffer, @@ -253,8 +253,8 @@ type BufferRange = { * tensors out of it, but for large models, a different approach is needed. */ export class CompositeArrayBuffer { - private ranges: BufferRange[] = []; - private previousRangeIndex = 0; + private shards: BufferShard[] = []; + private previousShardIndex = 0; private bufferUniformSize?: number; public readonly byteLength: number; @@ -271,7 +271,7 @@ export class CompositeArrayBuffer { return bufferOrTypedArray; }); - // Skip setting up ranges if there are no buffers. + // Skip setting up shards if there are no buffers. if (buffers.length === 0) { return; } @@ -289,23 +289,23 @@ export class CompositeArrayBuffer { this.bufferUniformSize = undefined; } - // Create the ranges, including their start and end points. + // Create the shards, including their start and end points. const end = start + buffer.byteLength; - this.ranges.push({buffer, start, end,}); + this.shards.push({buffer, start, end}); start = end; } // Set the byteLenghth - if (this.ranges.length === 0) { + if (this.shards.length === 0) { this.byteLength = 0; } - this.byteLength = this.ranges[this.ranges.length - 1].end; + this.byteLength = this.shards[this.shards.length - 1].end; } slice(start = 0, end = this.byteLength): ArrayBuffer { // NaN is treated as zero for slicing. This matches ArrayBuffer's behavior. - start = isNaN(start) ? 0 : start; - end = isNaN(end) ? 0 : end; + start = isNaN(Number(start)) ? 0 : start; + end = isNaN(Number(end)) ? 0 : end; // Fix the bounds to within the array. start = Math.max(0, start); @@ -314,33 +314,33 @@ export class CompositeArrayBuffer { return new ArrayBuffer(0); } - const startRangeIndex = this.findRangeForByte(start); - if (startRangeIndex === -1) { + const startShardIndex = this.findShardForByte(start); + if (startShardIndex === -1) { // This should not happen since the start and end indices are always // within 0 and the composite array's length. - throw new Error(`Could not find start range for byte ${start}`); + throw new Error(`Could not find start shard for byte ${start}`); } const size = end - start; const outputBuffer = new ArrayBuffer(size); const outputArray = new Uint8Array(outputBuffer); let sliced = 0; - for (let i = startRangeIndex; i < this.ranges.length; i++) { - const range = this.ranges[i]; + for (let i = startShardIndex; i < this.shards.length; i++) { + const shard = this.shards[i]; const globalStart = start + sliced; - const localStart = globalStart - range.start; + const localStart = globalStart - shard.start; const outputStart = sliced; - const globalEnd = Math.min(end, range.end); - const localEnd = globalEnd - range.start; + const globalEnd = Math.min(end, shard.end); + const localEnd = globalEnd - shard.start; - const outputSlice = new Uint8Array(range.buffer.slice(localStart, + const outputSlice = new Uint8Array(shard.buffer.slice(localStart, localEnd)); outputArray.set(outputSlice, outputStart); sliced += outputSlice.length; - if (end < range.end) { + if (end < shard.end) { break; } } @@ -348,48 +348,48 @@ export class CompositeArrayBuffer { } /** - * Get the index of the range that contains the byte at `byteIndex`. + * Get the index of the shard that contains the byte at `byteIndex`. */ - private findRangeForByte(byteIndex: number): number { - if (this.ranges.length === 0 || byteIndex < 0 || + private findShardForByte(byteIndex: number): number { + if (this.shards.length === 0 || byteIndex < 0 || byteIndex >= this.byteLength) { return -1; } - // If the buffers have a uniform size, compute the range directly. + // If the buffers have a uniform size, compute the shard directly. if (this.bufferUniformSize != null) { - this.previousRangeIndex = Math.floor(byteIndex / this.bufferUniformSize); - return this.previousRangeIndex; + this.previousShardIndex = Math.floor(byteIndex / this.bufferUniformSize); + return this.previousShardIndex; } // If the buffers don't have a uniform size, we need to search for the - // range. That means we need a function to check where the byteIndex lies - // relative to a given range. - function check(range: BufferRange) { - if (byteIndex < range.start) { + // shard. That means we need a function to check where the byteIndex lies + // relative to a given shard. + function check(shard: BufferShard) { + if (byteIndex < shard.start) { return -1; } - if (byteIndex >= range.end) { + if (byteIndex >= shard.end) { return 1; } return 0; } - // For efficiency, try the previous range first. - if (check(this.ranges[this.previousRangeIndex]) === 0) { - return this.previousRangeIndex; + // For efficiency, try the previous shard first. + if (check(this.shards[this.previousShardIndex]) === 0) { + return this.previousShardIndex; } // Otherwise, use a generic search function. // This should almost never end up being used in practice since the weight // entries should always be in order. - const index = search(this.ranges, check); + const index = search(this.shards, check); if (index === -1) { return -1; } - this.previousRangeIndex = index; - return this.previousRangeIndex; + this.previousShardIndex = index; + return this.previousShardIndex; } } From 26c51df9efd22e455013a7713e2cd301c23bef0c Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Thu, 20 Apr 2023 10:45:14 -0700 Subject: [PATCH 19/20] Move CompositeArrayBuffer to a new file --- tfjs-core/src/io/composite_array_buffer.ts | 190 +++++++++++++ .../src/io/composite_array_buffer_test.ts | 114 ++++++++ tfjs-core/src/io/weights_loader.ts | 269 +++--------------- tfjs-core/src/io/weights_loader_test.ts | 97 ------- 4 files changed, 345 insertions(+), 325 deletions(-) create mode 100644 tfjs-core/src/io/composite_array_buffer.ts create mode 100644 tfjs-core/src/io/composite_array_buffer_test.ts diff --git a/tfjs-core/src/io/composite_array_buffer.ts b/tfjs-core/src/io/composite_array_buffer.ts new file mode 100644 index 00000000000..5dceda75ad1 --- /dev/null +++ b/tfjs-core/src/io/composite_array_buffer.ts @@ -0,0 +1,190 @@ +import {TypedArray} from '../types'; +import * as util from '../util'; + +type BufferShard = { + start: number, + end: number, + buffer: ArrayBuffer, +}; + +/** + * Wraps a list of ArrayBuffers into a `slice()`-able object without allocating + * a large ArrayBuffer. + * + * Allocating large ArrayBuffers (~2GB) can be unstable on Chrome. TFJS loads + * its weights as a list of (usually) 4MB ArrayBuffers and then slices the + * weight tensors out of them. For small models, it's safe to concatenate all + * the weight buffers into a single ArrayBuffer and then slice the weight + * tensors out of it, but for large models, a different approach is needed. + */ + +export class CompositeArrayBuffer { + private shards: BufferShard[] = []; + private previousShardIndex = 0; + private bufferUniformSize?: number; + public readonly byteLength: number; + + constructor(buffers: ArrayBuffer | ArrayBuffer[] | TypedArray | + TypedArray[]) { + // Normalize the `buffers` input to be `ArrayBuffer[]`. + if (!(buffers instanceof Array)) { + buffers = [buffers]; + } + buffers = buffers.map((bufferOrTypedArray) => { + if (util.isTypedArray(bufferOrTypedArray)) { + return bufferOrTypedArray.buffer; + } + return bufferOrTypedArray; + }); + + // Skip setting up shards if there are no buffers. + if (buffers.length === 0) { + return; + } + + this.bufferUniformSize = buffers[0].byteLength; + let start = 0; + + for (let i = 0; i < buffers.length; i++) { + const buffer = buffers[i]; + // Check that all buffers except the last one have the same length. + if (i !== buffers.length - 1 && + buffer.byteLength !== this.bufferUniformSize) { + // Unset the buffer uniform size, since the buffer sizes are not + // uniform. + this.bufferUniformSize = undefined; + } + + // Create the shards, including their start and end points. + const end = start + buffer.byteLength; + this.shards.push({ buffer, start, end }); + start = end; + } + + // Set the byteLenghth + if (this.shards.length === 0) { + this.byteLength = 0; + } + this.byteLength = this.shards[this.shards.length - 1].end; + } + + slice(start = 0, end = this.byteLength): ArrayBuffer { + // NaN is treated as zero for slicing. This matches ArrayBuffer's behavior. + start = isNaN(Number(start)) ? 0 : start; + end = isNaN(Number(end)) ? 0 : end; + + // Fix the bounds to within the array. + start = Math.max(0, start); + end = Math.min(this.byteLength, end); + if (end <= start) { + return new ArrayBuffer(0); + } + + const startShardIndex = this.findShardForByte(start); + if (startShardIndex === -1) { + // This should not happen since the start and end indices are always + // within 0 and the composite array's length. + throw new Error(`Could not find start shard for byte ${start}`); + } + + const size = end - start; + const outputBuffer = new ArrayBuffer(size); + const outputArray = new Uint8Array(outputBuffer); + let sliced = 0; + for (let i = startShardIndex; i < this.shards.length; i++) { + const shard = this.shards[i]; + + const globalStart = start + sliced; + const localStart = globalStart - shard.start; + const outputStart = sliced; + + const globalEnd = Math.min(end, shard.end); + const localEnd = globalEnd - shard.start; + + const outputSlice = new Uint8Array(shard.buffer.slice(localStart, + localEnd)); + outputArray.set(outputSlice, outputStart); + sliced += outputSlice.length; + + if (end < shard.end) { + break; + } + } + return outputBuffer; + } + + /** + * Get the index of the shard that contains the byte at `byteIndex`. + */ + private findShardForByte(byteIndex: number): number { + if (this.shards.length === 0 || byteIndex < 0 || + byteIndex >= this.byteLength) { + return -1; + } + + // If the buffers have a uniform size, compute the shard directly. + if (this.bufferUniformSize != null) { + this.previousShardIndex = Math.floor(byteIndex / this.bufferUniformSize); + return this.previousShardIndex; + } + + // If the buffers don't have a uniform size, we need to search for the + // shard. That means we need a function to check where the byteIndex lies + // relative to a given shard. + function check(shard: BufferShard) { + if (byteIndex < shard.start) { + return -1; + } + if (byteIndex >= shard.end) { + return 1; + } + return 0; + } + + // For efficiency, try the previous shard first. + if (check(this.shards[this.previousShardIndex]) === 0) { + return this.previousShardIndex; + } + + // Otherwise, use a generic search function. + // This should almost never end up being used in practice since the weight + // entries should always be in order. + const index = search(this.shards, check); + if (index === -1) { + return -1; + } + + this.previousShardIndex = index; + return this.previousShardIndex; + } +} + +/** + * Search for an element of a sorted array. + * + * @param sortedArray The sorted array to search + * @param compare A function to compare the current value against the searched + * value. Return 0 on a match, negative if the searched value is less than + * the value passed to the function, and positive if the searched value is + * greater than the value passed to the function. + * @returns The index of the element, or -1 if it's not in the array. + */ +export function search(sortedArray: T[], compare: (t: T) => number): number { + // Binary search + let min = 0; + let max = sortedArray.length; + + while (min <= max) { + const middle = Math.floor((max - min) / 2) + min; + const side = compare(sortedArray[middle]); + + if (side === 0) { + return middle; + } else if (side < 0) { + max = middle; + } else { + min = middle + 1; + } + } + return -1; +} diff --git a/tfjs-core/src/io/composite_array_buffer_test.ts b/tfjs-core/src/io/composite_array_buffer_test.ts new file mode 100644 index 00000000000..fa64532ad8c --- /dev/null +++ b/tfjs-core/src/io/composite_array_buffer_test.ts @@ -0,0 +1,114 @@ +/** + * @license + * Copyright 2023 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {expectArraysEqual} from '../test_util'; +import {CompositeArrayBuffer} from './composite_array_buffer'; + +describe('CompositeArrayBuffer', () => { + const uniformBuffers = [ + new Uint8Array([0, 1, 2, 3]).buffer, + new Uint8Array([4, 5, 6, 7]).buffer, + new Uint8Array([8, 9, 10, 11]).buffer, + new Uint8Array([12, 13, 14, 15]).buffer, + new Uint8Array([16]).buffer, + ]; + + const nonUniformBuffers = [ + new Uint8Array([0, 1, 2]).buffer, + new Uint8Array([3, 4, 5, 6, 7]).buffer, + new Uint8Array([8, 9, 10, 11]).buffer, + new Uint8Array([12, 13, 14, 15, 16]).buffer, + ]; + + const bufferTestCases = [ + ['uniform', uniformBuffers], + ['non-uniform', nonUniformBuffers] + ] as const; + + for (const [buffersType, buffers] of bufferTestCases) { + let composite: CompositeArrayBuffer; + beforeEach(() => { + composite = new CompositeArrayBuffer(buffers); + }); + + it(`${buffersType}: slices across multiple buffers`, () => { + expectArraysEqual(new Uint8Array(composite.slice(1, 13)), + [1,2,3,4,5,6,7,8,9,10,11,12]); + }); + + it(`${buffersType}: slices to the end of the array when \'end\' is not ` + + 'specified', () => { + expectArraysEqual(new Uint8Array(composite.slice(5)), + [5,6,7,8,9,10,11,12,13,14,15,16]); + }); + + it(`${buffersType}: makes a copy when slice() is called with no arguments`, + () => { + expectArraysEqual(new Uint8Array(composite.slice()), + [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]); + }); + + it(`${buffersType}: slices from zero when start is negative`, () => { + expectArraysEqual(new Uint8Array(composite.slice(-4, 5)), + [0,1,2,3,4]); + }); + + it(`${buffersType}: slices to the end when end is greater than length`, + () => { + expectArraysEqual(new Uint8Array(composite.slice(7, 1000)), + [7,8,9,10,11,12,13,14,15,16]); + }); + + it(`${buffersType}: slices multiple ranges out of order`, () => { + expectArraysEqual(new Uint8Array(composite.slice(13, 15)), [13, 14]); + expectArraysEqual(new Uint8Array(composite.slice(0, 2)), [0, 1]); + expectArraysEqual(new Uint8Array(composite.slice(9, 13)), + [9, 10, 11, 12]); + }); + } + + it('can be passed an empty arraybuffer', () => { + const array = new Uint8Array([]); + const singleComposite = new CompositeArrayBuffer(array.buffer); + expectArraysEqual(new Uint8Array(singleComposite.slice()), []); + }); + + it('can be created from a single array', () => { + const array = new Uint8Array([1,2,3]); + const singleComposite = new CompositeArrayBuffer(array.buffer); + expectArraysEqual(new Uint8Array(singleComposite.slice()), array); + }); + + it('treats NaN as zero when passed as the start of slice', () => { + const array = new Uint8Array([1,2,3]); + const composite = new CompositeArrayBuffer(array.buffer); + expectArraysEqual(new Uint8Array(composite.slice(NaN, 2)), [1,2]); + }); + + it('treats NaN as zero when passed as the end of slice', () => { + const array = new Uint8Array([1,2,3]); + const composite = new CompositeArrayBuffer(array.buffer); + expectArraysEqual(new Uint8Array(composite.slice(0, NaN)), []); + }); + + it('supports TypedArray input', () => { + // This support is necessary for some tests in tfjs-converter. Maybe those + // tests are misconfigured? + const array = new Uint8Array([1,2,3]); + const composite = new CompositeArrayBuffer(array); + expectArraysEqual(new Uint8Array(composite.slice(0, 2)), [1,2]); + }); +}); diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index 487a082e2c4..8ad0ef2f85b 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -18,8 +18,8 @@ import {env} from '../environment'; import {NamedTensorMap} from '../tensor_types'; -import {TypedArray} from '../types'; import * as util from '../util'; +import {CompositeArrayBuffer} from './composite_array_buffer'; import {decodeWeights} from './io_utils'; import {monitorPromisesProgress} from './progress'; import {DTYPE_VALUE_SIZE_MAP, LoadOptions, WeightsManifestConfig, WeightsManifestEntry} from './types'; @@ -36,27 +36,27 @@ import {DTYPE_VALUE_SIZE_MAP, LoadOptions, WeightsManifestConfig, WeightsManifes * length as `fetchURLs`. */ export async function loadWeightsAsArrayBuffer( - fetchURLs: string[], loadOptions?: LoadOptions): Promise { + fetchURLs: string[], loadOptions?: LoadOptions): Promise { if (loadOptions == null) { loadOptions = {}; } const fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch : - loadOptions.fetchFunc; + loadOptions.fetchFunc; // Create the requests for all of the weights in parallel. const requests = fetchURLs.map( - fetchURL => - fetchFunc(fetchURL, loadOptions.requestInit, {isBinary: true})); + fetchURL => + fetchFunc(fetchURL, loadOptions.requestInit, { isBinary: true })); const fetchStartFraction = 0; const fetchEndFraction = 0.5; const responses = loadOptions.onProgress == null ? - await Promise.all(requests) : - await monitorPromisesProgress( - requests, loadOptions.onProgress, fetchStartFraction, - fetchEndFraction); + await Promise.all(requests) : + await monitorPromisesProgress( + requests, loadOptions.onProgress, fetchStartFraction, + fetchEndFraction); const bufferPromises = responses.map(response => response.arrayBuffer()); @@ -64,10 +64,10 @@ export async function loadWeightsAsArrayBuffer( const bufferEndFraction = 1; const buffers = loadOptions.onProgress == null ? - await Promise.all(bufferPromises) : - await monitorPromisesProgress( - bufferPromises, loadOptions.onProgress, bufferStartFraction, - bufferEndFraction); + await Promise.all(bufferPromises) : + await monitorPromisesProgress( + bufferPromises, loadOptions.onProgress, bufferStartFraction, + bufferEndFraction); return buffers; } @@ -81,9 +81,9 @@ export async function loadWeightsAsArrayBuffer( * @param weightNames The names of the weights to be fetched. */ export async function loadWeights( - manifest: WeightsManifestConfig, filePathPrefix = '', - weightNames?: string[], - requestInit?: RequestInit): Promise { + manifest: WeightsManifestConfig, filePathPrefix = '', + weightNames?: string[], + requestInit?: RequestInit): Promise { // TODO(nsthorat): Groups are currently fetched atomically. If you need a // single weight from a group, the whole group will be fetched. At a future // date, we should support fetching only the individual shards within a @@ -91,7 +91,7 @@ export async function loadWeights( // TODO(cais): Use `decodeWeights` for implementation. const fetchWeights = (fetchUrls: string[]) => - loadWeightsAsArrayBuffer(fetchUrls, {requestInit}); + loadWeightsAsArrayBuffer(fetchUrls, { requestInit }); const loadWeights = weightsLoaderFactory(fetchWeights); return loadWeights(manifest, filePathPrefix, weightNames); @@ -122,12 +122,12 @@ export async function loadWeights( * @returns Weight loading function. */ export function weightsLoaderFactory( - fetchWeightsFunction: (fetchUrls: string[]) => Promise): - (manifest: WeightsManifestConfig, filePathPrefix?: string, - weightNames?: string[]) => Promise { - return async( - manifest: WeightsManifestConfig, filePathPrefix = '', - weightNames?: string[]): Promise => { + fetchWeightsFunction: (fetchUrls: string[]) => Promise): + (manifest: WeightsManifestConfig, filePathPrefix?: string, + weightNames?: string[]) => Promise { + return async ( + manifest: WeightsManifestConfig, filePathPrefix = '', + weightNames?: string[]): Promise => { // Collect all the groups, weights, and their relative offsets to be // fetched. const groupIndicesToFetchMap = manifest.map(() => false); @@ -138,17 +138,17 @@ export function weightsLoaderFactory( }> } = {}; const weightsFound = - weightNames != null ? weightNames.map(() => false) : []; + weightNames != null ? weightNames.map(() => false) : []; const allManifestWeightNames: string[] = []; manifest.forEach((manifestGroupConfig, groupIndex) => { let groupOffset = 0; manifestGroupConfig.weights.forEach(weightsEntry => { const rawDtype = ('quantization' in weightsEntry) ? - weightsEntry.quantization.dtype : - weightsEntry.dtype; + weightsEntry.quantization.dtype : + weightsEntry.dtype; const weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] * - util.sizeFromShape(weightsEntry.shape); + util.sizeFromShape(weightsEntry.shape); const enqueueWeightsForFetchingFn = () => { groupIndicesToFetchMap[groupIndex] = true; @@ -182,27 +182,27 @@ export function weightsLoaderFactory( if (!weightsFound.every(found => found)) { const weightsNotFound = weightNames.filter((_, i) => !weightsFound[i]); throw new Error( - `Could not find weights in manifest with names: ` + - `${weightsNotFound.join(', ')}. \n` + - `Manifest JSON has weights with names: ` + - `${allManifestWeightNames.join(', ')}.`); + `Could not find weights in manifest with names: ` + + `${weightsNotFound.join(', ')}. \n` + + `Manifest JSON has weights with names: ` + + `${allManifestWeightNames.join(', ')}.`); } // Convert the one-hot boolean groupId => shouldFetch map to a list of group // IDs. const groupIndicesToFetch = - groupIndicesToFetchMap.reduce((accumulator, shouldFetch, i) => { - if (shouldFetch) { - accumulator.push(i); - } - return accumulator; - }, []); + groupIndicesToFetchMap.reduce((accumulator, shouldFetch, i) => { + if (shouldFetch) { + accumulator.push(i); + } + return accumulator; + }, []); const fetchUrls: string[] = []; groupIndicesToFetch.forEach(i => { manifest[i].paths.forEach(filepath => { const fetchUrl = filePathPrefix + - (!filePathPrefix.endsWith('/') ? '/' : '') + filepath; + (!filePathPrefix.endsWith('/') ? '/' : '') + filepath; fetchUrls.push(fetchUrl); }); }); @@ -220,10 +220,10 @@ export function weightsLoaderFactory( weightsEntries.forEach(weightsEntry => { const byteBuffer = weightsBuffer.slice( - weightsEntry.groupOffset, - weightsEntry.groupOffset + weightsEntry.sizeBytes); + weightsEntry.groupOffset, + weightsEntry.groupOffset + weightsEntry.sizeBytes); const nameToTensorMap = - decodeWeights(byteBuffer, [weightsEntry.manifestEntry]); + decodeWeights(byteBuffer, [weightsEntry.manifestEntry]); for (const name in nameToTensorMap) { weightsTensorMap[name] = nameToTensorMap[name]; } @@ -235,190 +235,3 @@ export function weightsLoaderFactory( return weightsTensorMap; }; } - -type BufferShard = { - start: number, - end: number, - buffer: ArrayBuffer, -}; - -/** - * Wraps a list of ArrayBuffers into a `slice()`-able object without allocating - * a large ArrayBuffer. - * - * Allocating large ArrayBuffers (~2GB) can be unstable on Chrome. TFJS loads - * its weights as a list of (usually) 4MB ArrayBuffers and then slices the - * weight tensors out of them. For small models, it's safe to concatenate all - * the weight buffers into a single ArrayBuffer and then slice the weight - * tensors out of it, but for large models, a different approach is needed. - */ -export class CompositeArrayBuffer { - private shards: BufferShard[] = []; - private previousShardIndex = 0; - private bufferUniformSize?: number; - public readonly byteLength: number; - - constructor(buffers: ArrayBuffer | ArrayBuffer[] | TypedArray - | TypedArray[]) { - // Normalize the `buffers` input to be `ArrayBuffer[]`. - if (!(buffers instanceof Array)) { - buffers = [buffers]; - } - buffers = buffers.map((bufferOrTypedArray) => { - if (util.isTypedArray(bufferOrTypedArray)) { - return bufferOrTypedArray.buffer; - } - return bufferOrTypedArray; - }); - - // Skip setting up shards if there are no buffers. - if (buffers.length === 0) { - return; - } - - this.bufferUniformSize = buffers[0].byteLength; - let start = 0; - - for (let i = 0; i < buffers.length; i++) { - const buffer = buffers[i]; - // Check that all buffers except the last one have the same length. - if (i !== buffers.length - 1 && - buffer.byteLength !== this.bufferUniformSize) { - // Unset the buffer uniform size, since the buffer sizes are not - // uniform. - this.bufferUniformSize = undefined; - } - - // Create the shards, including their start and end points. - const end = start + buffer.byteLength; - this.shards.push({buffer, start, end}); - start = end; - } - - // Set the byteLenghth - if (this.shards.length === 0) { - this.byteLength = 0; - } - this.byteLength = this.shards[this.shards.length - 1].end; - } - - slice(start = 0, end = this.byteLength): ArrayBuffer { - // NaN is treated as zero for slicing. This matches ArrayBuffer's behavior. - start = isNaN(Number(start)) ? 0 : start; - end = isNaN(Number(end)) ? 0 : end; - - // Fix the bounds to within the array. - start = Math.max(0, start); - end = Math.min(this.byteLength, end); - if (end <= start) { - return new ArrayBuffer(0); - } - - const startShardIndex = this.findShardForByte(start); - if (startShardIndex === -1) { - // This should not happen since the start and end indices are always - // within 0 and the composite array's length. - throw new Error(`Could not find start shard for byte ${start}`); - } - - const size = end - start; - const outputBuffer = new ArrayBuffer(size); - const outputArray = new Uint8Array(outputBuffer); - let sliced = 0; - for (let i = startShardIndex; i < this.shards.length; i++) { - const shard = this.shards[i]; - - const globalStart = start + sliced; - const localStart = globalStart - shard.start; - const outputStart = sliced; - - const globalEnd = Math.min(end, shard.end); - const localEnd = globalEnd - shard.start; - - const outputSlice = new Uint8Array(shard.buffer.slice(localStart, - localEnd)); - outputArray.set(outputSlice, outputStart); - sliced += outputSlice.length; - - if (end < shard.end) { - break; - } - } - return outputBuffer; - } - - /** - * Get the index of the shard that contains the byte at `byteIndex`. - */ - private findShardForByte(byteIndex: number): number { - if (this.shards.length === 0 || byteIndex < 0 || - byteIndex >= this.byteLength) { - return -1; - } - - // If the buffers have a uniform size, compute the shard directly. - if (this.bufferUniformSize != null) { - this.previousShardIndex = Math.floor(byteIndex / this.bufferUniformSize); - return this.previousShardIndex; - } - - // If the buffers don't have a uniform size, we need to search for the - // shard. That means we need a function to check where the byteIndex lies - // relative to a given shard. - function check(shard: BufferShard) { - if (byteIndex < shard.start) { - return -1; - } - if (byteIndex >= shard.end) { - return 1; - } - return 0; - } - - // For efficiency, try the previous shard first. - if (check(this.shards[this.previousShardIndex]) === 0) { - return this.previousShardIndex; - } - - // Otherwise, use a generic search function. - // This should almost never end up being used in practice since the weight - // entries should always be in order. - const index = search(this.shards, check); - if (index === -1) { - return -1; - } - - this.previousShardIndex = index; - return this.previousShardIndex; - } -} - -/** - * Search for an element of a sorted array. - * - * @param sortedArray The sorted array to search - * @param compare A function to compare the current value against the searched - * value. Return 0 on a match, negative if the searched value is less than - * the value passed to the function, and positive if the searched value is - * greater than the value passed to the function. - * @returns The index of the element, or -1 if it's not in the array. - */ -function search(sortedArray: T[], compare: (t: T) => number): number { - // Binary search - let min = 0; - let max = sortedArray.length; - - while (min <= max) { - const middle = Math.floor((max - min) / 2) + min; - const side = compare(sortedArray[middle]); - - if (side === 0) { - return middle; - } else if (side < 0) { - max = middle; - } else { - min = middle + 1; - } - } - return -1; -} diff --git a/tfjs-core/src/io/weights_loader_test.ts b/tfjs-core/src/io/weights_loader_test.ts index 62e3bf0f2e3..69a00607d39 100644 --- a/tfjs-core/src/io/weights_loader_test.ts +++ b/tfjs-core/src/io/weights_loader_test.ts @@ -18,7 +18,6 @@ import * as tf from '../index'; import {BROWSER_ENVS, describeWithFlags} from '../jasmine_util'; import {expectArraysClose, expectArraysEqual} from '../test_util'; import {WeightsManifestConfig} from './types'; -import {CompositeArrayBuffer} from './weights_loader'; describeWithFlags('loadWeights', BROWSER_ENVS, () => { const setupFakeWeightFiles = (fileBufferMap: { @@ -542,99 +541,3 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { expect(weight2.dtype).toEqual('float32'); }); }); - -describe('CompositeArrayBuffer', () => { - const uniformBuffers = [ - new Uint8Array([0, 1, 2, 3]).buffer, - new Uint8Array([4, 5, 6, 7]).buffer, - new Uint8Array([8, 9, 10, 11]).buffer, - new Uint8Array([12, 13, 14, 15]).buffer, - new Uint8Array([16]).buffer, - ]; - - const nonUniformBuffers = [ - new Uint8Array([0, 1, 2]).buffer, - new Uint8Array([3, 4, 5, 6, 7]).buffer, - new Uint8Array([8, 9, 10, 11]).buffer, - new Uint8Array([12, 13, 14, 15, 16]).buffer, - ]; - - const bufferTestCases = [ - ['uniform', uniformBuffers], - ['non-uniform', nonUniformBuffers] - ] as const; - - for (const [buffersType, buffers] of bufferTestCases) { - let composite: CompositeArrayBuffer; - beforeEach(() => { - composite = new CompositeArrayBuffer(buffers); - }); - - it(`${buffersType}: slices across multiple buffers`, () => { - expectArraysEqual(new Uint8Array(composite.slice(1, 13)), - [1,2,3,4,5,6,7,8,9,10,11,12]); - }); - - it(`${buffersType}: slices to the end of the array when \'end\' is not ` + - 'specified', () => { - expectArraysEqual(new Uint8Array(composite.slice(5)), - [5,6,7,8,9,10,11,12,13,14,15,16]); - }); - - it(`${buffersType}: makes a copy when slice() is called with no arguments`, - () => { - expectArraysEqual(new Uint8Array(composite.slice()), - [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]); - }); - - it(`${buffersType}: slices from zero when start is negative`, () => { - expectArraysEqual(new Uint8Array(composite.slice(-4, 5)), - [0,1,2,3,4]); - }); - - it(`${buffersType}: slices to the end when end is greater than length`, - () => { - expectArraysEqual(new Uint8Array(composite.slice(7, 1000)), - [7,8,9,10,11,12,13,14,15,16]); - }); - - it(`${buffersType}: slices multiple ranges out of order`, () => { - expectArraysEqual(new Uint8Array(composite.slice(13, 15)), [13, 14]); - expectArraysEqual(new Uint8Array(composite.slice(0, 2)), [0, 1]); - expectArraysEqual(new Uint8Array(composite.slice(9, 13)), - [9, 10, 11, 12]); - }); - } - - it('can be passed an empty arraybuffer', () => { - const array = new Uint8Array([]); - const singleComposite = new CompositeArrayBuffer(array.buffer); - expectArraysEqual(new Uint8Array(singleComposite.slice()), []); - }); - - it('can be created from a single array', () => { - const array = new Uint8Array([1,2,3]); - const singleComposite = new CompositeArrayBuffer(array.buffer); - expectArraysEqual(new Uint8Array(singleComposite.slice()), array); - }); - - it('treats NaN as zero when passed as the start of slice', () => { - const array = new Uint8Array([1,2,3]); - const composite = new CompositeArrayBuffer(array.buffer); - expectArraysEqual(new Uint8Array(composite.slice(NaN, 2)), [1,2]); - }); - - it('treats NaN as zero when passed as the end of slice', () => { - const array = new Uint8Array([1,2,3]); - const composite = new CompositeArrayBuffer(array.buffer); - expectArraysEqual(new Uint8Array(composite.slice(0, NaN)), []); - }); - - it('supports TypedArray input', () => { - // This support is necessary for some tests in tfjs-converter. Maybe those - // tests are misconfigured? - const array = new Uint8Array([1,2,3]); - const composite = new CompositeArrayBuffer(array); - expectArraysEqual(new Uint8Array(composite.slice(0, 2)), [1,2]); - }); -}); From ec9a3827e8fa5829029bc9474173ab83e23b43fe Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Thu, 20 Apr 2023 10:48:59 -0700 Subject: [PATCH 20/20] Add license --- tfjs-core/src/io/composite_array_buffer.ts | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tfjs-core/src/io/composite_array_buffer.ts b/tfjs-core/src/io/composite_array_buffer.ts index 5dceda75ad1..6dc67da73f2 100644 --- a/tfjs-core/src/io/composite_array_buffer.ts +++ b/tfjs-core/src/io/composite_array_buffer.ts @@ -1,3 +1,19 @@ +/** + * @license + * Copyright 2023 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ import {TypedArray} from '../types'; import * as util from '../util';