diff --git a/tfjs-core/src/io/composite_array_buffer.ts b/tfjs-core/src/io/composite_array_buffer.ts new file mode 100644 index 00000000000..6dc67da73f2 --- /dev/null +++ b/tfjs-core/src/io/composite_array_buffer.ts @@ -0,0 +1,206 @@ +/** + * @license + * Copyright 2023 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {TypedArray} from '../types'; +import * as util from '../util'; + +type BufferShard = { + start: number, + end: number, + buffer: ArrayBuffer, +}; + +/** + * Wraps a list of ArrayBuffers into a `slice()`-able object without allocating + * a large ArrayBuffer. + * + * Allocating large ArrayBuffers (~2GB) can be unstable on Chrome. TFJS loads + * its weights as a list of (usually) 4MB ArrayBuffers and then slices the + * weight tensors out of them. For small models, it's safe to concatenate all + * the weight buffers into a single ArrayBuffer and then slice the weight + * tensors out of it, but for large models, a different approach is needed. + */ + +export class CompositeArrayBuffer { + private shards: BufferShard[] = []; + private previousShardIndex = 0; + private bufferUniformSize?: number; + public readonly byteLength: number; + + constructor(buffers: ArrayBuffer | ArrayBuffer[] | TypedArray | + TypedArray[]) { + // Normalize the `buffers` input to be `ArrayBuffer[]`. + if (!(buffers instanceof Array)) { + buffers = [buffers]; + } + buffers = buffers.map((bufferOrTypedArray) => { + if (util.isTypedArray(bufferOrTypedArray)) { + return bufferOrTypedArray.buffer; + } + return bufferOrTypedArray; + }); + + // Skip setting up shards if there are no buffers. + if (buffers.length === 0) { + return; + } + + this.bufferUniformSize = buffers[0].byteLength; + let start = 0; + + for (let i = 0; i < buffers.length; i++) { + const buffer = buffers[i]; + // Check that all buffers except the last one have the same length. + if (i !== buffers.length - 1 && + buffer.byteLength !== this.bufferUniformSize) { + // Unset the buffer uniform size, since the buffer sizes are not + // uniform. + this.bufferUniformSize = undefined; + } + + // Create the shards, including their start and end points. + const end = start + buffer.byteLength; + this.shards.push({ buffer, start, end }); + start = end; + } + + // Set the byteLenghth + if (this.shards.length === 0) { + this.byteLength = 0; + } + this.byteLength = this.shards[this.shards.length - 1].end; + } + + slice(start = 0, end = this.byteLength): ArrayBuffer { + // NaN is treated as zero for slicing. This matches ArrayBuffer's behavior. + start = isNaN(Number(start)) ? 0 : start; + end = isNaN(Number(end)) ? 0 : end; + + // Fix the bounds to within the array. + start = Math.max(0, start); + end = Math.min(this.byteLength, end); + if (end <= start) { + return new ArrayBuffer(0); + } + + const startShardIndex = this.findShardForByte(start); + if (startShardIndex === -1) { + // This should not happen since the start and end indices are always + // within 0 and the composite array's length. + throw new Error(`Could not find start shard for byte ${start}`); + } + + const size = end - start; + const outputBuffer = new ArrayBuffer(size); + const outputArray = new Uint8Array(outputBuffer); + let sliced = 0; + for (let i = startShardIndex; i < this.shards.length; i++) { + const shard = this.shards[i]; + + const globalStart = start + sliced; + const localStart = globalStart - shard.start; + const outputStart = sliced; + + const globalEnd = Math.min(end, shard.end); + const localEnd = globalEnd - shard.start; + + const outputSlice = new Uint8Array(shard.buffer.slice(localStart, + localEnd)); + outputArray.set(outputSlice, outputStart); + sliced += outputSlice.length; + + if (end < shard.end) { + break; + } + } + return outputBuffer; + } + + /** + * Get the index of the shard that contains the byte at `byteIndex`. + */ + private findShardForByte(byteIndex: number): number { + if (this.shards.length === 0 || byteIndex < 0 || + byteIndex >= this.byteLength) { + return -1; + } + + // If the buffers have a uniform size, compute the shard directly. + if (this.bufferUniformSize != null) { + this.previousShardIndex = Math.floor(byteIndex / this.bufferUniformSize); + return this.previousShardIndex; + } + + // If the buffers don't have a uniform size, we need to search for the + // shard. That means we need a function to check where the byteIndex lies + // relative to a given shard. + function check(shard: BufferShard) { + if (byteIndex < shard.start) { + return -1; + } + if (byteIndex >= shard.end) { + return 1; + } + return 0; + } + + // For efficiency, try the previous shard first. + if (check(this.shards[this.previousShardIndex]) === 0) { + return this.previousShardIndex; + } + + // Otherwise, use a generic search function. + // This should almost never end up being used in practice since the weight + // entries should always be in order. + const index = search(this.shards, check); + if (index === -1) { + return -1; + } + + this.previousShardIndex = index; + return this.previousShardIndex; + } +} + +/** + * Search for an element of a sorted array. + * + * @param sortedArray The sorted array to search + * @param compare A function to compare the current value against the searched + * value. Return 0 on a match, negative if the searched value is less than + * the value passed to the function, and positive if the searched value is + * greater than the value passed to the function. + * @returns The index of the element, or -1 if it's not in the array. + */ +export function search(sortedArray: T[], compare: (t: T) => number): number { + // Binary search + let min = 0; + let max = sortedArray.length; + + while (min <= max) { + const middle = Math.floor((max - min) / 2) + min; + const side = compare(sortedArray[middle]); + + if (side === 0) { + return middle; + } else if (side < 0) { + max = middle; + } else { + min = middle + 1; + } + } + return -1; +} diff --git a/tfjs-core/src/io/composite_array_buffer_test.ts b/tfjs-core/src/io/composite_array_buffer_test.ts new file mode 100644 index 00000000000..fa64532ad8c --- /dev/null +++ b/tfjs-core/src/io/composite_array_buffer_test.ts @@ -0,0 +1,114 @@ +/** + * @license + * Copyright 2023 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {expectArraysEqual} from '../test_util'; +import {CompositeArrayBuffer} from './composite_array_buffer'; + +describe('CompositeArrayBuffer', () => { + const uniformBuffers = [ + new Uint8Array([0, 1, 2, 3]).buffer, + new Uint8Array([4, 5, 6, 7]).buffer, + new Uint8Array([8, 9, 10, 11]).buffer, + new Uint8Array([12, 13, 14, 15]).buffer, + new Uint8Array([16]).buffer, + ]; + + const nonUniformBuffers = [ + new Uint8Array([0, 1, 2]).buffer, + new Uint8Array([3, 4, 5, 6, 7]).buffer, + new Uint8Array([8, 9, 10, 11]).buffer, + new Uint8Array([12, 13, 14, 15, 16]).buffer, + ]; + + const bufferTestCases = [ + ['uniform', uniformBuffers], + ['non-uniform', nonUniformBuffers] + ] as const; + + for (const [buffersType, buffers] of bufferTestCases) { + let composite: CompositeArrayBuffer; + beforeEach(() => { + composite = new CompositeArrayBuffer(buffers); + }); + + it(`${buffersType}: slices across multiple buffers`, () => { + expectArraysEqual(new Uint8Array(composite.slice(1, 13)), + [1,2,3,4,5,6,7,8,9,10,11,12]); + }); + + it(`${buffersType}: slices to the end of the array when \'end\' is not ` + + 'specified', () => { + expectArraysEqual(new Uint8Array(composite.slice(5)), + [5,6,7,8,9,10,11,12,13,14,15,16]); + }); + + it(`${buffersType}: makes a copy when slice() is called with no arguments`, + () => { + expectArraysEqual(new Uint8Array(composite.slice()), + [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]); + }); + + it(`${buffersType}: slices from zero when start is negative`, () => { + expectArraysEqual(new Uint8Array(composite.slice(-4, 5)), + [0,1,2,3,4]); + }); + + it(`${buffersType}: slices to the end when end is greater than length`, + () => { + expectArraysEqual(new Uint8Array(composite.slice(7, 1000)), + [7,8,9,10,11,12,13,14,15,16]); + }); + + it(`${buffersType}: slices multiple ranges out of order`, () => { + expectArraysEqual(new Uint8Array(composite.slice(13, 15)), [13, 14]); + expectArraysEqual(new Uint8Array(composite.slice(0, 2)), [0, 1]); + expectArraysEqual(new Uint8Array(composite.slice(9, 13)), + [9, 10, 11, 12]); + }); + } + + it('can be passed an empty arraybuffer', () => { + const array = new Uint8Array([]); + const singleComposite = new CompositeArrayBuffer(array.buffer); + expectArraysEqual(new Uint8Array(singleComposite.slice()), []); + }); + + it('can be created from a single array', () => { + const array = new Uint8Array([1,2,3]); + const singleComposite = new CompositeArrayBuffer(array.buffer); + expectArraysEqual(new Uint8Array(singleComposite.slice()), array); + }); + + it('treats NaN as zero when passed as the start of slice', () => { + const array = new Uint8Array([1,2,3]); + const composite = new CompositeArrayBuffer(array.buffer); + expectArraysEqual(new Uint8Array(composite.slice(NaN, 2)), [1,2]); + }); + + it('treats NaN as zero when passed as the end of slice', () => { + const array = new Uint8Array([1,2,3]); + const composite = new CompositeArrayBuffer(array.buffer); + expectArraysEqual(new Uint8Array(composite.slice(0, NaN)), []); + }); + + it('supports TypedArray input', () => { + // This support is necessary for some tests in tfjs-converter. Maybe those + // tests are misconfigured? + const array = new Uint8Array([1,2,3]); + const composite = new CompositeArrayBuffer(array); + expectArraysEqual(new Uint8Array(composite.slice(0, 2)), [1,2]); + }); +}); diff --git a/tfjs-core/src/io/weights_loader.ts b/tfjs-core/src/io/weights_loader.ts index 8b0f923fd8a..8ad0ef2f85b 100644 --- a/tfjs-core/src/io/weights_loader.ts +++ b/tfjs-core/src/io/weights_loader.ts @@ -19,6 +19,7 @@ import {env} from '../environment'; import {NamedTensorMap} from '../tensor_types'; import * as util from '../util'; +import {CompositeArrayBuffer} from './composite_array_buffer'; import {decodeWeights} from './io_utils'; import {monitorPromisesProgress} from './progress'; import {DTYPE_VALUE_SIZE_MAP, LoadOptions, WeightsManifestConfig, WeightsManifestEntry} from './types'; @@ -35,27 +36,27 @@ import {DTYPE_VALUE_SIZE_MAP, LoadOptions, WeightsManifestConfig, WeightsManifes * length as `fetchURLs`. */ export async function loadWeightsAsArrayBuffer( - fetchURLs: string[], loadOptions?: LoadOptions): Promise { + fetchURLs: string[], loadOptions?: LoadOptions): Promise { if (loadOptions == null) { loadOptions = {}; } const fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch : - loadOptions.fetchFunc; + loadOptions.fetchFunc; // Create the requests for all of the weights in parallel. const requests = fetchURLs.map( - fetchURL => - fetchFunc(fetchURL, loadOptions.requestInit, {isBinary: true})); + fetchURL => + fetchFunc(fetchURL, loadOptions.requestInit, { isBinary: true })); const fetchStartFraction = 0; const fetchEndFraction = 0.5; const responses = loadOptions.onProgress == null ? - await Promise.all(requests) : - await monitorPromisesProgress( - requests, loadOptions.onProgress, fetchStartFraction, - fetchEndFraction); + await Promise.all(requests) : + await monitorPromisesProgress( + requests, loadOptions.onProgress, fetchStartFraction, + fetchEndFraction); const bufferPromises = responses.map(response => response.arrayBuffer()); @@ -63,10 +64,10 @@ export async function loadWeightsAsArrayBuffer( const bufferEndFraction = 1; const buffers = loadOptions.onProgress == null ? - await Promise.all(bufferPromises) : - await monitorPromisesProgress( - bufferPromises, loadOptions.onProgress, bufferStartFraction, - bufferEndFraction); + await Promise.all(bufferPromises) : + await monitorPromisesProgress( + bufferPromises, loadOptions.onProgress, bufferStartFraction, + bufferEndFraction); return buffers; } @@ -80,9 +81,9 @@ export async function loadWeightsAsArrayBuffer( * @param weightNames The names of the weights to be fetched. */ export async function loadWeights( - manifest: WeightsManifestConfig, filePathPrefix = '', - weightNames?: string[], - requestInit?: RequestInit): Promise { + manifest: WeightsManifestConfig, filePathPrefix = '', + weightNames?: string[], + requestInit?: RequestInit): Promise { // TODO(nsthorat): Groups are currently fetched atomically. If you need a // single weight from a group, the whole group will be fetched. At a future // date, we should support fetching only the individual shards within a @@ -90,7 +91,7 @@ export async function loadWeights( // TODO(cais): Use `decodeWeights` for implementation. const fetchWeights = (fetchUrls: string[]) => - loadWeightsAsArrayBuffer(fetchUrls, {requestInit}); + loadWeightsAsArrayBuffer(fetchUrls, { requestInit }); const loadWeights = weightsLoaderFactory(fetchWeights); return loadWeights(manifest, filePathPrefix, weightNames); @@ -121,12 +122,12 @@ export async function loadWeights( * @returns Weight loading function. */ export function weightsLoaderFactory( - fetchWeightsFunction: (fetchUrls: string[]) => Promise): - (manifest: WeightsManifestConfig, filePathPrefix?: string, - weightNames?: string[]) => Promise { - return async( - manifest: WeightsManifestConfig, filePathPrefix = '', - weightNames?: string[]): Promise => { + fetchWeightsFunction: (fetchUrls: string[]) => Promise): + (manifest: WeightsManifestConfig, filePathPrefix?: string, + weightNames?: string[]) => Promise { + return async ( + manifest: WeightsManifestConfig, filePathPrefix = '', + weightNames?: string[]): Promise => { // Collect all the groups, weights, and their relative offsets to be // fetched. const groupIndicesToFetchMap = manifest.map(() => false); @@ -137,17 +138,17 @@ export function weightsLoaderFactory( }> } = {}; const weightsFound = - weightNames != null ? weightNames.map(() => false) : []; + weightNames != null ? weightNames.map(() => false) : []; const allManifestWeightNames: string[] = []; manifest.forEach((manifestGroupConfig, groupIndex) => { let groupOffset = 0; manifestGroupConfig.weights.forEach(weightsEntry => { const rawDtype = ('quantization' in weightsEntry) ? - weightsEntry.quantization.dtype : - weightsEntry.dtype; + weightsEntry.quantization.dtype : + weightsEntry.dtype; const weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] * - util.sizeFromShape(weightsEntry.shape); + util.sizeFromShape(weightsEntry.shape); const enqueueWeightsForFetchingFn = () => { groupIndicesToFetchMap[groupIndex] = true; @@ -181,27 +182,27 @@ export function weightsLoaderFactory( if (!weightsFound.every(found => found)) { const weightsNotFound = weightNames.filter((_, i) => !weightsFound[i]); throw new Error( - `Could not find weights in manifest with names: ` + - `${weightsNotFound.join(', ')}. \n` + - `Manifest JSON has weights with names: ` + - `${allManifestWeightNames.join(', ')}.`); + `Could not find weights in manifest with names: ` + + `${weightsNotFound.join(', ')}. \n` + + `Manifest JSON has weights with names: ` + + `${allManifestWeightNames.join(', ')}.`); } // Convert the one-hot boolean groupId => shouldFetch map to a list of group // IDs. const groupIndicesToFetch = - groupIndicesToFetchMap.reduce((accumulator, shouldFetch, i) => { - if (shouldFetch) { - accumulator.push(i); - } - return accumulator; - }, []); + groupIndicesToFetchMap.reduce((accumulator, shouldFetch, i) => { + if (shouldFetch) { + accumulator.push(i); + } + return accumulator; + }, []); const fetchUrls: string[] = []; groupIndicesToFetch.forEach(i => { manifest[i].paths.forEach(filepath => { const fetchUrl = filePathPrefix + - (!filePathPrefix.endsWith('/') ? '/' : '') + filepath; + (!filePathPrefix.endsWith('/') ? '/' : '') + filepath; fetchUrls.push(fetchUrl); }); }); @@ -212,28 +213,17 @@ export function weightsLoaderFactory( groupIndicesToFetch.forEach(i => { const numBuffers = manifest[i].paths.length; - let groupBytes = 0; - for (let i = 0; i < numBuffers; i++) { - groupBytes += buffers[bufferIndexOffset + i].byteLength; - } - - // Create a buffer for the whole group. - const groupBuffer = new ArrayBuffer(groupBytes); - const groupByteBuffer = new Uint8Array(groupBuffer); - let groupBufferOffset = 0; - for (let i = 0; i < numBuffers; i++) { - const buffer = new Uint8Array(buffers[bufferIndexOffset + i]); - groupByteBuffer.set(buffer, groupBufferOffset); - groupBufferOffset += buffer.byteLength; - } + const weightsBuffer = new CompositeArrayBuffer( + buffers.slice(bufferIndexOffset, bufferIndexOffset + numBuffers)); const weightsEntries = groupWeightsToFetch[i]; + weightsEntries.forEach(weightsEntry => { - const byteBuffer = groupBuffer.slice( - weightsEntry.groupOffset, - weightsEntry.groupOffset + weightsEntry.sizeBytes); + const byteBuffer = weightsBuffer.slice( + weightsEntry.groupOffset, + weightsEntry.groupOffset + weightsEntry.sizeBytes); const nameToTensorMap = - decodeWeights(byteBuffer, [weightsEntry.manifestEntry]); + decodeWeights(byteBuffer, [weightsEntry.manifestEntry]); for (const name in nameToTensorMap) { weightsTensorMap[name] = nameToTensorMap[name]; } diff --git a/tfjs-core/src/io/weights_loader_test.ts b/tfjs-core/src/io/weights_loader_test.ts index 759f3b72b18..69a00607d39 100644 --- a/tfjs-core/src/io/weights_loader_test.ts +++ b/tfjs-core/src/io/weights_loader_test.ts @@ -403,7 +403,6 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => { it('throws if requested weight has unknown dtype', async () => { setupFakeWeightFiles({'./weightfile0': new Float32Array([1, 2, 3])}); - const manifest: WeightsManifestConfig = [{ 'paths': ['weightfile0'], 'weights': [{