From 71365a01de860a1cc6db9a4b379570caf2113041 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Wed, 19 Apr 2023 17:06:01 -0700 Subject: [PATCH 01/11] Support using a list of ArrayBuffers as model weight data --- .../src/executor/graph_model_test.ts | 9 +++-- .../src/operations/executors/spy_ops.ts | 14 ++++++- tfjs-core/src/io/browser_files.ts | 14 +++++-- tfjs-core/src/io/browser_files_test.ts | 26 +++++++------ tfjs-core/src/io/http.ts | 13 +++++-- tfjs-core/src/io/http_test.ts | 37 ++++++++++++------- tfjs-core/src/io/indexed_db_test.ts | 10 +++-- tfjs-core/src/io/io.ts | 5 ++- tfjs-core/src/io/io_utils.ts | 35 ++++++++++-------- tfjs-core/src/io/local_storage.ts | 9 ++++- tfjs-core/src/io/model_management_test.ts | 8 +++- tfjs-core/src/io/passthrough.ts | 6 +-- tfjs-core/src/io/types.ts | 4 +- tfjs-layers/src/models.ts | 4 +- tfjs-layers/src/models_test.ts | 21 +++++++---- 15 files changed, 142 insertions(+), 73 deletions(-) diff --git a/tfjs-converter/src/executor/graph_model_test.ts b/tfjs-converter/src/executor/graph_model_test.ts index 7d3a3ade14f..97adc903882 100644 --- a/tfjs-converter/src/executor/graph_model_test.ts +++ b/tfjs-converter/src/executor/graph_model_test.ts @@ -586,7 +586,8 @@ describe('Model', () => { expect(handler.savedArtifacts.modelTopology).toEqual(CUSTOM_OP_MODEL); expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest); tfc.test_util.expectArraysClose( - new Int32Array(handler.savedArtifacts.weightData), bias.dataSync()); + new Int32Array(new io.CompositeArrayBuffer( + handler.savedArtifacts.weightData).slice()), bias.dataSync()); }); }); }); @@ -616,7 +617,8 @@ describe('Model', () => { }); expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest); tfc.test_util.expectArraysClose( - new Int32Array(handler.savedArtifacts.weightData), bias.dataSync()); + new Int32Array(new io.CompositeArrayBuffer( + handler.savedArtifacts.weightData).slice()), bias.dataSync()); }); }); @@ -904,7 +906,8 @@ describe('Model', () => { }); expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest); tfc.test_util.expectArraysClose( - new Int32Array(handler.savedArtifacts.weightData), bias.dataSync()); + new Int32Array(new io.CompositeArrayBuffer(handler.savedArtifacts + .weightData).slice()), bias.dataSync()); }); }); diff --git a/tfjs-converter/src/operations/executors/spy_ops.ts b/tfjs-converter/src/operations/executors/spy_ops.ts index e8f9fb8e1bb..ae1ef993664 100644 --- a/tfjs-converter/src/operations/executors/spy_ops.ts +++ b/tfjs-converter/src/operations/executors/spy_ops.ts @@ -15,11 +15,23 @@ * ============================================================================= */ +// The opposite of Extract +type Without = T extends U ? never : T; + +// Do not spy on CompositeArrayBuffer because it is a class constructor. +type NotSpiedOn = 'CompositeArrayBuffer'; + export type RecursiveSpy = - T extends Function ? jasmine.Spy : {[K in keyof T]: RecursiveSpy}; + T extends Function ? jasmine.Spy : + {[K in Without]: RecursiveSpy} & + {[K in Extract]: T[K]}; export function spyOnAllFunctions(obj: T): RecursiveSpy { return Object.fromEntries(Object.entries(obj).map(([key, val]) => { + // TODO(mattSoulanille): Do not hard code this + if (key === 'CompositeArrayBuffer') { + return val; + } if (val instanceof Function) { return [key, jasmine.createSpy(`${key} spy`, val).and.callThrough()]; } else if (val instanceof Array) { diff --git a/tfjs-core/src/io/browser_files.ts b/tfjs-core/src/io/browser_files.ts index 90b11058b1c..07ce146143e 100644 --- a/tfjs-core/src/io/browser_files.ts +++ b/tfjs-core/src/io/browser_files.ts @@ -25,7 +25,7 @@ import {env} from '../environment'; import {basename, concatenateArrayBuffers, getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts} from './io_utils'; import {IORouter, IORouterRegistry} from './router_registry'; -import {IOHandler, ModelArtifacts, ModelJSON, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types'; +import {IOHandler, ModelArtifacts, ModelJSON, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types'; const DEFAULT_FILE_NAME_PREFIX = 'model'; const DEFAULT_JSON_EXTENSION_NAME = '.json'; @@ -70,8 +70,14 @@ export class BrowserDownloads implements IOHandler { 'Browser downloads are not supported in ' + 'this environment since `document` is not present'); } + + // TODO(mattsoulanille): Support saving models over 2GB that exceed + // Chrome's ArrayBuffer size limit. + const weightBuffer = concatenateArrayBuffers( + [modelArtifacts.weightData].flat()); + const weightsURL = window.URL.createObjectURL(new Blob( - [modelArtifacts.weightData], {type: 'application/octet-stream'})); + [weightBuffer], {type: 'application/octet-stream'})); if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error( @@ -169,7 +175,7 @@ class BrowserFiles implements IOHandler { } private loadWeights(weightsManifest: WeightsManifestConfig): Promise<[ - /* weightSpecs */ WeightsManifestEntry[], /* weightData */ ArrayBuffer + /* weightSpecs */ WeightsManifestEntry[], WeightData, ]> { const weightSpecs: WeightsManifestEntry[] = []; const paths: string[] = []; @@ -185,7 +191,7 @@ class BrowserFiles implements IOHandler { paths.map(path => this.loadWeightsFile(path, pathToFile[path])); return Promise.all(promises).then( - buffers => [weightSpecs, concatenateArrayBuffers(buffers)]); + buffers => [weightSpecs, buffers]); } private loadWeightsFile(path: string, file: File): Promise { diff --git a/tfjs-core/src/io/browser_files_test.ts b/tfjs-core/src/io/browser_files_test.ts index d9f63f7f26c..728264510fd 100644 --- a/tfjs-core/src/io/browser_files_test.ts +++ b/tfjs-core/src/io/browser_files_test.ts @@ -23,6 +23,7 @@ import * as tf from '../index'; import {BROWSER_ENVS, describeWithFlags} from '../jasmine_util'; import {browserDownloads, BrowserDownloads, browserDownloadsRouter} from './browser_files'; import {WeightsManifestConfig, WeightsManifestEntry} from './types'; +import {CompositeArrayBuffer} from './composite_array_buffer'; const modelTopology1: {} = { 'class_name': 'Sequential', @@ -310,8 +311,8 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => { expect(modelArtifacts.modelInitializer).toEqual({}); expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1); - expect(new Uint8Array(modelArtifacts.weightData)) - .toEqual(new Uint8Array(weightData1)); + expect(new Uint8Array(new CompositeArrayBuffer(modelArtifacts.weightData) + .slice())).toEqual(new Uint8Array(weightData1)); }); it(`One group, two paths`, async () => { @@ -351,9 +352,10 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => { const modelArtifacts = await filesHandler.load(); expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightSpecs); - expect(new Uint8Array(modelArtifacts.weightData)).toEqual(new Uint8Array([ - 1, 2, 3, 4, 10, 20, 30, 40 - ])); + expect(new Uint8Array(new CompositeArrayBuffer(modelArtifacts.weightData) + .slice())).toEqual(new Uint8Array([ + 1, 2, 3, 4, 10, 20, 30, 40 + ])); }); it(`Two groups, four paths, reverseOrder=false`, async () => { @@ -418,9 +420,10 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => { expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs) .toEqual(weightSpecs1.concat(weightSpecs2)); - expect(new Uint8Array(modelArtifacts.weightData)).toEqual(new Uint8Array([ - 1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80 - ])); + expect(new Uint8Array(new CompositeArrayBuffer(modelArtifacts.weightData) + .slice())).toEqual(new Uint8Array([ + 1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80 + ])); }); it(`Two groups, four paths, reverseOrder=true`, async () => { @@ -485,9 +488,10 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => { expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs) .toEqual(weightSpecs1.concat(weightSpecs2)); - expect(new Uint8Array(modelArtifacts.weightData)).toEqual(new Uint8Array([ - 1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80 - ])); + expect(new Uint8Array(new CompositeArrayBuffer(modelArtifacts.weightData) + .slice())).toEqual(new Uint8Array([ + 1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80 + ])); }); it('Upload model topology only', async () => { diff --git a/tfjs-core/src/io/http.ts b/tfjs-core/src/io/http.ts index 5b3aab81fb9..58a2f1ec1dc 100644 --- a/tfjs-core/src/io/http.ts +++ b/tfjs-core/src/io/http.ts @@ -26,7 +26,7 @@ import {env} from '../environment'; import {assert} from '../util'; import {concatenateArrayBuffers, getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts, getWeightSpecs} from './io_utils'; import {IORouter, IORouterRegistry} from './router_registry'; -import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types'; +import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types'; import {loadWeightsAsArrayBuffer} from './weights_loader'; const OCTET_STREAM_MIME_TYPE = 'application/octet-stream'; @@ -110,9 +110,14 @@ export class HTTPRequest implements IOHandler { 'model.json'); if (modelArtifacts.weightData != null) { + // TODO(mattsoulanille): Support saving models over 2GB that exceed + // Chrome's ArrayBuffer size limit. + const weightBuffer = concatenateArrayBuffers( + [modelArtifacts.weightData].flat()); + init.body.append( 'model.weights.bin', - new Blob([modelArtifacts.weightData], {type: OCTET_STREAM_MIME_TYPE}), + new Blob([weightBuffer], {type: OCTET_STREAM_MIME_TYPE}), 'model.weights.bin'); } @@ -182,7 +187,7 @@ export class HTTPRequest implements IOHandler { } private async loadWeights(weightsManifest: WeightsManifestConfig): - Promise<[WeightsManifestEntry[], ArrayBuffer]> { + Promise<[WeightsManifestEntry[], WeightData]> { const weightPath = Array.isArray(this.path) ? this.path[1] : this.path; const [prefix, suffix] = parseUrl(weightPath); const pathPrefix = this.weightPathPrefix || prefix; @@ -210,7 +215,7 @@ export class HTTPRequest implements IOHandler { fetchFunc: this.fetch, onProgress: this.onProgress }); - return [weightSpecs, concatenateArrayBuffers(buffers)]; + return [weightSpecs, buffers]; } } diff --git a/tfjs-core/src/io/http_test.ts b/tfjs-core/src/io/http_test.ts index 4be84eb9eef..1b387a4b6bd 100644 --- a/tfjs-core/src/io/http_test.ts +++ b/tfjs-core/src/io/http_test.ts @@ -18,6 +18,7 @@ import * as tf from '../index'; import {BROWSER_ENVS, CHROME_ENVS, describeWithFlags, NODE_ENVS} from '../jasmine_util'; import {HTTPRequest, httpRouter, parseUrl} from './http'; +import {CompositeArrayBuffer} from './composite_array_buffer'; // Test data. const modelTopology1: {} = { @@ -161,7 +162,8 @@ describeWithFlags('http-load fetch', NODE_ENVS, () => { expect(modelArtifacts.generatedBy).toEqual('1.15'); expect(modelArtifacts.convertedBy).toEqual('1.3.1'); expect(modelArtifacts.userDefinedMetadata).toEqual({}); - expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); + expect(new Float32Array(new CompositeArrayBuffer(modelArtifacts.weightData) + .slice())).toEqual(floatData); }); it('throw exception if no fetch polyfill', () => { @@ -507,7 +509,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.userDefinedMetadata).toEqual({}); expect(modelArtifacts.modelInitializer).toEqual({}); - expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); + expect(new Float32Array(new CompositeArrayBuffer( + modelArtifacts.weightData).slice())).toEqual(floatData); expect(Object.keys(requestInits).length).toEqual(2); // Assert that fetch is invoked with `window` as the context. expect(fetchSpy.calls.mostRecent().object).toEqual(window); @@ -550,7 +553,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { const modelArtifacts = await handler.load(); expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); - expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); + expect(new Float32Array(new CompositeArrayBuffer( + modelArtifacts.weightData).slice())).toEqual(floatData); expect(Object.keys(requestInits).length).toEqual(2); expect(Object.keys(requestInits).length).toEqual(2); expect(requestInits['./model.json'].headers['header_key_1']) @@ -599,8 +603,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { const modelArtifacts = await handler.load(); expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); - expect(new Float32Array(modelArtifacts.weightData)) - .toEqual(new Float32Array([1, 3, 3, 7, 4])); + expect(new Float32Array(new CompositeArrayBuffer(modelArtifacts + .weightData).slice())).toEqual(new Float32Array([1, 3, 3, 7, 4])); }); it('2 groups, 2 weight, 2 paths', async () => { @@ -644,7 +648,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.weightSpecs) .toEqual( weightsManifest[0].weights.concat(weightsManifest[1].weights)); - expect(new Float32Array(modelArtifacts.weightData)) + expect(new Float32Array(new CompositeArrayBuffer( + modelArtifacts.weightData).slice())) .toEqual(new Float32Array([1, 3, 3, 7, 4])); }); @@ -689,10 +694,10 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.weightSpecs) .toEqual( weightsManifest[0].weights.concat(weightsManifest[1].weights)); - expect(new Int32Array(modelArtifacts.weightData.slice(0, 12))) - .toEqual(new Int32Array([1, 3, 3])); - expect(new Uint8Array(modelArtifacts.weightData.slice(12, 14))) - .toEqual(new Uint8Array([7, 4])); + expect(new Int32Array(new CompositeArrayBuffer(modelArtifacts.weightData) + .slice(0, 12))).toEqual(new Int32Array([1, 3, 3])); + expect(new Uint8Array(new CompositeArrayBuffer(modelArtifacts.weightData) + .slice(12, 14))).toEqual(new Uint8Array([7, 4])); }); it('topology only', async () => { @@ -752,9 +757,11 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.weightSpecs) .toEqual( weightsManifest[0].weights.concat(weightsManifest[1].weights)); - expect(new Int32Array(modelArtifacts.weightData.slice(0, 12))) + expect(new Int32Array(new CompositeArrayBuffer( + modelArtifacts.weightData).slice(0, 12))) .toEqual(new Int32Array([1, 3, 3])); - expect(new Float32Array(modelArtifacts.weightData.slice(12, 20))) + expect(new Float32Array(new CompositeArrayBuffer( + modelArtifacts.weightData).slice(12, 20))) .toEqual(new Float32Array([-7, -4])); }); @@ -840,7 +847,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { const modelArtifacts = await handler.load(); expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); - expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); + expect(new Float32Array(new CompositeArrayBuffer( + modelArtifacts.weightData).slice())).toEqual(floatData); expect(Object.keys(requestInits).length).toEqual(2); expect(Object.keys(requestInits).length).toEqual(2); expect(requestInits['./model.json'].headers['header_key_1']) @@ -902,7 +910,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); - expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); + expect(new Float32Array(new CompositeArrayBuffer(modelArtifacts.weightData) + .slice())).toEqual(floatData); expect(fetchInputs).toEqual(['./model.json', './weightfile0']); expect(fetchInits.length).toEqual(2); diff --git a/tfjs-core/src/io/indexed_db_test.ts b/tfjs-core/src/io/indexed_db_test.ts index 308fce2e464..2502afc716d 100644 --- a/tfjs-core/src/io/indexed_db_test.ts +++ b/tfjs-core/src/io/indexed_db_test.ts @@ -23,6 +23,7 @@ import * as tf from '../index'; import {BROWSER_ENVS, describeWithFlags, runWithLock} from '../jasmine_util'; import {expectArrayBuffersEqual} from '../test_util'; import {browserIndexedDB, BrowserIndexedDB, BrowserIndexedDBManager, deleteDatabase, indexedDBRouter} from './indexed_db'; +import {CompositeArrayBuffer} from './composite_array_buffer'; describeWithFlags('IndexedDB', BROWSER_ENVS, () => { // Test data. @@ -121,8 +122,9 @@ describeWithFlags('IndexedDB', BROWSER_ENVS, () => { expect(loadedArtifacts.generatedBy).toEqual('TensorFlow.js v0.0.0'); expect(loadedArtifacts.convertedBy).toEqual(null); expect(loadedArtifacts.modelInitializer).toEqual({}); - expectArrayBuffersEqual(loadedArtifacts.weightData, weightData1); - })); + expectArrayBuffersEqual(new CompositeArrayBuffer( + loadedArtifacts.weightData).slice(), weightData1); + })); it('Save two models and load one', runWithLock(async () => { const weightData2 = new ArrayBuffer(24); @@ -160,7 +162,9 @@ describeWithFlags('IndexedDB', BROWSER_ENVS, () => { const loadedArtifacts = await handler1.load(); expect(loadedArtifacts.modelTopology).toEqual(modelTopology1); expect(loadedArtifacts.weightSpecs).toEqual(weightSpecs1); - expectArrayBuffersEqual(loadedArtifacts.weightData, weightData1); + expect(loadedArtifacts.weightData).toBeDefined(); + expectArrayBuffersEqual(new CompositeArrayBuffer( + loadedArtifacts.weightData).slice(), weightData1); })); it('Loading nonexistent model fails', runWithLock(async () => { diff --git a/tfjs-core/src/io/io.ts b/tfjs-core/src/io/io.ts index 29383e451c5..49e9a1e2e06 100644 --- a/tfjs-core/src/io/io.ts +++ b/tfjs-core/src/io/io.ts @@ -25,13 +25,15 @@ import {browserHTTPRequest, http, isHTTPScheme} from './http'; import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsForJSON, getModelArtifactsForJSONSync, getModelArtifactsInfoForJSON, getWeightSpecs} from './io_utils'; import {fromMemory, fromMemorySync, withSaveHandler, withSaveHandlerSync} from './passthrough'; import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry'; -import {IOHandler, IOHandlerSync, LoadHandler, LoadOptions, ModelArtifacts, ModelArtifactsInfo, ModelJSON, ModelStoreManager, OnProgressCallback, RequestDetails, SaveConfig, SaveHandler, SaveResult, TrainingConfig, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types'; +import {IOHandler, IOHandlerSync, LoadHandler, LoadOptions, ModelArtifacts, ModelArtifactsInfo, ModelJSON, ModelStoreManager, OnProgressCallback, RequestDetails, SaveConfig, SaveHandler, SaveResult, TrainingConfig, WeightGroup, WeightsManifestConfig, WeightsManifestEntry, WeightData} from './types'; import {loadWeights, weightsLoaderFactory} from './weights_loader'; +import {CompositeArrayBuffer} from './composite_array_buffer'; export {copyModel, listModels, moveModel, removeModel} from './model_management'; export { browserFiles, browserHTTPRequest, + CompositeArrayBuffer, concatenateArrayBuffers, decodeWeights, encodeWeights, @@ -62,6 +64,7 @@ export { SaveHandler, SaveResult, TrainingConfig, + WeightData, WeightGroup, weightsLoaderFactory, WeightsManifestConfig, diff --git a/tfjs-core/src/io/io_utils.ts b/tfjs-core/src/io/io_utils.ts index ce20b7c9fee..e9024f17d72 100644 --- a/tfjs-core/src/io/io_utils.ts +++ b/tfjs-core/src/io/io_utils.ts @@ -21,7 +21,8 @@ import {NamedTensor, NamedTensorMap} from '../tensor_types'; import {TypedArray} from '../types'; import {sizeFromShape} from '../util'; -import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, ModelJSON, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types'; +import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, ModelJSON, WeightData, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types'; +import {CompositeArrayBuffer} from './composite_array_buffer'; /** Number of bytes reserved for the length of the string. (32bit integer). */ const NUM_BYTES_STRING_LENGTH = 4; @@ -101,8 +102,8 @@ export async function encodeWeights( * * This function is the reverse of `encodeWeights`. * - * @param buffer A flat ArrayBuffer carrying the binary values of the tensors - * concatenated in the order specified in `specs`. + * @param weightData A flat ArrayBuffer carrying the binary values of the + * tensors concatenated in the order specified in `specs`. * @param specs Specifications of the names, dtypes and shapes of the tensors * whose value are encoded by `buffer`. * @return A map from tensor name to tensor value, with the names corresponding @@ -110,8 +111,10 @@ export async function encodeWeights( * @throws Error, if any of the tensors has unsupported dtype. */ export function decodeWeights( - buffer: ArrayBuffer, specs: WeightsManifestEntry[]): NamedTensorMap { + weightData: WeightData, + specs: WeightsManifestEntry[]): NamedTensorMap { // TODO(adarob, cais): Support quantization. + const compositeBuffer = new CompositeArrayBuffer(weightData); const out: NamedTensorMap = {}; let float16Decode: (buffer: Uint16Array) => Float32Array | undefined; let offset = 0; @@ -145,7 +148,7 @@ export function decodeWeights( } const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype]; const byteBuffer = - buffer.slice(offset, offset + size * quantizationSizeFactor); + compositeBuffer.slice(offset, offset + size * quantizationSizeFactor); const quantizedArray = (quantization.dtype === 'uint8') ? new Uint8Array(byteBuffer) : new Uint16Array(byteBuffer); @@ -186,15 +189,17 @@ export function decodeWeights( values = []; for (let i = 0; i < size; i++) { const byteLength = new Uint32Array( - buffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0]; + compositeBuffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0]; offset += NUM_BYTES_STRING_LENGTH; - const bytes = new Uint8Array(buffer.slice(offset, offset + byteLength)); + const bytes = new Uint8Array( + compositeBuffer.slice(offset, offset + byteLength)); (values as Uint8Array[]).push(bytes); offset += byteLength; } } else { const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype]; - const byteBuffer = buffer.slice(offset, offset + size * dtypeFactor); + const byteBuffer = compositeBuffer.slice(offset, + offset + size * dtypeFactor); if (dtype === 'float32') { values = new Float32Array(byteBuffer); @@ -411,14 +416,14 @@ export function getModelJSONForModelArtifacts( * @param modelJSON Object containing the parsed JSON of `model.json` * @param weightSpecs The list of WeightsManifestEntry for the model. Must be * passed if the modelJSON has a weightsManifest. - * @param weightData An ArrayBuffer of weight data for the model corresponding - * to the weights in weightSpecs. Must be passed if the modelJSON has a - * weightsManifest. + * @param weightData An ArrayBuffer or array of ArrayBuffers of weight data for + * the model corresponding to the weights in weightSpecs. Must be passed if + * the modelJSON has a weightsManifest. * @returns A Promise of the `ModelArtifacts`, as described by the JSON file. */ export function getModelArtifactsForJSONSync( modelJSON: ModelJSON, weightSpecs?: WeightsManifestEntry[], - weightData?: ArrayBuffer): ModelArtifacts { + weightData?: WeightData): ModelArtifacts { const modelArtifacts: ModelArtifacts = { modelTopology: modelJSON.modelTopology, @@ -468,10 +473,10 @@ export function getModelArtifactsForJSONSync( export async function getModelArtifactsForJSON( modelJSON: ModelJSON, loadWeights: (weightsManifest: WeightsManifestConfig) => Promise<[ - /* weightSpecs */ WeightsManifestEntry[], /* weightData */ ArrayBuffer + /* weightSpecs */ WeightsManifestEntry[], WeightData, ]>): Promise { let weightSpecs: WeightsManifestEntry[] | undefined; - let weightData: ArrayBuffer | undefined; + let weightData: WeightData | undefined; if (modelJSON.weightsManifest != null) { [weightSpecs, weightData] = await loadWeights(modelJSON.weightsManifest); @@ -502,7 +507,7 @@ export function getModelArtifactsInfoForJSON(modelArtifacts: ModelArtifacts): stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)), weightDataBytes: modelArtifacts.weightData == null ? 0 : - modelArtifacts.weightData.byteLength, + new CompositeArrayBuffer(modelArtifacts.weightData).byteLength, }; } diff --git a/tfjs-core/src/io/local_storage.ts b/tfjs-core/src/io/local_storage.ts index 2de2639c35c..f362fd22bc9 100644 --- a/tfjs-core/src/io/local_storage.ts +++ b/tfjs-core/src/io/local_storage.ts @@ -19,7 +19,7 @@ import '../flags'; import {env} from '../environment'; import {assert} from '../util'; -import {arrayBufferToBase64String, base64StringToArrayBuffer, getModelArtifactsInfoForJSON} from './io_utils'; +import {arrayBufferToBase64String, base64StringToArrayBuffer, concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils'; import {IORouter, IORouterRegistry} from './router_registry'; import {IOHandler, ModelArtifacts, ModelArtifactsInfo, ModelJSON, ModelStoreManager, SaveResult} from './types'; @@ -174,13 +174,18 @@ export class BrowserLocalStorage implements IOHandler { const modelArtifactsInfo: ModelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts); + // TODO(mattsoulanille): Support saving models over 2GB that exceed + // Chrome's ArrayBuffer size limit. + const weightBuffer = concatenateArrayBuffers( + [modelArtifacts.weightData].flat()); + try { this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo)); this.LS.setItem(this.keys.topology, topology); this.LS.setItem(this.keys.weightSpecs, weightSpecs); this.LS.setItem( this.keys.weightData, - arrayBufferToBase64String(modelArtifacts.weightData)); + arrayBufferToBase64String(weightBuffer)); // Note that JSON.stringify doesn't write out keys that have undefined // values, so for some keys, we set undefined instead of a null-ish diff --git a/tfjs-core/src/io/model_management_test.ts b/tfjs-core/src/io/model_management_test.ts index db66e42367f..81111bbc460 100644 --- a/tfjs-core/src/io/model_management_test.ts +++ b/tfjs-core/src/io/model_management_test.ts @@ -19,6 +19,7 @@ import * as tf from '../index'; import {CHROME_ENVS, describeWithFlags, runWithLock} from '../jasmine_util'; import {deleteDatabase} from './indexed_db'; import {purgeLocalStorageArtifacts} from './local_storage'; +import {CompositeArrayBuffer} from './composite_array_buffer'; // Disabled for non-Chrome browsers due to: // https://github.com/tensorflow/tfjs/issues/427 @@ -268,7 +269,9 @@ describeWithFlags('ModelManagement', CHROME_ENVS, () => { .then(loaded => { expect(loaded.modelTopology).toEqual(modelTopology1); expect(loaded.weightSpecs).toEqual(weightSpecs1); - expect(new Uint8Array(loaded.weightData)) + expect(loaded.weightData).toBeDefined(); + expect(new Uint8Array( + new CompositeArrayBuffer(loaded.weightData).slice())) .toEqual(new Uint8Array(weightData1)); done(); }) @@ -311,7 +314,8 @@ describeWithFlags('ModelManagement', CHROME_ENVS, () => { .then(loaded => { expect(loaded.modelTopology).toEqual(modelTopology1); expect(loaded.weightSpecs).toEqual(weightSpecs1); - expect(new Uint8Array(loaded.weightData)) + expect(new Uint8Array( + new CompositeArrayBuffer(loaded.weightData).slice())) .toEqual(new Uint8Array(weightData1)); done(); }) diff --git a/tfjs-core/src/io/passthrough.ts b/tfjs-core/src/io/passthrough.ts index 5e9cd8636c6..e5416002d59 100644 --- a/tfjs-core/src/io/passthrough.ts +++ b/tfjs-core/src/io/passthrough.ts @@ -19,7 +19,7 @@ * IOHandlers that pass through the in-memory ModelArtifacts format. */ -import {IOHandler, IOHandlerSync, LoadHandler, ModelArtifacts, SaveHandler, SaveResult, TrainingConfig, WeightsManifestEntry} from './types'; +import {IOHandler, IOHandlerSync, LoadHandler, ModelArtifacts, SaveHandler, SaveResult, TrainingConfig, WeightData, WeightsManifestEntry} from './types'; class PassthroughLoader implements IOHandlerSync { constructor(private readonly modelArtifacts?: ModelArtifacts) {} @@ -76,7 +76,7 @@ class PassthroughAsync implements IOHandler { */ export function fromMemory( modelArtifacts: {}|ModelArtifacts, weightSpecs?: WeightsManifestEntry[], - weightData?: ArrayBuffer, trainingConfig?: TrainingConfig): IOHandler { + weightData?: WeightData, trainingConfig?: TrainingConfig): IOHandler { const args = arguments as unknown as Parameters; return new PassthroughAsync(fromMemorySync(...args)); @@ -105,7 +105,7 @@ export function fromMemory( */ export function fromMemorySync( modelArtifacts: {}|ModelArtifacts, weightSpecs?: WeightsManifestEntry[], - weightData?: ArrayBuffer, trainingConfig?: TrainingConfig): IOHandlerSync { + weightData?: WeightData, trainingConfig?: TrainingConfig): IOHandlerSync { if (arguments.length === 1) { const isModelArtifacts = (modelArtifacts as ModelArtifacts).modelTopology != null || diff --git a/tfjs-core/src/io/types.ts b/tfjs-core/src/io/types.ts index ece6b97e8b5..ff098b6ca61 100644 --- a/tfjs-core/src/io/types.ts +++ b/tfjs-core/src/io/types.ts @@ -215,6 +215,8 @@ export declare interface TrainingConfig { loss_weights?: number[]|{[key: string]: number}; } +export type WeightData = ArrayBuffer | ArrayBuffer[]; + /** * The serialized artifacts of a model, including topology and weights. * @@ -251,7 +253,7 @@ export declare interface ModelArtifacts { * Binary buffer for all weight values concatenated in the order specified * by `weightSpecs`. */ - weightData?: ArrayBuffer; + weightData?: WeightData; /** * Hard-coded format name for models saved from TensorFlow.js or converted diff --git a/tfjs-layers/src/models.ts b/tfjs-layers/src/models.ts index 833ab887833..40d6e79a320 100644 --- a/tfjs-layers/src/models.ts +++ b/tfjs-layers/src/models.ts @@ -342,9 +342,9 @@ export async function loadLayersModelFromIOHandler( } function decodeModelAndOptimizerWeights( - buffer: ArrayBuffer, specs: io.WeightsManifestEntry[]): + weightData: io.WeightData, specs: io.WeightsManifestEntry[]): {modelWeights: NamedTensorMap, optimizerWeights: NamedTensor[]} { - const name2Tensor = io.decodeWeights(buffer, specs); + const name2Tensor = io.decodeWeights(weightData, specs); const modelWeights: NamedTensorMap = {}; const optimizerWeights: NamedTensor[] = []; specs.forEach(spec => { diff --git a/tfjs-layers/src/models_test.ts b/tfjs-layers/src/models_test.ts index 1f470ed1d7c..e2db066cf49 100644 --- a/tfjs-layers/src/models_test.ts +++ b/tfjs-layers/src/models_test.ts @@ -1293,7 +1293,8 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(weightData.byteLength).toEqual(4 * 8 + 4 * 1 + 4); + expect(new io.CompositeArrayBuffer(weightData).byteLength) + .toEqual(4 * 8 + 4 * 1 + 4); // Load the model back, with the optimizer. const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts)); @@ -1352,7 +1353,8 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(weightData.byteLength).toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); + expect(new io.CompositeArrayBuffer(weightData) + .byteLength).toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); // Load the model back, with the optimizer. const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts)); @@ -1411,7 +1413,8 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(weightData.byteLength).toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); + expect(new io.CompositeArrayBuffer(weightData).byteLength) + .toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); // Load the model back, with the optimizer. const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts)); @@ -1468,7 +1471,8 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(weightData.byteLength).toEqual(4 + 4 * 8 * 2 + 4 * 1 * 2); + expect(new io.CompositeArrayBuffer(weightData).byteLength) + .toEqual(4 + 4 * 8 * 2 + 4 * 1 * 2); // Load the model back, with the optimizer. const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts)); @@ -1530,7 +1534,8 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(weightData.byteLength).toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); + expect(new io.CompositeArrayBuffer(weightData).byteLength) + .toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); // Load the model back, with the optimizer. const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts)); @@ -1588,7 +1593,8 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(weightData.byteLength).toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); + expect(new io.CompositeArrayBuffer(weightData).byteLength) + .toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); // Load the model back, with the optimizer. const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts)); @@ -1645,7 +1651,8 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(weightData.byteLength).toEqual(4 + 4 * 8 * 2 + 4 * 1 * 2); + expect(new io.CompositeArrayBuffer(weightData).byteLength) + .toEqual(4 + 4 * 8 * 2 + 4 * 1 * 2); // Load the model back, with the optimizer. const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts)); From b6de59758779e02abad20ec49d2f53f664f15371 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Thu, 20 Apr 2023 12:47:11 -0700 Subject: [PATCH 02/11] Avoid 'Array.flat()' --- tfjs-core/src/io/browser_files.ts | 3 +-- tfjs-core/src/io/http.ts | 3 +-- tfjs-core/src/io/io_utils.ts | 14 ++++++++++---- tfjs-core/src/io/local_storage.ts | 3 +-- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/tfjs-core/src/io/browser_files.ts b/tfjs-core/src/io/browser_files.ts index 07ce146143e..1f8f8e70968 100644 --- a/tfjs-core/src/io/browser_files.ts +++ b/tfjs-core/src/io/browser_files.ts @@ -73,8 +73,7 @@ export class BrowserDownloads implements IOHandler { // TODO(mattsoulanille): Support saving models over 2GB that exceed // Chrome's ArrayBuffer size limit. - const weightBuffer = concatenateArrayBuffers( - [modelArtifacts.weightData].flat()); + const weightBuffer = concatenateArrayBuffers(modelArtifacts.weightData) const weightsURL = window.URL.createObjectURL(new Blob( [weightBuffer], {type: 'application/octet-stream'})); diff --git a/tfjs-core/src/io/http.ts b/tfjs-core/src/io/http.ts index 58a2f1ec1dc..78fb60f4b11 100644 --- a/tfjs-core/src/io/http.ts +++ b/tfjs-core/src/io/http.ts @@ -112,8 +112,7 @@ export class HTTPRequest implements IOHandler { if (modelArtifacts.weightData != null) { // TODO(mattsoulanille): Support saving models over 2GB that exceed // Chrome's ArrayBuffer size limit. - const weightBuffer = concatenateArrayBuffers( - [modelArtifacts.weightData].flat()); + const weightBuffer = concatenateArrayBuffers(modelArtifacts.weightData); init.body.append( 'model.weights.bin', diff --git a/tfjs-core/src/io/io_utils.ts b/tfjs-core/src/io/io_utils.ts index e9024f17d72..18c14816428 100644 --- a/tfjs-core/src/io/io_utils.ts +++ b/tfjs-core/src/io/io_utils.ts @@ -102,8 +102,9 @@ export async function encodeWeights( * * This function is the reverse of `encodeWeights`. * - * @param weightData A flat ArrayBuffer carrying the binary values of the - * tensors concatenated in the order specified in `specs`. + * @param weightData A flat ArrayBuffer or an array of ArrayBuffers carrying the + * binary values of the tensors concatenated in the order specified in + * `specs`. * @param specs Specifications of the names, dtypes and shapes of the tensors * whose value are encoded by `buffer`. * @return A map from tensor name to tensor value, with the names corresponding @@ -335,10 +336,15 @@ export function base64StringToArrayBuffer(str: string): ArrayBuffer { /** * Concatenate a number of ArrayBuffers into one. * - * @param buffers A number of array buffers to concatenate. + * @param buffers An array of ArrayBuffers to concatenate, or a single + * ArrayBuffer. * @returns Result of concatenating `buffers` in order. */ -export function concatenateArrayBuffers(buffers: ArrayBuffer[]): ArrayBuffer { +export function concatenateArrayBuffers(buffers: ArrayBuffer[] + | ArrayBuffer): ArrayBuffer { + if (!(buffers instanceof Array)) { + return buffers; + } if (buffers.length === 1) { return buffers[0]; } diff --git a/tfjs-core/src/io/local_storage.ts b/tfjs-core/src/io/local_storage.ts index f362fd22bc9..14ca64ad117 100644 --- a/tfjs-core/src/io/local_storage.ts +++ b/tfjs-core/src/io/local_storage.ts @@ -176,8 +176,7 @@ export class BrowserLocalStorage implements IOHandler { // TODO(mattsoulanille): Support saving models over 2GB that exceed // Chrome's ArrayBuffer size limit. - const weightBuffer = concatenateArrayBuffers( - [modelArtifacts.weightData].flat()); + const weightBuffer = concatenateArrayBuffers(modelArtifacts.weightData); try { this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo)); From ffcfa680770f07367ae99728781fa98a03a009b2 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Thu, 20 Apr 2023 13:11:26 -0700 Subject: [PATCH 03/11] Simplify some of the tests --- .../src/executor/graph_model_test.ts | 4 +- tfjs-core/src/io/browser_files.ts | 2 +- tfjs-core/src/io/browser_files_test.ts | 30 ++++++------ tfjs-core/src/io/http_test.ts | 46 +++++++++---------- tfjs-core/src/io/indexed_db_test.ts | 10 ++-- tfjs-core/src/io/model_management_test.ts | 6 +-- 6 files changed, 48 insertions(+), 50 deletions(-) diff --git a/tfjs-converter/src/executor/graph_model_test.ts b/tfjs-converter/src/executor/graph_model_test.ts index 97adc903882..ba538561e00 100644 --- a/tfjs-converter/src/executor/graph_model_test.ts +++ b/tfjs-converter/src/executor/graph_model_test.ts @@ -586,8 +586,8 @@ describe('Model', () => { expect(handler.savedArtifacts.modelTopology).toEqual(CUSTOM_OP_MODEL); expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest); tfc.test_util.expectArraysClose( - new Int32Array(new io.CompositeArrayBuffer( - handler.savedArtifacts.weightData).slice()), bias.dataSync()); + new Int32Array(new io.CompositeArrayBuffer( + handler.savedArtifacts.weightData).slice()), bias.dataSync()); }); }); }); diff --git a/tfjs-core/src/io/browser_files.ts b/tfjs-core/src/io/browser_files.ts index 1f8f8e70968..56e1df7d391 100644 --- a/tfjs-core/src/io/browser_files.ts +++ b/tfjs-core/src/io/browser_files.ts @@ -73,7 +73,7 @@ export class BrowserDownloads implements IOHandler { // TODO(mattsoulanille): Support saving models over 2GB that exceed // Chrome's ArrayBuffer size limit. - const weightBuffer = concatenateArrayBuffers(modelArtifacts.weightData) + const weightBuffer = concatenateArrayBuffers(modelArtifacts.weightData); const weightsURL = window.URL.createObjectURL(new Blob( [weightBuffer], {type: 'application/octet-stream'})); diff --git a/tfjs-core/src/io/browser_files_test.ts b/tfjs-core/src/io/browser_files_test.ts index 728264510fd..4c2ae3445ec 100644 --- a/tfjs-core/src/io/browser_files_test.ts +++ b/tfjs-core/src/io/browser_files_test.ts @@ -23,7 +23,7 @@ import * as tf from '../index'; import {BROWSER_ENVS, describeWithFlags} from '../jasmine_util'; import {browserDownloads, BrowserDownloads, browserDownloadsRouter} from './browser_files'; import {WeightsManifestConfig, WeightsManifestEntry} from './types'; -import {CompositeArrayBuffer} from './composite_array_buffer'; +import {concatenateArrayBuffers} from './io_utils'; const modelTopology1: {} = { 'class_name': 'Sequential', @@ -311,8 +311,8 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => { expect(modelArtifacts.modelInitializer).toEqual({}); expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1); - expect(new Uint8Array(new CompositeArrayBuffer(modelArtifacts.weightData) - .slice())).toEqual(new Uint8Array(weightData1)); + expect(new Uint8Array(concatenateArrayBuffers(modelArtifacts.weightData))) + .toEqual(new Uint8Array(weightData1)); }); it(`One group, two paths`, async () => { @@ -352,10 +352,10 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => { const modelArtifacts = await filesHandler.load(); expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightSpecs); - expect(new Uint8Array(new CompositeArrayBuffer(modelArtifacts.weightData) - .slice())).toEqual(new Uint8Array([ - 1, 2, 3, 4, 10, 20, 30, 40 - ])); + expect(new Uint8Array(concatenateArrayBuffers(modelArtifacts.weightData))) + .toEqual(new Uint8Array([ + 1, 2, 3, 4, 10, 20, 30, 40 + ])); }); it(`Two groups, four paths, reverseOrder=false`, async () => { @@ -420,10 +420,10 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => { expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs) .toEqual(weightSpecs1.concat(weightSpecs2)); - expect(new Uint8Array(new CompositeArrayBuffer(modelArtifacts.weightData) - .slice())).toEqual(new Uint8Array([ - 1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80 - ])); + expect(new Uint8Array(concatenateArrayBuffers(modelArtifacts.weightData))) + .toEqual(new Uint8Array([ + 1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80 + ])); }); it(`Two groups, four paths, reverseOrder=true`, async () => { @@ -488,10 +488,10 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => { expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs) .toEqual(weightSpecs1.concat(weightSpecs2)); - expect(new Uint8Array(new CompositeArrayBuffer(modelArtifacts.weightData) - .slice())).toEqual(new Uint8Array([ - 1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80 - ])); + expect(new Uint8Array(concatenateArrayBuffers(modelArtifacts.weightData))) + .toEqual(new Uint8Array([ + 1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80 + ])); }); it('Upload model topology only', async () => { diff --git a/tfjs-core/src/io/http_test.ts b/tfjs-core/src/io/http_test.ts index 1b387a4b6bd..1e6d2f32c9f 100644 --- a/tfjs-core/src/io/http_test.ts +++ b/tfjs-core/src/io/http_test.ts @@ -18,7 +18,7 @@ import * as tf from '../index'; import {BROWSER_ENVS, CHROME_ENVS, describeWithFlags, NODE_ENVS} from '../jasmine_util'; import {HTTPRequest, httpRouter, parseUrl} from './http'; -import {CompositeArrayBuffer} from './composite_array_buffer'; +import {concatenateArrayBuffers} from './io_utils'; // Test data. const modelTopology1: {} = { @@ -162,8 +162,8 @@ describeWithFlags('http-load fetch', NODE_ENVS, () => { expect(modelArtifacts.generatedBy).toEqual('1.15'); expect(modelArtifacts.convertedBy).toEqual('1.3.1'); expect(modelArtifacts.userDefinedMetadata).toEqual({}); - expect(new Float32Array(new CompositeArrayBuffer(modelArtifacts.weightData) - .slice())).toEqual(floatData); + expect(new Float32Array(concatenateArrayBuffers(modelArtifacts.weightData))) + .toEqual(floatData); }); it('throw exception if no fetch polyfill', () => { @@ -509,8 +509,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.userDefinedMetadata).toEqual({}); expect(modelArtifacts.modelInitializer).toEqual({}); - expect(new Float32Array(new CompositeArrayBuffer( - modelArtifacts.weightData).slice())).toEqual(floatData); + expect(new Float32Array(concatenateArrayBuffers(modelArtifacts + .weightData))).toEqual(floatData); expect(Object.keys(requestInits).length).toEqual(2); // Assert that fetch is invoked with `window` as the context. expect(fetchSpy.calls.mostRecent().object).toEqual(window); @@ -553,8 +553,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { const modelArtifacts = await handler.load(); expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); - expect(new Float32Array(new CompositeArrayBuffer( - modelArtifacts.weightData).slice())).toEqual(floatData); + expect(new Float32Array(concatenateArrayBuffers(modelArtifacts + .weightData))).toEqual(floatData); expect(Object.keys(requestInits).length).toEqual(2); expect(Object.keys(requestInits).length).toEqual(2); expect(requestInits['./model.json'].headers['header_key_1']) @@ -603,8 +603,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { const modelArtifacts = await handler.load(); expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); - expect(new Float32Array(new CompositeArrayBuffer(modelArtifacts - .weightData).slice())).toEqual(new Float32Array([1, 3, 3, 7, 4])); + expect(new Float32Array(concatenateArrayBuffers(modelArtifacts + .weightData))).toEqual(new Float32Array([1, 3, 3, 7, 4])); }); it('2 groups, 2 weight, 2 paths', async () => { @@ -648,9 +648,9 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.weightSpecs) .toEqual( weightsManifest[0].weights.concat(weightsManifest[1].weights)); - expect(new Float32Array(new CompositeArrayBuffer( - modelArtifacts.weightData).slice())) - .toEqual(new Float32Array([1, 3, 3, 7, 4])); + expect(new Float32Array(concatenateArrayBuffers( + modelArtifacts.weightData))) + .toEqual(new Float32Array([1, 3, 3, 7, 4])); }); it('2 groups, 2 weight, 2 paths, Int32 and Uint8 Data', async () => { @@ -694,9 +694,9 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.weightSpecs) .toEqual( weightsManifest[0].weights.concat(weightsManifest[1].weights)); - expect(new Int32Array(new CompositeArrayBuffer(modelArtifacts.weightData) + expect(new Int32Array(concatenateArrayBuffers(modelArtifacts.weightData) .slice(0, 12))).toEqual(new Int32Array([1, 3, 3])); - expect(new Uint8Array(new CompositeArrayBuffer(modelArtifacts.weightData) + expect(new Uint8Array(concatenateArrayBuffers(modelArtifacts.weightData) .slice(12, 14))).toEqual(new Uint8Array([7, 4])); }); @@ -757,12 +757,10 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.weightSpecs) .toEqual( weightsManifest[0].weights.concat(weightsManifest[1].weights)); - expect(new Int32Array(new CompositeArrayBuffer( - modelArtifacts.weightData).slice(0, 12))) - .toEqual(new Int32Array([1, 3, 3])); - expect(new Float32Array(new CompositeArrayBuffer( - modelArtifacts.weightData).slice(12, 20))) - .toEqual(new Float32Array([-7, -4])); + expect(new Int32Array(concatenateArrayBuffers(modelArtifacts.weightData) + .slice(0, 12))).toEqual(new Int32Array([1, 3, 3])); + expect(new Float32Array(concatenateArrayBuffers(modelArtifacts.weightData) + .slice(12, 20))).toEqual(new Float32Array([-7, -4])); }); it('Missing modelTopology and weightsManifest leads to error', async () => { @@ -847,8 +845,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { const modelArtifacts = await handler.load(); expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); - expect(new Float32Array(new CompositeArrayBuffer( - modelArtifacts.weightData).slice())).toEqual(floatData); + expect(new Float32Array(concatenateArrayBuffers( + modelArtifacts.weightData))).toEqual(floatData); expect(Object.keys(requestInits).length).toEqual(2); expect(Object.keys(requestInits).length).toEqual(2); expect(requestInits['./model.json'].headers['header_key_1']) @@ -910,8 +908,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); - expect(new Float32Array(new CompositeArrayBuffer(modelArtifacts.weightData) - .slice())).toEqual(floatData); + expect(new Float32Array(concatenateArrayBuffers(modelArtifacts.weightData))) + .toEqual(floatData); expect(fetchInputs).toEqual(['./model.json', './weightfile0']); expect(fetchInits.length).toEqual(2); diff --git a/tfjs-core/src/io/indexed_db_test.ts b/tfjs-core/src/io/indexed_db_test.ts index 2502afc716d..ed078cc21ad 100644 --- a/tfjs-core/src/io/indexed_db_test.ts +++ b/tfjs-core/src/io/indexed_db_test.ts @@ -23,7 +23,7 @@ import * as tf from '../index'; import {BROWSER_ENVS, describeWithFlags, runWithLock} from '../jasmine_util'; import {expectArrayBuffersEqual} from '../test_util'; import {browserIndexedDB, BrowserIndexedDB, BrowserIndexedDBManager, deleteDatabase, indexedDBRouter} from './indexed_db'; -import {CompositeArrayBuffer} from './composite_array_buffer'; +import {concatenateArrayBuffers} from './io_utils'; describeWithFlags('IndexedDB', BROWSER_ENVS, () => { // Test data. @@ -122,8 +122,8 @@ describeWithFlags('IndexedDB', BROWSER_ENVS, () => { expect(loadedArtifacts.generatedBy).toEqual('TensorFlow.js v0.0.0'); expect(loadedArtifacts.convertedBy).toEqual(null); expect(loadedArtifacts.modelInitializer).toEqual({}); - expectArrayBuffersEqual(new CompositeArrayBuffer( - loadedArtifacts.weightData).slice(), weightData1); + expectArrayBuffersEqual(concatenateArrayBuffers( + loadedArtifacts.weightData), weightData1); })); it('Save two models and load one', runWithLock(async () => { @@ -163,8 +163,8 @@ describeWithFlags('IndexedDB', BROWSER_ENVS, () => { expect(loadedArtifacts.modelTopology).toEqual(modelTopology1); expect(loadedArtifacts.weightSpecs).toEqual(weightSpecs1); expect(loadedArtifacts.weightData).toBeDefined(); - expectArrayBuffersEqual(new CompositeArrayBuffer( - loadedArtifacts.weightData).slice(), weightData1); + expectArrayBuffersEqual(concatenateArrayBuffers( + loadedArtifacts.weightData), weightData1); })); it('Loading nonexistent model fails', runWithLock(async () => { diff --git a/tfjs-core/src/io/model_management_test.ts b/tfjs-core/src/io/model_management_test.ts index 81111bbc460..156f22e4b8a 100644 --- a/tfjs-core/src/io/model_management_test.ts +++ b/tfjs-core/src/io/model_management_test.ts @@ -18,8 +18,8 @@ import * as tf from '../index'; import {CHROME_ENVS, describeWithFlags, runWithLock} from '../jasmine_util'; import {deleteDatabase} from './indexed_db'; +import {concatenateArrayBuffers} from './io_utils'; import {purgeLocalStorageArtifacts} from './local_storage'; -import {CompositeArrayBuffer} from './composite_array_buffer'; // Disabled for non-Chrome browsers due to: // https://github.com/tensorflow/tfjs/issues/427 @@ -271,7 +271,7 @@ describeWithFlags('ModelManagement', CHROME_ENVS, () => { expect(loaded.weightSpecs).toEqual(weightSpecs1); expect(loaded.weightData).toBeDefined(); expect(new Uint8Array( - new CompositeArrayBuffer(loaded.weightData).slice())) + concatenateArrayBuffers(loaded.weightData))) .toEqual(new Uint8Array(weightData1)); done(); }) @@ -315,7 +315,7 @@ describeWithFlags('ModelManagement', CHROME_ENVS, () => { expect(loaded.modelTopology).toEqual(modelTopology1); expect(loaded.weightSpecs).toEqual(weightSpecs1); expect(new Uint8Array( - new CompositeArrayBuffer(loaded.weightData).slice())) + concatenateArrayBuffers(loaded.weightData))) .toEqual(new Uint8Array(weightData1)); done(); }) From d2fdededc47cac5898ec73aabfdb9db7566bfb91 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Thu, 20 Apr 2023 13:20:32 -0700 Subject: [PATCH 04/11] Do not export 'CompositeArrayBuffer' from tfjs-core --- tfjs-converter/src/executor/graph_model_test.ts | 12 ++++++------ tfjs-converter/src/operations/executors/spy_ops.ts | 14 +------------- tfjs-core/src/io/io.ts | 2 -- tfjs-layers/src/models_test.ts | 14 +++++++------- 4 files changed, 14 insertions(+), 28 deletions(-) diff --git a/tfjs-converter/src/executor/graph_model_test.ts b/tfjs-converter/src/executor/graph_model_test.ts index ba538561e00..ab987cdbc73 100644 --- a/tfjs-converter/src/executor/graph_model_test.ts +++ b/tfjs-converter/src/executor/graph_model_test.ts @@ -586,8 +586,8 @@ describe('Model', () => { expect(handler.savedArtifacts.modelTopology).toEqual(CUSTOM_OP_MODEL); expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest); tfc.test_util.expectArraysClose( - new Int32Array(new io.CompositeArrayBuffer( - handler.savedArtifacts.weightData).slice()), bias.dataSync()); + new Int32Array(io.concatenateArrayBuffers( + handler.savedArtifacts.weightData)), bias.dataSync()); }); }); }); @@ -617,8 +617,8 @@ describe('Model', () => { }); expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest); tfc.test_util.expectArraysClose( - new Int32Array(new io.CompositeArrayBuffer( - handler.savedArtifacts.weightData).slice()), bias.dataSync()); + new Int32Array(io.concatenateArrayBuffers( + handler.savedArtifacts.weightData)), bias.dataSync()); }); }); @@ -906,8 +906,8 @@ describe('Model', () => { }); expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest); tfc.test_util.expectArraysClose( - new Int32Array(new io.CompositeArrayBuffer(handler.savedArtifacts - .weightData).slice()), bias.dataSync()); + new Int32Array(io.concatenateArrayBuffers(handler.savedArtifacts + .weightData)), bias.dataSync()); }); }); diff --git a/tfjs-converter/src/operations/executors/spy_ops.ts b/tfjs-converter/src/operations/executors/spy_ops.ts index ae1ef993664..e8f9fb8e1bb 100644 --- a/tfjs-converter/src/operations/executors/spy_ops.ts +++ b/tfjs-converter/src/operations/executors/spy_ops.ts @@ -15,23 +15,11 @@ * ============================================================================= */ -// The opposite of Extract -type Without = T extends U ? never : T; - -// Do not spy on CompositeArrayBuffer because it is a class constructor. -type NotSpiedOn = 'CompositeArrayBuffer'; - export type RecursiveSpy = - T extends Function ? jasmine.Spy : - {[K in Without]: RecursiveSpy} & - {[K in Extract]: T[K]}; + T extends Function ? jasmine.Spy : {[K in keyof T]: RecursiveSpy}; export function spyOnAllFunctions(obj: T): RecursiveSpy { return Object.fromEntries(Object.entries(obj).map(([key, val]) => { - // TODO(mattSoulanille): Do not hard code this - if (key === 'CompositeArrayBuffer') { - return val; - } if (val instanceof Function) { return [key, jasmine.createSpy(`${key} spy`, val).and.callThrough()]; } else if (val instanceof Array) { diff --git a/tfjs-core/src/io/io.ts b/tfjs-core/src/io/io.ts index 49e9a1e2e06..81548f95d55 100644 --- a/tfjs-core/src/io/io.ts +++ b/tfjs-core/src/io/io.ts @@ -27,13 +27,11 @@ import {fromMemory, fromMemorySync, withSaveHandler, withSaveHandlerSync} from ' import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry'; import {IOHandler, IOHandlerSync, LoadHandler, LoadOptions, ModelArtifacts, ModelArtifactsInfo, ModelJSON, ModelStoreManager, OnProgressCallback, RequestDetails, SaveConfig, SaveHandler, SaveResult, TrainingConfig, WeightGroup, WeightsManifestConfig, WeightsManifestEntry, WeightData} from './types'; import {loadWeights, weightsLoaderFactory} from './weights_loader'; -import {CompositeArrayBuffer} from './composite_array_buffer'; export {copyModel, listModels, moveModel, removeModel} from './model_management'; export { browserFiles, browserHTTPRequest, - CompositeArrayBuffer, concatenateArrayBuffers, decodeWeights, encodeWeights, diff --git a/tfjs-layers/src/models_test.ts b/tfjs-layers/src/models_test.ts index e2db066cf49..4358873c753 100644 --- a/tfjs-layers/src/models_test.ts +++ b/tfjs-layers/src/models_test.ts @@ -1293,7 +1293,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(new io.CompositeArrayBuffer(weightData).byteLength) + expect(io.concatenateArrayBuffers(weightData).byteLength) .toEqual(4 * 8 + 4 * 1 + 4); // Load the model back, with the optimizer. @@ -1353,7 +1353,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(new io.CompositeArrayBuffer(weightData) + expect(io.concatenateArrayBuffers(weightData) .byteLength).toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); // Load the model back, with the optimizer. @@ -1413,7 +1413,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(new io.CompositeArrayBuffer(weightData).byteLength) + expect(io.concatenateArrayBuffers(weightData).byteLength) .toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); // Load the model back, with the optimizer. @@ -1471,7 +1471,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(new io.CompositeArrayBuffer(weightData).byteLength) + expect(io.concatenateArrayBuffers(weightData).byteLength) .toEqual(4 + 4 * 8 * 2 + 4 * 1 * 2); // Load the model back, with the optimizer. @@ -1534,7 +1534,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(new io.CompositeArrayBuffer(weightData).byteLength) + expect(io.concatenateArrayBuffers(weightData).byteLength) .toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); // Load the model back, with the optimizer. @@ -1593,7 +1593,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(new io.CompositeArrayBuffer(weightData).byteLength) + expect(io.concatenateArrayBuffers(weightData).byteLength) .toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); // Load the model back, with the optimizer. @@ -1651,7 +1651,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(new io.CompositeArrayBuffer(weightData).byteLength) + expect(io.concatenateArrayBuffers(weightData).byteLength) .toEqual(4 + 4 * 8 * 2 + 4 * 1 * 2); // Load the model back, with the optimizer. From 121396d89466ed6169e391a41b87e02e2c593c81 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Thu, 20 Apr 2023 13:26:33 -0700 Subject: [PATCH 05/11] Update doc for weightData --- tfjs-core/src/io/types.ts | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tfjs-core/src/io/types.ts b/tfjs-core/src/io/types.ts index ff098b6ca61..2dc0893a82f 100644 --- a/tfjs-core/src/io/types.ts +++ b/tfjs-core/src/io/types.ts @@ -250,8 +250,10 @@ export declare interface ModelArtifacts { weightSpecs?: WeightsManifestEntry[]; /** - * Binary buffer for all weight values concatenated in the order specified - * by `weightSpecs`. + * Binary buffer(s) for all weight values in the order specified by + * `weightSpecs`. This may be a single ArrayBuffer of all the weights + * concatenated together or an Array of ArrayBuffers containing the weights + * (weights may be sharded across multiple ArrayBuffers). */ weightData?: WeightData; From feaa673edd045ee2ec382072418219ea984f787d Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Thu, 20 Apr 2023 14:20:21 -0700 Subject: [PATCH 06/11] Fix tfjs-node --- tfjs-core/src/io/io_utils.ts | 2 +- tfjs-node/src/io/file_system.ts | 4 ++-- tfjs-node/src/io/file_system_test.ts | 9 ++++++--- tfjs-node/src/io/io_utils.ts | 27 --------------------------- tfjs-node/src/io/node_http_test.ts | 3 ++- 5 files changed, 11 insertions(+), 34 deletions(-) diff --git a/tfjs-core/src/io/io_utils.ts b/tfjs-core/src/io/io_utils.ts index 18c14816428..db57a7cb9e7 100644 --- a/tfjs-core/src/io/io_utils.ts +++ b/tfjs-core/src/io/io_utils.ts @@ -291,7 +291,7 @@ const useNodeBuffer = typeof Buffer !== 'undefined' && */ export function stringByteLength(str: string): number { if (useNodeBuffer) { - return Buffer.byteLength(str); + return Buffer.byteLength(str, 'utf8'); } return new Blob([str]).size; } diff --git a/tfjs-node/src/io/file_system.ts b/tfjs-node/src/io/file_system.ts index ad1148d68c0..74b4cc582d6 100644 --- a/tfjs-node/src/io/file_system.ts +++ b/tfjs-node/src/io/file_system.ts @@ -19,7 +19,7 @@ import * as tf from '@tensorflow/tfjs'; import * as fs from 'fs'; import {dirname, join, resolve} from 'path'; import {promisify} from 'util'; -import {getModelArtifactsInfoForJSON, toArrayBuffer} from './io_utils'; +import {toArrayBuffer} from './io_utils'; const stat = promisify(fs.stat); const writeFile = promisify(fs.writeFile); @@ -121,7 +121,7 @@ export class NodeFileSystem implements tf.io.IOHandler { // TODO(cais): Use explicit tf.io.ModelArtifactsInfo type below once it // is available. // tslint:disable-next-line:no-any - modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts) as any + modelArtifactsInfo: tf.io.getModelArtifactsInfoForJSON(modelArtifacts), }; } } diff --git a/tfjs-node/src/io/file_system_test.ts b/tfjs-node/src/io/file_system_test.ts index 7a918aa545b..12d7bfd3f48 100644 --- a/tfjs-node/src/io/file_system_test.ts +++ b/tfjs-node/src/io/file_system_test.ts @@ -150,7 +150,8 @@ describe('File system IOHandler', () => { expect(modelArtifacts.weightSpecs).toEqual(weightSpecs1); expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1); - expect(new Float32Array(modelArtifacts.weightData)) + expect(new Float32Array(tf.io.concatenateArrayBuffers( + modelArtifacts.weightData))) .toEqual(new Float32Array([0, 0, 0, 0])); done(); }) @@ -216,7 +217,8 @@ describe('File system IOHandler', () => { } ]); tf.test_util.expectArraysClose( - new Float32Array(modelArtifacts.weightData), + new Float32Array(tf.io.concatenateArrayBuffers( + modelArtifacts.weightData)), new Float32Array([-1.1, -3.3, -3.3, -7.7])); }); @@ -341,7 +343,8 @@ describe('File system IOHandler', () => { } ]); tf.test_util.expectArraysClose( - new Float32Array(modelArtifacts.weightData), + new Float32Array(tf.io.concatenateArrayBuffers( + modelArtifacts.weightData)), new Float32Array([-1.1, -3.3, -3.3, -7.7])); }); diff --git a/tfjs-node/src/io/io_utils.ts b/tfjs-node/src/io/io_utils.ts index 74582196eb4..9a486d8bfce 100644 --- a/tfjs-node/src/io/io_utils.ts +++ b/tfjs-node/src/io/io_utils.ts @@ -51,30 +51,3 @@ export function toArrayBuffer(buf: Buffer|Buffer[]): ArrayBuffer { return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength); } } - -// TODO(cais): Use explicit tf.io.ModelArtifactsInfo return type below once it -// is available. -/** - * Populate ModelArtifactsInfo fields for a model with JSON topology. - * @param modelArtifacts - * @returns A ModelArtifactsInfo object. - */ -export function getModelArtifactsInfoForJSON( - modelArtifacts: tf.io.ModelArtifacts) { - if (modelArtifacts.modelTopology instanceof ArrayBuffer) { - throw new Error('Expected JSON model topology, received ArrayBuffer.'); - } - return { - dateSaved: new Date(), - modelTopologyType: 'JSON', - modelTopologyBytes: modelArtifacts.modelTopology == null ? - 0 : - Buffer.byteLength(JSON.stringify(modelArtifacts.modelTopology), 'utf8'), - weightSpecsBytes: modelArtifacts.weightSpecs == null ? - 0 : - Buffer.byteLength(JSON.stringify(modelArtifacts.weightSpecs), 'utf8'), - weightDataBytes: modelArtifacts.weightData == null ? - 0 : - modelArtifacts.weightData.byteLength, - }; -} diff --git a/tfjs-node/src/io/node_http_test.ts b/tfjs-node/src/io/node_http_test.ts index ff3a52b68a0..0ad981cf34e 100644 --- a/tfjs-node/src/io/node_http_test.ts +++ b/tfjs-node/src/io/node_http_test.ts @@ -136,7 +136,8 @@ describe('nodeHTTPRequest-load', () => { expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1); - expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); + expect(new Float32Array(tf.io.concatenateArrayBuffers( + modelArtifacts.weightData))).toEqual(floatData); expect(requestInits).toEqual([ {credentials: 'include', cache: 'no-cache'}, From 94ed22e6b7600efac4e6abfa7ecc10877f3e2092 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Thu, 20 Apr 2023 14:22:18 -0700 Subject: [PATCH 07/11] Remove unused import --- tfjs-node/src/io/io_utils.ts | 2 -- 1 file changed, 2 deletions(-) diff --git a/tfjs-node/src/io/io_utils.ts b/tfjs-node/src/io/io_utils.ts index 9a486d8bfce..28608855664 100644 --- a/tfjs-node/src/io/io_utils.ts +++ b/tfjs-node/src/io/io_utils.ts @@ -15,8 +15,6 @@ * ============================================================================= */ -import * as tf from '@tensorflow/tfjs'; - /** * Convert an ArrayBuffer to a Buffer. */ From de30e69c585e82dbd1a1fd78be5ff48451653e61 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Wed, 3 May 2023 15:00:49 -0700 Subject: [PATCH 08/11] Replace concatenateArrayBuffers implementation with CompositeArrayBuffer.slice() --- tfjs-core/src/io/composite_array_buffer.ts | 26 ++++++++++++++++--- .../src/io/composite_array_buffer_test.ts | 14 +++++++++- tfjs-core/src/io/io.ts | 2 ++ tfjs-core/src/io/io_utils.ts | 22 +++------------- 4 files changed, 41 insertions(+), 23 deletions(-) diff --git a/tfjs-core/src/io/composite_array_buffer.ts b/tfjs-core/src/io/composite_array_buffer.ts index 6dc67da73f2..069d2204166 100644 --- a/tfjs-core/src/io/composite_array_buffer.ts +++ b/tfjs-core/src/io/composite_array_buffer.ts @@ -40,8 +40,22 @@ export class CompositeArrayBuffer { private bufferUniformSize?: number; public readonly byteLength: number; - constructor(buffers: ArrayBuffer | ArrayBuffer[] | TypedArray | + /** + * Concatenate a number of ArrayBuffers into one. + * + * @param buffers An array of ArrayBuffers to concatenate, or a single + * ArrayBuffer. + * @returns Result of concatenating `buffers` in order. + */ + static concatenateArrayBuffers(buffers?: ArrayBuffer[] | ArrayBuffer) { + return new CompositeArrayBuffer(buffers).slice(); + } + + constructor(buffers?: ArrayBuffer | ArrayBuffer[] | TypedArray | TypedArray[]) { + if (buffers == null) { + return; + } // Normalize the `buffers` input to be `ArrayBuffer[]`. if (!(buffers instanceof Array)) { buffers = [buffers]; @@ -85,6 +99,12 @@ export class CompositeArrayBuffer { } slice(start = 0, end = this.byteLength): ArrayBuffer { + // If there are no shards, then the CompositeArrayBuffer was initialized + // with no data. + if (this.shards.length === 0) { + return new ArrayBuffer(0); + } + // NaN is treated as zero for slicing. This matches ArrayBuffer's behavior. start = isNaN(Number(start)) ? 0 : start; end = isNaN(Number(end)) ? 0 : end; @@ -117,8 +137,8 @@ export class CompositeArrayBuffer { const globalEnd = Math.min(end, shard.end); const localEnd = globalEnd - shard.start; - const outputSlice = new Uint8Array(shard.buffer.slice(localStart, - localEnd)); + const outputSlice = new Uint8Array(shard.buffer, localStart, + localEnd - localStart); outputArray.set(outputSlice, outputStart); sliced += outputSlice.length; diff --git a/tfjs-core/src/io/composite_array_buffer_test.ts b/tfjs-core/src/io/composite_array_buffer_test.ts index fa64532ad8c..0c88bc4ad42 100644 --- a/tfjs-core/src/io/composite_array_buffer_test.ts +++ b/tfjs-core/src/io/composite_array_buffer_test.ts @@ -80,7 +80,7 @@ describe('CompositeArrayBuffer', () => { }); } - it('can be passed an empty arraybuffer', () => { + it('can be created from an empty arraybuffer', () => { const array = new Uint8Array([]); const singleComposite = new CompositeArrayBuffer(array.buffer); expectArraysEqual(new Uint8Array(singleComposite.slice()), []); @@ -92,6 +92,18 @@ describe('CompositeArrayBuffer', () => { expectArraysEqual(new Uint8Array(singleComposite.slice()), array); }); + it('can be created from zero arrays', () => { + const singleComposite = new CompositeArrayBuffer([]); + expectArraysEqual(new Uint8Array(singleComposite.slice()), + new Uint8Array()); + }); + + it('can be created from undefined input', () => { + const singleComposite = new CompositeArrayBuffer(); + expectArraysEqual(new Uint8Array(singleComposite.slice()), + new Uint8Array()); + }); + it('treats NaN as zero when passed as the start of slice', () => { const array = new Uint8Array([1,2,3]); const composite = new CompositeArrayBuffer(array.buffer); diff --git a/tfjs-core/src/io/io.ts b/tfjs-core/src/io/io.ts index 81548f95d55..49e9a1e2e06 100644 --- a/tfjs-core/src/io/io.ts +++ b/tfjs-core/src/io/io.ts @@ -27,11 +27,13 @@ import {fromMemory, fromMemorySync, withSaveHandler, withSaveHandlerSync} from ' import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry'; import {IOHandler, IOHandlerSync, LoadHandler, LoadOptions, ModelArtifacts, ModelArtifactsInfo, ModelJSON, ModelStoreManager, OnProgressCallback, RequestDetails, SaveConfig, SaveHandler, SaveResult, TrainingConfig, WeightGroup, WeightsManifestConfig, WeightsManifestEntry, WeightData} from './types'; import {loadWeights, weightsLoaderFactory} from './weights_loader'; +import {CompositeArrayBuffer} from './composite_array_buffer'; export {copyModel, listModels, moveModel, removeModel} from './model_management'; export { browserFiles, browserHTTPRequest, + CompositeArrayBuffer, concatenateArrayBuffers, decodeWeights, encodeWeights, diff --git a/tfjs-core/src/io/io_utils.ts b/tfjs-core/src/io/io_utils.ts index db57a7cb9e7..09f1a7176d3 100644 --- a/tfjs-core/src/io/io_utils.ts +++ b/tfjs-core/src/io/io_utils.ts @@ -339,28 +339,12 @@ export function base64StringToArrayBuffer(str: string): ArrayBuffer { * @param buffers An array of ArrayBuffers to concatenate, or a single * ArrayBuffer. * @returns Result of concatenating `buffers` in order. + * + * @deprecated Use tf.io.CompositeArrayBuffer.concatenateArrayBuffers instead. */ export function concatenateArrayBuffers(buffers: ArrayBuffer[] | ArrayBuffer): ArrayBuffer { - if (!(buffers instanceof Array)) { - return buffers; - } - if (buffers.length === 1) { - return buffers[0]; - } - - let totalByteLength = 0; - buffers.forEach((buffer: ArrayBuffer) => { - totalByteLength += buffer.byteLength; - }); - - const temp = new Uint8Array(totalByteLength); - let offset = 0; - buffers.forEach((buffer: ArrayBuffer) => { - temp.set(new Uint8Array(buffer), offset); - offset += buffer.byteLength; - }); - return temp.buffer; + return CompositeArrayBuffer.concatenateArrayBuffers(buffers); } /** From eb279cc0407f4c55a53bc943bed1fe2e4a7afe15 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Wed, 3 May 2023 15:04:03 -0700 Subject: [PATCH 09/11] Rename CompositeArrayBuffer.concatenateArrayBuffers to .join --- tfjs-core/src/io/composite_array_buffer.ts | 2 +- tfjs-core/src/io/io_utils.ts | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tfjs-core/src/io/composite_array_buffer.ts b/tfjs-core/src/io/composite_array_buffer.ts index 069d2204166..411fb074083 100644 --- a/tfjs-core/src/io/composite_array_buffer.ts +++ b/tfjs-core/src/io/composite_array_buffer.ts @@ -47,7 +47,7 @@ export class CompositeArrayBuffer { * ArrayBuffer. * @returns Result of concatenating `buffers` in order. */ - static concatenateArrayBuffers(buffers?: ArrayBuffer[] | ArrayBuffer) { + static join(buffers?: ArrayBuffer[] | ArrayBuffer) { return new CompositeArrayBuffer(buffers).slice(); } diff --git a/tfjs-core/src/io/io_utils.ts b/tfjs-core/src/io/io_utils.ts index 09f1a7176d3..fa9005a9ba8 100644 --- a/tfjs-core/src/io/io_utils.ts +++ b/tfjs-core/src/io/io_utils.ts @@ -340,11 +340,11 @@ export function base64StringToArrayBuffer(str: string): ArrayBuffer { * ArrayBuffer. * @returns Result of concatenating `buffers` in order. * - * @deprecated Use tf.io.CompositeArrayBuffer.concatenateArrayBuffers instead. + * @deprecated Use tf.io.CompositeArrayBuffer.join() instead. */ export function concatenateArrayBuffers(buffers: ArrayBuffer[] | ArrayBuffer): ArrayBuffer { - return CompositeArrayBuffer.concatenateArrayBuffers(buffers); + return CompositeArrayBuffer.join(buffers); } /** From dd2326137f808e4c9ade45688b318f2a94cca3bf Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Wed, 3 May 2023 15:16:21 -0700 Subject: [PATCH 10/11] Replace concatenateArrayBuffers with CompositeArrayBuffer.join in core --- .../src/operations/executors/spy_ops.ts | 14 ++++++++- tfjs-core/src/io/browser_files.ts | 5 ++-- tfjs-core/src/io/browser_files_test.ts | 10 +++---- tfjs-core/src/io/http.ts | 5 ++-- tfjs-core/src/io/http_test.ts | 29 ++++++++++--------- tfjs-core/src/io/indexed_db_test.ts | 6 ++-- tfjs-core/src/io/io_utils_test.ts | 2 ++ tfjs-core/src/io/local_storage.ts | 5 ++-- tfjs-core/src/io/model_management_test.ts | 6 ++-- 9 files changed, 50 insertions(+), 32 deletions(-) diff --git a/tfjs-converter/src/operations/executors/spy_ops.ts b/tfjs-converter/src/operations/executors/spy_ops.ts index e8f9fb8e1bb..ae1ef993664 100644 --- a/tfjs-converter/src/operations/executors/spy_ops.ts +++ b/tfjs-converter/src/operations/executors/spy_ops.ts @@ -15,11 +15,23 @@ * ============================================================================= */ +// The opposite of Extract +type Without = T extends U ? never : T; + +// Do not spy on CompositeArrayBuffer because it is a class constructor. +type NotSpiedOn = 'CompositeArrayBuffer'; + export type RecursiveSpy = - T extends Function ? jasmine.Spy : {[K in keyof T]: RecursiveSpy}; + T extends Function ? jasmine.Spy : + {[K in Without]: RecursiveSpy} & + {[K in Extract]: T[K]}; export function spyOnAllFunctions(obj: T): RecursiveSpy { return Object.fromEntries(Object.entries(obj).map(([key, val]) => { + // TODO(mattSoulanille): Do not hard code this + if (key === 'CompositeArrayBuffer') { + return val; + } if (val instanceof Function) { return [key, jasmine.createSpy(`${key} spy`, val).and.callThrough()]; } else if (val instanceof Array) { diff --git a/tfjs-core/src/io/browser_files.ts b/tfjs-core/src/io/browser_files.ts index 56e1df7d391..816e00a0820 100644 --- a/tfjs-core/src/io/browser_files.ts +++ b/tfjs-core/src/io/browser_files.ts @@ -23,9 +23,10 @@ import '../flags'; import {env} from '../environment'; -import {basename, concatenateArrayBuffers, getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts} from './io_utils'; +import {basename, getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts} from './io_utils'; import {IORouter, IORouterRegistry} from './router_registry'; import {IOHandler, ModelArtifacts, ModelJSON, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types'; +import {CompositeArrayBuffer} from './composite_array_buffer'; const DEFAULT_FILE_NAME_PREFIX = 'model'; const DEFAULT_JSON_EXTENSION_NAME = '.json'; @@ -73,7 +74,7 @@ export class BrowserDownloads implements IOHandler { // TODO(mattsoulanille): Support saving models over 2GB that exceed // Chrome's ArrayBuffer size limit. - const weightBuffer = concatenateArrayBuffers(modelArtifacts.weightData); + const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData); const weightsURL = window.URL.createObjectURL(new Blob( [weightBuffer], {type: 'application/octet-stream'})); diff --git a/tfjs-core/src/io/browser_files_test.ts b/tfjs-core/src/io/browser_files_test.ts index 4c2ae3445ec..349cc1f172c 100644 --- a/tfjs-core/src/io/browser_files_test.ts +++ b/tfjs-core/src/io/browser_files_test.ts @@ -23,7 +23,7 @@ import * as tf from '../index'; import {BROWSER_ENVS, describeWithFlags} from '../jasmine_util'; import {browserDownloads, BrowserDownloads, browserDownloadsRouter} from './browser_files'; import {WeightsManifestConfig, WeightsManifestEntry} from './types'; -import {concatenateArrayBuffers} from './io_utils'; +import {CompositeArrayBuffer} from './composite_array_buffer'; const modelTopology1: {} = { 'class_name': 'Sequential', @@ -311,7 +311,7 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => { expect(modelArtifacts.modelInitializer).toEqual({}); expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1); - expect(new Uint8Array(concatenateArrayBuffers(modelArtifacts.weightData))) + expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData))) .toEqual(new Uint8Array(weightData1)); }); @@ -352,7 +352,7 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => { const modelArtifacts = await filesHandler.load(); expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightSpecs); - expect(new Uint8Array(concatenateArrayBuffers(modelArtifacts.weightData))) + expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData))) .toEqual(new Uint8Array([ 1, 2, 3, 4, 10, 20, 30, 40 ])); @@ -420,7 +420,7 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => { expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs) .toEqual(weightSpecs1.concat(weightSpecs2)); - expect(new Uint8Array(concatenateArrayBuffers(modelArtifacts.weightData))) + expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData))) .toEqual(new Uint8Array([ 1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80 ])); @@ -488,7 +488,7 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => { expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs) .toEqual(weightSpecs1.concat(weightSpecs2)); - expect(new Uint8Array(concatenateArrayBuffers(modelArtifacts.weightData))) + expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData))) .toEqual(new Uint8Array([ 1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80 ])); diff --git a/tfjs-core/src/io/http.ts b/tfjs-core/src/io/http.ts index 78fb60f4b11..c30ce501dd3 100644 --- a/tfjs-core/src/io/http.ts +++ b/tfjs-core/src/io/http.ts @@ -24,7 +24,8 @@ import {env} from '../environment'; import {assert} from '../util'; -import {concatenateArrayBuffers, getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts, getWeightSpecs} from './io_utils'; +import {getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts, getWeightSpecs} from './io_utils'; +import {CompositeArrayBuffer} from './composite_array_buffer'; import {IORouter, IORouterRegistry} from './router_registry'; import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types'; import {loadWeightsAsArrayBuffer} from './weights_loader'; @@ -112,7 +113,7 @@ export class HTTPRequest implements IOHandler { if (modelArtifacts.weightData != null) { // TODO(mattsoulanille): Support saving models over 2GB that exceed // Chrome's ArrayBuffer size limit. - const weightBuffer = concatenateArrayBuffers(modelArtifacts.weightData); + const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData); init.body.append( 'model.weights.bin', diff --git a/tfjs-core/src/io/http_test.ts b/tfjs-core/src/io/http_test.ts index 1e6d2f32c9f..18b868940be 100644 --- a/tfjs-core/src/io/http_test.ts +++ b/tfjs-core/src/io/http_test.ts @@ -18,7 +18,7 @@ import * as tf from '../index'; import {BROWSER_ENVS, CHROME_ENVS, describeWithFlags, NODE_ENVS} from '../jasmine_util'; import {HTTPRequest, httpRouter, parseUrl} from './http'; -import {concatenateArrayBuffers} from './io_utils'; +import {CompositeArrayBuffer} from './composite_array_buffer'; // Test data. const modelTopology1: {} = { @@ -162,8 +162,8 @@ describeWithFlags('http-load fetch', NODE_ENVS, () => { expect(modelArtifacts.generatedBy).toEqual('1.15'); expect(modelArtifacts.convertedBy).toEqual('1.3.1'); expect(modelArtifacts.userDefinedMetadata).toEqual({}); - expect(new Float32Array(concatenateArrayBuffers(modelArtifacts.weightData))) - .toEqual(floatData); + expect(new Float32Array(CompositeArrayBuffer.join( + modelArtifacts.weightData))).toEqual(floatData); }); it('throw exception if no fetch polyfill', () => { @@ -509,7 +509,7 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.userDefinedMetadata).toEqual({}); expect(modelArtifacts.modelInitializer).toEqual({}); - expect(new Float32Array(concatenateArrayBuffers(modelArtifacts + expect(new Float32Array(CompositeArrayBuffer.join(modelArtifacts .weightData))).toEqual(floatData); expect(Object.keys(requestInits).length).toEqual(2); // Assert that fetch is invoked with `window` as the context. @@ -553,7 +553,7 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { const modelArtifacts = await handler.load(); expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); - expect(new Float32Array(concatenateArrayBuffers(modelArtifacts + expect(new Float32Array(CompositeArrayBuffer.join(modelArtifacts .weightData))).toEqual(floatData); expect(Object.keys(requestInits).length).toEqual(2); expect(Object.keys(requestInits).length).toEqual(2); @@ -603,7 +603,7 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { const modelArtifacts = await handler.load(); expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); - expect(new Float32Array(concatenateArrayBuffers(modelArtifacts + expect(new Float32Array(CompositeArrayBuffer.join(modelArtifacts .weightData))).toEqual(new Float32Array([1, 3, 3, 7, 4])); }); @@ -648,7 +648,7 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.weightSpecs) .toEqual( weightsManifest[0].weights.concat(weightsManifest[1].weights)); - expect(new Float32Array(concatenateArrayBuffers( + expect(new Float32Array(CompositeArrayBuffer.join( modelArtifacts.weightData))) .toEqual(new Float32Array([1, 3, 3, 7, 4])); }); @@ -694,9 +694,9 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.weightSpecs) .toEqual( weightsManifest[0].weights.concat(weightsManifest[1].weights)); - expect(new Int32Array(concatenateArrayBuffers(modelArtifacts.weightData) + expect(new Int32Array(CompositeArrayBuffer.join(modelArtifacts.weightData) .slice(0, 12))).toEqual(new Int32Array([1, 3, 3])); - expect(new Uint8Array(concatenateArrayBuffers(modelArtifacts.weightData) + expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData) .slice(12, 14))).toEqual(new Uint8Array([7, 4])); }); @@ -757,9 +757,10 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.weightSpecs) .toEqual( weightsManifest[0].weights.concat(weightsManifest[1].weights)); - expect(new Int32Array(concatenateArrayBuffers(modelArtifacts.weightData) + expect(new Int32Array(CompositeArrayBuffer.join(modelArtifacts.weightData) .slice(0, 12))).toEqual(new Int32Array([1, 3, 3])); - expect(new Float32Array(concatenateArrayBuffers(modelArtifacts.weightData) + expect(new Float32Array(CompositeArrayBuffer + .join(modelArtifacts.weightData) .slice(12, 20))).toEqual(new Float32Array([-7, -4])); }); @@ -845,7 +846,7 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { const modelArtifacts = await handler.load(); expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); - expect(new Float32Array(concatenateArrayBuffers( + expect(new Float32Array(CompositeArrayBuffer.join( modelArtifacts.weightData))).toEqual(floatData); expect(Object.keys(requestInits).length).toEqual(2); expect(Object.keys(requestInits).length).toEqual(2); @@ -908,8 +909,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); - expect(new Float32Array(concatenateArrayBuffers(modelArtifacts.weightData))) - .toEqual(floatData); + expect(new Float32Array(CompositeArrayBuffer + .join(modelArtifacts.weightData))).toEqual(floatData); expect(fetchInputs).toEqual(['./model.json', './weightfile0']); expect(fetchInits.length).toEqual(2); diff --git a/tfjs-core/src/io/indexed_db_test.ts b/tfjs-core/src/io/indexed_db_test.ts index ed078cc21ad..2aaff05e21a 100644 --- a/tfjs-core/src/io/indexed_db_test.ts +++ b/tfjs-core/src/io/indexed_db_test.ts @@ -23,7 +23,7 @@ import * as tf from '../index'; import {BROWSER_ENVS, describeWithFlags, runWithLock} from '../jasmine_util'; import {expectArrayBuffersEqual} from '../test_util'; import {browserIndexedDB, BrowserIndexedDB, BrowserIndexedDBManager, deleteDatabase, indexedDBRouter} from './indexed_db'; -import {concatenateArrayBuffers} from './io_utils'; +import {CompositeArrayBuffer} from './composite_array_buffer'; describeWithFlags('IndexedDB', BROWSER_ENVS, () => { // Test data. @@ -122,7 +122,7 @@ describeWithFlags('IndexedDB', BROWSER_ENVS, () => { expect(loadedArtifacts.generatedBy).toEqual('TensorFlow.js v0.0.0'); expect(loadedArtifacts.convertedBy).toEqual(null); expect(loadedArtifacts.modelInitializer).toEqual({}); - expectArrayBuffersEqual(concatenateArrayBuffers( + expectArrayBuffersEqual(CompositeArrayBuffer.join( loadedArtifacts.weightData), weightData1); })); @@ -163,7 +163,7 @@ describeWithFlags('IndexedDB', BROWSER_ENVS, () => { expect(loadedArtifacts.modelTopology).toEqual(modelTopology1); expect(loadedArtifacts.weightSpecs).toEqual(weightSpecs1); expect(loadedArtifacts.weightData).toBeDefined(); - expectArrayBuffersEqual(concatenateArrayBuffers( + expectArrayBuffersEqual(CompositeArrayBuffer.join( loadedArtifacts.weightData), weightData1); })); diff --git a/tfjs-core/src/io/io_utils_test.ts b/tfjs-core/src/io/io_utils_test.ts index 6656ef43aa3..01e497c075b 100644 --- a/tfjs-core/src/io/io_utils_test.ts +++ b/tfjs-core/src/io/io_utils_test.ts @@ -620,6 +620,8 @@ describeWithFlags( }); describe('concatenateArrayBuffers', () => { + // TODO(mattSoulanille): Move these tests to CompositeArrayBuffer.join when + // concatenateArrayBuffers is removed. it('Concatenate 3 non-empty ArrayBuffers', () => { const buffer1 = new Uint8Array([1, 2, 3]); const buffer2 = new Uint8Array([11, 22, 33, 44]); diff --git a/tfjs-core/src/io/local_storage.ts b/tfjs-core/src/io/local_storage.ts index 14ca64ad117..f20206a9290 100644 --- a/tfjs-core/src/io/local_storage.ts +++ b/tfjs-core/src/io/local_storage.ts @@ -19,7 +19,8 @@ import '../flags'; import {env} from '../environment'; import {assert} from '../util'; -import {arrayBufferToBase64String, base64StringToArrayBuffer, concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils'; +import {arrayBufferToBase64String, base64StringToArrayBuffer, getModelArtifactsInfoForJSON} from './io_utils'; +import {CompositeArrayBuffer} from './composite_array_buffer'; import {IORouter, IORouterRegistry} from './router_registry'; import {IOHandler, ModelArtifacts, ModelArtifactsInfo, ModelJSON, ModelStoreManager, SaveResult} from './types'; @@ -176,7 +177,7 @@ export class BrowserLocalStorage implements IOHandler { // TODO(mattsoulanille): Support saving models over 2GB that exceed // Chrome's ArrayBuffer size limit. - const weightBuffer = concatenateArrayBuffers(modelArtifacts.weightData); + const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData); try { this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo)); diff --git a/tfjs-core/src/io/model_management_test.ts b/tfjs-core/src/io/model_management_test.ts index 156f22e4b8a..0f8a958ec7e 100644 --- a/tfjs-core/src/io/model_management_test.ts +++ b/tfjs-core/src/io/model_management_test.ts @@ -18,7 +18,7 @@ import * as tf from '../index'; import {CHROME_ENVS, describeWithFlags, runWithLock} from '../jasmine_util'; import {deleteDatabase} from './indexed_db'; -import {concatenateArrayBuffers} from './io_utils'; +import {CompositeArrayBuffer} from './composite_array_buffer'; import {purgeLocalStorageArtifacts} from './local_storage'; // Disabled for non-Chrome browsers due to: @@ -271,7 +271,7 @@ describeWithFlags('ModelManagement', CHROME_ENVS, () => { expect(loaded.weightSpecs).toEqual(weightSpecs1); expect(loaded.weightData).toBeDefined(); expect(new Uint8Array( - concatenateArrayBuffers(loaded.weightData))) + CompositeArrayBuffer.join(loaded.weightData))) .toEqual(new Uint8Array(weightData1)); done(); }) @@ -315,7 +315,7 @@ describeWithFlags('ModelManagement', CHROME_ENVS, () => { expect(loaded.modelTopology).toEqual(modelTopology1); expect(loaded.weightSpecs).toEqual(weightSpecs1); expect(new Uint8Array( - concatenateArrayBuffers(loaded.weightData))) + CompositeArrayBuffer.join(loaded.weightData))) .toEqual(new Uint8Array(weightData1)); done(); }) From 0568a301807f8c3982482affe8f8379034066e34 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Wed, 3 May 2023 15:35:07 -0700 Subject: [PATCH 11/11] Change concatenateArrayBuffers to CompositeArrayBuffer.join for test files --- tfjs-converter/src/executor/graph_model_test.ts | 10 +++++----- tfjs-layers/src/models_test.ts | 14 +++++++------- tfjs-node/src/io/file_system_test.ts | 6 +++--- tfjs-node/src/io/node_http_test.ts | 2 +- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tfjs-converter/src/executor/graph_model_test.ts b/tfjs-converter/src/executor/graph_model_test.ts index ab987cdbc73..2b6d4ca104e 100644 --- a/tfjs-converter/src/executor/graph_model_test.ts +++ b/tfjs-converter/src/executor/graph_model_test.ts @@ -586,7 +586,7 @@ describe('Model', () => { expect(handler.savedArtifacts.modelTopology).toEqual(CUSTOM_OP_MODEL); expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest); tfc.test_util.expectArraysClose( - new Int32Array(io.concatenateArrayBuffers( + new Int32Array(io.CompositeArrayBuffer.join( handler.savedArtifacts.weightData)), bias.dataSync()); }); }); @@ -617,8 +617,8 @@ describe('Model', () => { }); expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest); tfc.test_util.expectArraysClose( - new Int32Array(io.concatenateArrayBuffers( - handler.savedArtifacts.weightData)), bias.dataSync()); + new Int32Array(io.CompositeArrayBuffer.join( + handler.savedArtifacts.weightData)), bias.dataSync()); }); }); @@ -906,8 +906,8 @@ describe('Model', () => { }); expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest); tfc.test_util.expectArraysClose( - new Int32Array(io.concatenateArrayBuffers(handler.savedArtifacts - .weightData)), bias.dataSync()); + new Int32Array(io.CompositeArrayBuffer.join(handler.savedArtifacts + .weightData)), bias.dataSync()); }); }); diff --git a/tfjs-layers/src/models_test.ts b/tfjs-layers/src/models_test.ts index 4358873c753..e2db066cf49 100644 --- a/tfjs-layers/src/models_test.ts +++ b/tfjs-layers/src/models_test.ts @@ -1293,7 +1293,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(io.concatenateArrayBuffers(weightData).byteLength) + expect(new io.CompositeArrayBuffer(weightData).byteLength) .toEqual(4 * 8 + 4 * 1 + 4); // Load the model back, with the optimizer. @@ -1353,7 +1353,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(io.concatenateArrayBuffers(weightData) + expect(new io.CompositeArrayBuffer(weightData) .byteLength).toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); // Load the model back, with the optimizer. @@ -1413,7 +1413,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(io.concatenateArrayBuffers(weightData).byteLength) + expect(new io.CompositeArrayBuffer(weightData).byteLength) .toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); // Load the model back, with the optimizer. @@ -1471,7 +1471,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(io.concatenateArrayBuffers(weightData).byteLength) + expect(new io.CompositeArrayBuffer(weightData).byteLength) .toEqual(4 + 4 * 8 * 2 + 4 * 1 * 2); // Load the model back, with the optimizer. @@ -1534,7 +1534,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(io.concatenateArrayBuffers(weightData).byteLength) + expect(new io.CompositeArrayBuffer(weightData).byteLength) .toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); // Load the model back, with the optimizer. @@ -1593,7 +1593,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(io.concatenateArrayBuffers(weightData).byteLength) + expect(new io.CompositeArrayBuffer(weightData).byteLength) .toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); // Load the model back, with the optimizer. @@ -1651,7 +1651,7 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(io.concatenateArrayBuffers(weightData).byteLength) + expect(new io.CompositeArrayBuffer(weightData).byteLength) .toEqual(4 + 4 * 8 * 2 + 4 * 1 * 2); // Load the model back, with the optimizer. diff --git a/tfjs-node/src/io/file_system_test.ts b/tfjs-node/src/io/file_system_test.ts index 12d7bfd3f48..13627f55da5 100644 --- a/tfjs-node/src/io/file_system_test.ts +++ b/tfjs-node/src/io/file_system_test.ts @@ -150,7 +150,7 @@ describe('File system IOHandler', () => { expect(modelArtifacts.weightSpecs).toEqual(weightSpecs1); expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1); - expect(new Float32Array(tf.io.concatenateArrayBuffers( + expect(new Float32Array(tf.io.CompositeArrayBuffer.join( modelArtifacts.weightData))) .toEqual(new Float32Array([0, 0, 0, 0])); done(); @@ -217,7 +217,7 @@ describe('File system IOHandler', () => { } ]); tf.test_util.expectArraysClose( - new Float32Array(tf.io.concatenateArrayBuffers( + new Float32Array(tf.io.CompositeArrayBuffer.join( modelArtifacts.weightData)), new Float32Array([-1.1, -3.3, -3.3, -7.7])); }); @@ -343,7 +343,7 @@ describe('File system IOHandler', () => { } ]); tf.test_util.expectArraysClose( - new Float32Array(tf.io.concatenateArrayBuffers( + new Float32Array(tf.io.CompositeArrayBuffer.join( modelArtifacts.weightData)), new Float32Array([-1.1, -3.3, -3.3, -7.7])); }); diff --git a/tfjs-node/src/io/node_http_test.ts b/tfjs-node/src/io/node_http_test.ts index 0ad981cf34e..83b6e6140d4 100644 --- a/tfjs-node/src/io/node_http_test.ts +++ b/tfjs-node/src/io/node_http_test.ts @@ -136,7 +136,7 @@ describe('nodeHTTPRequest-load', () => { expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1); - expect(new Float32Array(tf.io.concatenateArrayBuffers( + expect(new Float32Array(tf.io.CompositeArrayBuffer.join( modelArtifacts.weightData))).toEqual(floatData); expect(requestInits).toEqual([