diff --git a/tfjs-converter/src/executor/graph_model_test.ts b/tfjs-converter/src/executor/graph_model_test.ts index 7d3a3ade14f..2b6d4ca104e 100644 --- a/tfjs-converter/src/executor/graph_model_test.ts +++ b/tfjs-converter/src/executor/graph_model_test.ts @@ -586,7 +586,8 @@ describe('Model', () => { expect(handler.savedArtifacts.modelTopology).toEqual(CUSTOM_OP_MODEL); expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest); tfc.test_util.expectArraysClose( - new Int32Array(handler.savedArtifacts.weightData), bias.dataSync()); + new Int32Array(io.CompositeArrayBuffer.join( + handler.savedArtifacts.weightData)), bias.dataSync()); }); }); }); @@ -616,7 +617,8 @@ describe('Model', () => { }); expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest); tfc.test_util.expectArraysClose( - new Int32Array(handler.savedArtifacts.weightData), bias.dataSync()); + new Int32Array(io.CompositeArrayBuffer.join( + handler.savedArtifacts.weightData)), bias.dataSync()); }); }); @@ -904,7 +906,8 @@ describe('Model', () => { }); expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest); tfc.test_util.expectArraysClose( - new Int32Array(handler.savedArtifacts.weightData), bias.dataSync()); + new Int32Array(io.CompositeArrayBuffer.join(handler.savedArtifacts + .weightData)), bias.dataSync()); }); }); diff --git a/tfjs-converter/src/operations/executors/spy_ops.ts b/tfjs-converter/src/operations/executors/spy_ops.ts index e8f9fb8e1bb..ae1ef993664 100644 --- a/tfjs-converter/src/operations/executors/spy_ops.ts +++ b/tfjs-converter/src/operations/executors/spy_ops.ts @@ -15,11 +15,23 @@ * ============================================================================= */ +// The opposite of Extract +type Without = T extends U ? never : T; + +// Do not spy on CompositeArrayBuffer because it is a class constructor. +type NotSpiedOn = 'CompositeArrayBuffer'; + export type RecursiveSpy = - T extends Function ? jasmine.Spy : {[K in keyof T]: RecursiveSpy}; + T extends Function ? jasmine.Spy : + {[K in Without]: RecursiveSpy} & + {[K in Extract]: T[K]}; export function spyOnAllFunctions(obj: T): RecursiveSpy { return Object.fromEntries(Object.entries(obj).map(([key, val]) => { + // TODO(mattSoulanille): Do not hard code this + if (key === 'CompositeArrayBuffer') { + return val; + } if (val instanceof Function) { return [key, jasmine.createSpy(`${key} spy`, val).and.callThrough()]; } else if (val instanceof Array) { diff --git a/tfjs-core/src/io/browser_files.ts b/tfjs-core/src/io/browser_files.ts index 90b11058b1c..816e00a0820 100644 --- a/tfjs-core/src/io/browser_files.ts +++ b/tfjs-core/src/io/browser_files.ts @@ -23,9 +23,10 @@ import '../flags'; import {env} from '../environment'; -import {basename, concatenateArrayBuffers, getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts} from './io_utils'; +import {basename, getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts} from './io_utils'; import {IORouter, IORouterRegistry} from './router_registry'; -import {IOHandler, ModelArtifacts, ModelJSON, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types'; +import {IOHandler, ModelArtifacts, ModelJSON, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types'; +import {CompositeArrayBuffer} from './composite_array_buffer'; const DEFAULT_FILE_NAME_PREFIX = 'model'; const DEFAULT_JSON_EXTENSION_NAME = '.json'; @@ -70,8 +71,13 @@ export class BrowserDownloads implements IOHandler { 'Browser downloads are not supported in ' + 'this environment since `document` is not present'); } + + // TODO(mattsoulanille): Support saving models over 2GB that exceed + // Chrome's ArrayBuffer size limit. + const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData); + const weightsURL = window.URL.createObjectURL(new Blob( - [modelArtifacts.weightData], {type: 'application/octet-stream'})); + [weightBuffer], {type: 'application/octet-stream'})); if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error( @@ -169,7 +175,7 @@ class BrowserFiles implements IOHandler { } private loadWeights(weightsManifest: WeightsManifestConfig): Promise<[ - /* weightSpecs */ WeightsManifestEntry[], /* weightData */ ArrayBuffer + /* weightSpecs */ WeightsManifestEntry[], WeightData, ]> { const weightSpecs: WeightsManifestEntry[] = []; const paths: string[] = []; @@ -185,7 +191,7 @@ class BrowserFiles implements IOHandler { paths.map(path => this.loadWeightsFile(path, pathToFile[path])); return Promise.all(promises).then( - buffers => [weightSpecs, concatenateArrayBuffers(buffers)]); + buffers => [weightSpecs, buffers]); } private loadWeightsFile(path: string, file: File): Promise { diff --git a/tfjs-core/src/io/browser_files_test.ts b/tfjs-core/src/io/browser_files_test.ts index d9f63f7f26c..349cc1f172c 100644 --- a/tfjs-core/src/io/browser_files_test.ts +++ b/tfjs-core/src/io/browser_files_test.ts @@ -23,6 +23,7 @@ import * as tf from '../index'; import {BROWSER_ENVS, describeWithFlags} from '../jasmine_util'; import {browserDownloads, BrowserDownloads, browserDownloadsRouter} from './browser_files'; import {WeightsManifestConfig, WeightsManifestEntry} from './types'; +import {CompositeArrayBuffer} from './composite_array_buffer'; const modelTopology1: {} = { 'class_name': 'Sequential', @@ -310,7 +311,7 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => { expect(modelArtifacts.modelInitializer).toEqual({}); expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1); - expect(new Uint8Array(modelArtifacts.weightData)) + expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData))) .toEqual(new Uint8Array(weightData1)); }); @@ -351,9 +352,10 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => { const modelArtifacts = await filesHandler.load(); expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightSpecs); - expect(new Uint8Array(modelArtifacts.weightData)).toEqual(new Uint8Array([ - 1, 2, 3, 4, 10, 20, 30, 40 - ])); + expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData))) + .toEqual(new Uint8Array([ + 1, 2, 3, 4, 10, 20, 30, 40 + ])); }); it(`Two groups, four paths, reverseOrder=false`, async () => { @@ -418,9 +420,10 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => { expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs) .toEqual(weightSpecs1.concat(weightSpecs2)); - expect(new Uint8Array(modelArtifacts.weightData)).toEqual(new Uint8Array([ - 1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80 - ])); + expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData))) + .toEqual(new Uint8Array([ + 1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80 + ])); }); it(`Two groups, four paths, reverseOrder=true`, async () => { @@ -485,9 +488,10 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => { expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs) .toEqual(weightSpecs1.concat(weightSpecs2)); - expect(new Uint8Array(modelArtifacts.weightData)).toEqual(new Uint8Array([ - 1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80 - ])); + expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData))) + .toEqual(new Uint8Array([ + 1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80 + ])); }); it('Upload model topology only', async () => { diff --git a/tfjs-core/src/io/composite_array_buffer.ts b/tfjs-core/src/io/composite_array_buffer.ts index 6dc67da73f2..411fb074083 100644 --- a/tfjs-core/src/io/composite_array_buffer.ts +++ b/tfjs-core/src/io/composite_array_buffer.ts @@ -40,8 +40,22 @@ export class CompositeArrayBuffer { private bufferUniformSize?: number; public readonly byteLength: number; - constructor(buffers: ArrayBuffer | ArrayBuffer[] | TypedArray | + /** + * Concatenate a number of ArrayBuffers into one. + * + * @param buffers An array of ArrayBuffers to concatenate, or a single + * ArrayBuffer. + * @returns Result of concatenating `buffers` in order. + */ + static join(buffers?: ArrayBuffer[] | ArrayBuffer) { + return new CompositeArrayBuffer(buffers).slice(); + } + + constructor(buffers?: ArrayBuffer | ArrayBuffer[] | TypedArray | TypedArray[]) { + if (buffers == null) { + return; + } // Normalize the `buffers` input to be `ArrayBuffer[]`. if (!(buffers instanceof Array)) { buffers = [buffers]; @@ -85,6 +99,12 @@ export class CompositeArrayBuffer { } slice(start = 0, end = this.byteLength): ArrayBuffer { + // If there are no shards, then the CompositeArrayBuffer was initialized + // with no data. + if (this.shards.length === 0) { + return new ArrayBuffer(0); + } + // NaN is treated as zero for slicing. This matches ArrayBuffer's behavior. start = isNaN(Number(start)) ? 0 : start; end = isNaN(Number(end)) ? 0 : end; @@ -117,8 +137,8 @@ export class CompositeArrayBuffer { const globalEnd = Math.min(end, shard.end); const localEnd = globalEnd - shard.start; - const outputSlice = new Uint8Array(shard.buffer.slice(localStart, - localEnd)); + const outputSlice = new Uint8Array(shard.buffer, localStart, + localEnd - localStart); outputArray.set(outputSlice, outputStart); sliced += outputSlice.length; diff --git a/tfjs-core/src/io/composite_array_buffer_test.ts b/tfjs-core/src/io/composite_array_buffer_test.ts index fa64532ad8c..0c88bc4ad42 100644 --- a/tfjs-core/src/io/composite_array_buffer_test.ts +++ b/tfjs-core/src/io/composite_array_buffer_test.ts @@ -80,7 +80,7 @@ describe('CompositeArrayBuffer', () => { }); } - it('can be passed an empty arraybuffer', () => { + it('can be created from an empty arraybuffer', () => { const array = new Uint8Array([]); const singleComposite = new CompositeArrayBuffer(array.buffer); expectArraysEqual(new Uint8Array(singleComposite.slice()), []); @@ -92,6 +92,18 @@ describe('CompositeArrayBuffer', () => { expectArraysEqual(new Uint8Array(singleComposite.slice()), array); }); + it('can be created from zero arrays', () => { + const singleComposite = new CompositeArrayBuffer([]); + expectArraysEqual(new Uint8Array(singleComposite.slice()), + new Uint8Array()); + }); + + it('can be created from undefined input', () => { + const singleComposite = new CompositeArrayBuffer(); + expectArraysEqual(new Uint8Array(singleComposite.slice()), + new Uint8Array()); + }); + it('treats NaN as zero when passed as the start of slice', () => { const array = new Uint8Array([1,2,3]); const composite = new CompositeArrayBuffer(array.buffer); diff --git a/tfjs-core/src/io/http.ts b/tfjs-core/src/io/http.ts index 5b3aab81fb9..c30ce501dd3 100644 --- a/tfjs-core/src/io/http.ts +++ b/tfjs-core/src/io/http.ts @@ -24,9 +24,10 @@ import {env} from '../environment'; import {assert} from '../util'; -import {concatenateArrayBuffers, getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts, getWeightSpecs} from './io_utils'; +import {getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts, getWeightSpecs} from './io_utils'; +import {CompositeArrayBuffer} from './composite_array_buffer'; import {IORouter, IORouterRegistry} from './router_registry'; -import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types'; +import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types'; import {loadWeightsAsArrayBuffer} from './weights_loader'; const OCTET_STREAM_MIME_TYPE = 'application/octet-stream'; @@ -110,9 +111,13 @@ export class HTTPRequest implements IOHandler { 'model.json'); if (modelArtifacts.weightData != null) { + // TODO(mattsoulanille): Support saving models over 2GB that exceed + // Chrome's ArrayBuffer size limit. + const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData); + init.body.append( 'model.weights.bin', - new Blob([modelArtifacts.weightData], {type: OCTET_STREAM_MIME_TYPE}), + new Blob([weightBuffer], {type: OCTET_STREAM_MIME_TYPE}), 'model.weights.bin'); } @@ -182,7 +187,7 @@ export class HTTPRequest implements IOHandler { } private async loadWeights(weightsManifest: WeightsManifestConfig): - Promise<[WeightsManifestEntry[], ArrayBuffer]> { + Promise<[WeightsManifestEntry[], WeightData]> { const weightPath = Array.isArray(this.path) ? this.path[1] : this.path; const [prefix, suffix] = parseUrl(weightPath); const pathPrefix = this.weightPathPrefix || prefix; @@ -210,7 +215,7 @@ export class HTTPRequest implements IOHandler { fetchFunc: this.fetch, onProgress: this.onProgress }); - return [weightSpecs, concatenateArrayBuffers(buffers)]; + return [weightSpecs, buffers]; } } diff --git a/tfjs-core/src/io/http_test.ts b/tfjs-core/src/io/http_test.ts index 4be84eb9eef..18b868940be 100644 --- a/tfjs-core/src/io/http_test.ts +++ b/tfjs-core/src/io/http_test.ts @@ -18,6 +18,7 @@ import * as tf from '../index'; import {BROWSER_ENVS, CHROME_ENVS, describeWithFlags, NODE_ENVS} from '../jasmine_util'; import {HTTPRequest, httpRouter, parseUrl} from './http'; +import {CompositeArrayBuffer} from './composite_array_buffer'; // Test data. const modelTopology1: {} = { @@ -161,7 +162,8 @@ describeWithFlags('http-load fetch', NODE_ENVS, () => { expect(modelArtifacts.generatedBy).toEqual('1.15'); expect(modelArtifacts.convertedBy).toEqual('1.3.1'); expect(modelArtifacts.userDefinedMetadata).toEqual({}); - expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); + expect(new Float32Array(CompositeArrayBuffer.join( + modelArtifacts.weightData))).toEqual(floatData); }); it('throw exception if no fetch polyfill', () => { @@ -507,7 +509,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.userDefinedMetadata).toEqual({}); expect(modelArtifacts.modelInitializer).toEqual({}); - expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); + expect(new Float32Array(CompositeArrayBuffer.join(modelArtifacts + .weightData))).toEqual(floatData); expect(Object.keys(requestInits).length).toEqual(2); // Assert that fetch is invoked with `window` as the context. expect(fetchSpy.calls.mostRecent().object).toEqual(window); @@ -550,7 +553,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { const modelArtifacts = await handler.load(); expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); - expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); + expect(new Float32Array(CompositeArrayBuffer.join(modelArtifacts + .weightData))).toEqual(floatData); expect(Object.keys(requestInits).length).toEqual(2); expect(Object.keys(requestInits).length).toEqual(2); expect(requestInits['./model.json'].headers['header_key_1']) @@ -599,8 +603,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { const modelArtifacts = await handler.load(); expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); - expect(new Float32Array(modelArtifacts.weightData)) - .toEqual(new Float32Array([1, 3, 3, 7, 4])); + expect(new Float32Array(CompositeArrayBuffer.join(modelArtifacts + .weightData))).toEqual(new Float32Array([1, 3, 3, 7, 4])); }); it('2 groups, 2 weight, 2 paths', async () => { @@ -644,8 +648,9 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.weightSpecs) .toEqual( weightsManifest[0].weights.concat(weightsManifest[1].weights)); - expect(new Float32Array(modelArtifacts.weightData)) - .toEqual(new Float32Array([1, 3, 3, 7, 4])); + expect(new Float32Array(CompositeArrayBuffer.join( + modelArtifacts.weightData))) + .toEqual(new Float32Array([1, 3, 3, 7, 4])); }); it('2 groups, 2 weight, 2 paths, Int32 and Uint8 Data', async () => { @@ -689,10 +694,10 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.weightSpecs) .toEqual( weightsManifest[0].weights.concat(weightsManifest[1].weights)); - expect(new Int32Array(modelArtifacts.weightData.slice(0, 12))) - .toEqual(new Int32Array([1, 3, 3])); - expect(new Uint8Array(modelArtifacts.weightData.slice(12, 14))) - .toEqual(new Uint8Array([7, 4])); + expect(new Int32Array(CompositeArrayBuffer.join(modelArtifacts.weightData) + .slice(0, 12))).toEqual(new Int32Array([1, 3, 3])); + expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData) + .slice(12, 14))).toEqual(new Uint8Array([7, 4])); }); it('topology only', async () => { @@ -752,10 +757,11 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.weightSpecs) .toEqual( weightsManifest[0].weights.concat(weightsManifest[1].weights)); - expect(new Int32Array(modelArtifacts.weightData.slice(0, 12))) - .toEqual(new Int32Array([1, 3, 3])); - expect(new Float32Array(modelArtifacts.weightData.slice(12, 20))) - .toEqual(new Float32Array([-7, -4])); + expect(new Int32Array(CompositeArrayBuffer.join(modelArtifacts.weightData) + .slice(0, 12))).toEqual(new Int32Array([1, 3, 3])); + expect(new Float32Array(CompositeArrayBuffer + .join(modelArtifacts.weightData) + .slice(12, 20))).toEqual(new Float32Array([-7, -4])); }); it('Missing modelTopology and weightsManifest leads to error', async () => { @@ -840,7 +846,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { const modelArtifacts = await handler.load(); expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); - expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); + expect(new Float32Array(CompositeArrayBuffer.join( + modelArtifacts.weightData))).toEqual(floatData); expect(Object.keys(requestInits).length).toEqual(2); expect(Object.keys(requestInits).length).toEqual(2); expect(requestInits['./model.json'].headers['header_key_1']) @@ -902,7 +909,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); - expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); + expect(new Float32Array(CompositeArrayBuffer + .join(modelArtifacts.weightData))).toEqual(floatData); expect(fetchInputs).toEqual(['./model.json', './weightfile0']); expect(fetchInits.length).toEqual(2); diff --git a/tfjs-core/src/io/indexed_db_test.ts b/tfjs-core/src/io/indexed_db_test.ts index 308fce2e464..2aaff05e21a 100644 --- a/tfjs-core/src/io/indexed_db_test.ts +++ b/tfjs-core/src/io/indexed_db_test.ts @@ -23,6 +23,7 @@ import * as tf from '../index'; import {BROWSER_ENVS, describeWithFlags, runWithLock} from '../jasmine_util'; import {expectArrayBuffersEqual} from '../test_util'; import {browserIndexedDB, BrowserIndexedDB, BrowserIndexedDBManager, deleteDatabase, indexedDBRouter} from './indexed_db'; +import {CompositeArrayBuffer} from './composite_array_buffer'; describeWithFlags('IndexedDB', BROWSER_ENVS, () => { // Test data. @@ -121,8 +122,9 @@ describeWithFlags('IndexedDB', BROWSER_ENVS, () => { expect(loadedArtifacts.generatedBy).toEqual('TensorFlow.js v0.0.0'); expect(loadedArtifacts.convertedBy).toEqual(null); expect(loadedArtifacts.modelInitializer).toEqual({}); - expectArrayBuffersEqual(loadedArtifacts.weightData, weightData1); - })); + expectArrayBuffersEqual(CompositeArrayBuffer.join( + loadedArtifacts.weightData), weightData1); + })); it('Save two models and load one', runWithLock(async () => { const weightData2 = new ArrayBuffer(24); @@ -160,7 +162,9 @@ describeWithFlags('IndexedDB', BROWSER_ENVS, () => { const loadedArtifacts = await handler1.load(); expect(loadedArtifacts.modelTopology).toEqual(modelTopology1); expect(loadedArtifacts.weightSpecs).toEqual(weightSpecs1); - expectArrayBuffersEqual(loadedArtifacts.weightData, weightData1); + expect(loadedArtifacts.weightData).toBeDefined(); + expectArrayBuffersEqual(CompositeArrayBuffer.join( + loadedArtifacts.weightData), weightData1); })); it('Loading nonexistent model fails', runWithLock(async () => { diff --git a/tfjs-core/src/io/io.ts b/tfjs-core/src/io/io.ts index 29383e451c5..49e9a1e2e06 100644 --- a/tfjs-core/src/io/io.ts +++ b/tfjs-core/src/io/io.ts @@ -25,13 +25,15 @@ import {browserHTTPRequest, http, isHTTPScheme} from './http'; import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsForJSON, getModelArtifactsForJSONSync, getModelArtifactsInfoForJSON, getWeightSpecs} from './io_utils'; import {fromMemory, fromMemorySync, withSaveHandler, withSaveHandlerSync} from './passthrough'; import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry'; -import {IOHandler, IOHandlerSync, LoadHandler, LoadOptions, ModelArtifacts, ModelArtifactsInfo, ModelJSON, ModelStoreManager, OnProgressCallback, RequestDetails, SaveConfig, SaveHandler, SaveResult, TrainingConfig, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types'; +import {IOHandler, IOHandlerSync, LoadHandler, LoadOptions, ModelArtifacts, ModelArtifactsInfo, ModelJSON, ModelStoreManager, OnProgressCallback, RequestDetails, SaveConfig, SaveHandler, SaveResult, TrainingConfig, WeightGroup, WeightsManifestConfig, WeightsManifestEntry, WeightData} from './types'; import {loadWeights, weightsLoaderFactory} from './weights_loader'; +import {CompositeArrayBuffer} from './composite_array_buffer'; export {copyModel, listModels, moveModel, removeModel} from './model_management'; export { browserFiles, browserHTTPRequest, + CompositeArrayBuffer, concatenateArrayBuffers, decodeWeights, encodeWeights, @@ -62,6 +64,7 @@ export { SaveHandler, SaveResult, TrainingConfig, + WeightData, WeightGroup, weightsLoaderFactory, WeightsManifestConfig, diff --git a/tfjs-core/src/io/io_utils.ts b/tfjs-core/src/io/io_utils.ts index ce20b7c9fee..fa9005a9ba8 100644 --- a/tfjs-core/src/io/io_utils.ts +++ b/tfjs-core/src/io/io_utils.ts @@ -21,7 +21,8 @@ import {NamedTensor, NamedTensorMap} from '../tensor_types'; import {TypedArray} from '../types'; import {sizeFromShape} from '../util'; -import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, ModelJSON, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types'; +import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, ModelJSON, WeightData, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types'; +import {CompositeArrayBuffer} from './composite_array_buffer'; /** Number of bytes reserved for the length of the string. (32bit integer). */ const NUM_BYTES_STRING_LENGTH = 4; @@ -101,8 +102,9 @@ export async function encodeWeights( * * This function is the reverse of `encodeWeights`. * - * @param buffer A flat ArrayBuffer carrying the binary values of the tensors - * concatenated in the order specified in `specs`. + * @param weightData A flat ArrayBuffer or an array of ArrayBuffers carrying the + * binary values of the tensors concatenated in the order specified in + * `specs`. * @param specs Specifications of the names, dtypes and shapes of the tensors * whose value are encoded by `buffer`. * @return A map from tensor name to tensor value, with the names corresponding @@ -110,8 +112,10 @@ export async function encodeWeights( * @throws Error, if any of the tensors has unsupported dtype. */ export function decodeWeights( - buffer: ArrayBuffer, specs: WeightsManifestEntry[]): NamedTensorMap { + weightData: WeightData, + specs: WeightsManifestEntry[]): NamedTensorMap { // TODO(adarob, cais): Support quantization. + const compositeBuffer = new CompositeArrayBuffer(weightData); const out: NamedTensorMap = {}; let float16Decode: (buffer: Uint16Array) => Float32Array | undefined; let offset = 0; @@ -145,7 +149,7 @@ export function decodeWeights( } const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype]; const byteBuffer = - buffer.slice(offset, offset + size * quantizationSizeFactor); + compositeBuffer.slice(offset, offset + size * quantizationSizeFactor); const quantizedArray = (quantization.dtype === 'uint8') ? new Uint8Array(byteBuffer) : new Uint16Array(byteBuffer); @@ -186,15 +190,17 @@ export function decodeWeights( values = []; for (let i = 0; i < size; i++) { const byteLength = new Uint32Array( - buffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0]; + compositeBuffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0]; offset += NUM_BYTES_STRING_LENGTH; - const bytes = new Uint8Array(buffer.slice(offset, offset + byteLength)); + const bytes = new Uint8Array( + compositeBuffer.slice(offset, offset + byteLength)); (values as Uint8Array[]).push(bytes); offset += byteLength; } } else { const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype]; - const byteBuffer = buffer.slice(offset, offset + size * dtypeFactor); + const byteBuffer = compositeBuffer.slice(offset, + offset + size * dtypeFactor); if (dtype === 'float32') { values = new Float32Array(byteBuffer); @@ -285,7 +291,7 @@ const useNodeBuffer = typeof Buffer !== 'undefined' && */ export function stringByteLength(str: string): number { if (useNodeBuffer) { - return Buffer.byteLength(str); + return Buffer.byteLength(str, 'utf8'); } return new Blob([str]).size; } @@ -330,26 +336,15 @@ export function base64StringToArrayBuffer(str: string): ArrayBuffer { /** * Concatenate a number of ArrayBuffers into one. * - * @param buffers A number of array buffers to concatenate. + * @param buffers An array of ArrayBuffers to concatenate, or a single + * ArrayBuffer. * @returns Result of concatenating `buffers` in order. + * + * @deprecated Use tf.io.CompositeArrayBuffer.join() instead. */ -export function concatenateArrayBuffers(buffers: ArrayBuffer[]): ArrayBuffer { - if (buffers.length === 1) { - return buffers[0]; - } - - let totalByteLength = 0; - buffers.forEach((buffer: ArrayBuffer) => { - totalByteLength += buffer.byteLength; - }); - - const temp = new Uint8Array(totalByteLength); - let offset = 0; - buffers.forEach((buffer: ArrayBuffer) => { - temp.set(new Uint8Array(buffer), offset); - offset += buffer.byteLength; - }); - return temp.buffer; +export function concatenateArrayBuffers(buffers: ArrayBuffer[] + | ArrayBuffer): ArrayBuffer { + return CompositeArrayBuffer.join(buffers); } /** @@ -411,14 +406,14 @@ export function getModelJSONForModelArtifacts( * @param modelJSON Object containing the parsed JSON of `model.json` * @param weightSpecs The list of WeightsManifestEntry for the model. Must be * passed if the modelJSON has a weightsManifest. - * @param weightData An ArrayBuffer of weight data for the model corresponding - * to the weights in weightSpecs. Must be passed if the modelJSON has a - * weightsManifest. + * @param weightData An ArrayBuffer or array of ArrayBuffers of weight data for + * the model corresponding to the weights in weightSpecs. Must be passed if + * the modelJSON has a weightsManifest. * @returns A Promise of the `ModelArtifacts`, as described by the JSON file. */ export function getModelArtifactsForJSONSync( modelJSON: ModelJSON, weightSpecs?: WeightsManifestEntry[], - weightData?: ArrayBuffer): ModelArtifacts { + weightData?: WeightData): ModelArtifacts { const modelArtifacts: ModelArtifacts = { modelTopology: modelJSON.modelTopology, @@ -468,10 +463,10 @@ export function getModelArtifactsForJSONSync( export async function getModelArtifactsForJSON( modelJSON: ModelJSON, loadWeights: (weightsManifest: WeightsManifestConfig) => Promise<[ - /* weightSpecs */ WeightsManifestEntry[], /* weightData */ ArrayBuffer + /* weightSpecs */ WeightsManifestEntry[], WeightData, ]>): Promise { let weightSpecs: WeightsManifestEntry[] | undefined; - let weightData: ArrayBuffer | undefined; + let weightData: WeightData | undefined; if (modelJSON.weightsManifest != null) { [weightSpecs, weightData] = await loadWeights(modelJSON.weightsManifest); @@ -502,7 +497,7 @@ export function getModelArtifactsInfoForJSON(modelArtifacts: ModelArtifacts): stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)), weightDataBytes: modelArtifacts.weightData == null ? 0 : - modelArtifacts.weightData.byteLength, + new CompositeArrayBuffer(modelArtifacts.weightData).byteLength, }; } diff --git a/tfjs-core/src/io/io_utils_test.ts b/tfjs-core/src/io/io_utils_test.ts index 6656ef43aa3..01e497c075b 100644 --- a/tfjs-core/src/io/io_utils_test.ts +++ b/tfjs-core/src/io/io_utils_test.ts @@ -620,6 +620,8 @@ describeWithFlags( }); describe('concatenateArrayBuffers', () => { + // TODO(mattSoulanille): Move these tests to CompositeArrayBuffer.join when + // concatenateArrayBuffers is removed. it('Concatenate 3 non-empty ArrayBuffers', () => { const buffer1 = new Uint8Array([1, 2, 3]); const buffer2 = new Uint8Array([11, 22, 33, 44]); diff --git a/tfjs-core/src/io/local_storage.ts b/tfjs-core/src/io/local_storage.ts index 2de2639c35c..f20206a9290 100644 --- a/tfjs-core/src/io/local_storage.ts +++ b/tfjs-core/src/io/local_storage.ts @@ -20,6 +20,7 @@ import {env} from '../environment'; import {assert} from '../util'; import {arrayBufferToBase64String, base64StringToArrayBuffer, getModelArtifactsInfoForJSON} from './io_utils'; +import {CompositeArrayBuffer} from './composite_array_buffer'; import {IORouter, IORouterRegistry} from './router_registry'; import {IOHandler, ModelArtifacts, ModelArtifactsInfo, ModelJSON, ModelStoreManager, SaveResult} from './types'; @@ -174,13 +175,17 @@ export class BrowserLocalStorage implements IOHandler { const modelArtifactsInfo: ModelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts); + // TODO(mattsoulanille): Support saving models over 2GB that exceed + // Chrome's ArrayBuffer size limit. + const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData); + try { this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo)); this.LS.setItem(this.keys.topology, topology); this.LS.setItem(this.keys.weightSpecs, weightSpecs); this.LS.setItem( this.keys.weightData, - arrayBufferToBase64String(modelArtifacts.weightData)); + arrayBufferToBase64String(weightBuffer)); // Note that JSON.stringify doesn't write out keys that have undefined // values, so for some keys, we set undefined instead of a null-ish diff --git a/tfjs-core/src/io/model_management_test.ts b/tfjs-core/src/io/model_management_test.ts index db66e42367f..0f8a958ec7e 100644 --- a/tfjs-core/src/io/model_management_test.ts +++ b/tfjs-core/src/io/model_management_test.ts @@ -18,6 +18,7 @@ import * as tf from '../index'; import {CHROME_ENVS, describeWithFlags, runWithLock} from '../jasmine_util'; import {deleteDatabase} from './indexed_db'; +import {CompositeArrayBuffer} from './composite_array_buffer'; import {purgeLocalStorageArtifacts} from './local_storage'; // Disabled for non-Chrome browsers due to: @@ -268,7 +269,9 @@ describeWithFlags('ModelManagement', CHROME_ENVS, () => { .then(loaded => { expect(loaded.modelTopology).toEqual(modelTopology1); expect(loaded.weightSpecs).toEqual(weightSpecs1); - expect(new Uint8Array(loaded.weightData)) + expect(loaded.weightData).toBeDefined(); + expect(new Uint8Array( + CompositeArrayBuffer.join(loaded.weightData))) .toEqual(new Uint8Array(weightData1)); done(); }) @@ -311,7 +314,8 @@ describeWithFlags('ModelManagement', CHROME_ENVS, () => { .then(loaded => { expect(loaded.modelTopology).toEqual(modelTopology1); expect(loaded.weightSpecs).toEqual(weightSpecs1); - expect(new Uint8Array(loaded.weightData)) + expect(new Uint8Array( + CompositeArrayBuffer.join(loaded.weightData))) .toEqual(new Uint8Array(weightData1)); done(); }) diff --git a/tfjs-core/src/io/passthrough.ts b/tfjs-core/src/io/passthrough.ts index 5e9cd8636c6..e5416002d59 100644 --- a/tfjs-core/src/io/passthrough.ts +++ b/tfjs-core/src/io/passthrough.ts @@ -19,7 +19,7 @@ * IOHandlers that pass through the in-memory ModelArtifacts format. */ -import {IOHandler, IOHandlerSync, LoadHandler, ModelArtifacts, SaveHandler, SaveResult, TrainingConfig, WeightsManifestEntry} from './types'; +import {IOHandler, IOHandlerSync, LoadHandler, ModelArtifacts, SaveHandler, SaveResult, TrainingConfig, WeightData, WeightsManifestEntry} from './types'; class PassthroughLoader implements IOHandlerSync { constructor(private readonly modelArtifacts?: ModelArtifacts) {} @@ -76,7 +76,7 @@ class PassthroughAsync implements IOHandler { */ export function fromMemory( modelArtifacts: {}|ModelArtifacts, weightSpecs?: WeightsManifestEntry[], - weightData?: ArrayBuffer, trainingConfig?: TrainingConfig): IOHandler { + weightData?: WeightData, trainingConfig?: TrainingConfig): IOHandler { const args = arguments as unknown as Parameters; return new PassthroughAsync(fromMemorySync(...args)); @@ -105,7 +105,7 @@ export function fromMemory( */ export function fromMemorySync( modelArtifacts: {}|ModelArtifacts, weightSpecs?: WeightsManifestEntry[], - weightData?: ArrayBuffer, trainingConfig?: TrainingConfig): IOHandlerSync { + weightData?: WeightData, trainingConfig?: TrainingConfig): IOHandlerSync { if (arguments.length === 1) { const isModelArtifacts = (modelArtifacts as ModelArtifacts).modelTopology != null || diff --git a/tfjs-core/src/io/types.ts b/tfjs-core/src/io/types.ts index ece6b97e8b5..2dc0893a82f 100644 --- a/tfjs-core/src/io/types.ts +++ b/tfjs-core/src/io/types.ts @@ -215,6 +215,8 @@ export declare interface TrainingConfig { loss_weights?: number[]|{[key: string]: number}; } +export type WeightData = ArrayBuffer | ArrayBuffer[]; + /** * The serialized artifacts of a model, including topology and weights. * @@ -248,10 +250,12 @@ export declare interface ModelArtifacts { weightSpecs?: WeightsManifestEntry[]; /** - * Binary buffer for all weight values concatenated in the order specified - * by `weightSpecs`. + * Binary buffer(s) for all weight values in the order specified by + * `weightSpecs`. This may be a single ArrayBuffer of all the weights + * concatenated together or an Array of ArrayBuffers containing the weights + * (weights may be sharded across multiple ArrayBuffers). */ - weightData?: ArrayBuffer; + weightData?: WeightData; /** * Hard-coded format name for models saved from TensorFlow.js or converted diff --git a/tfjs-layers/src/models.ts b/tfjs-layers/src/models.ts index 833ab887833..40d6e79a320 100644 --- a/tfjs-layers/src/models.ts +++ b/tfjs-layers/src/models.ts @@ -342,9 +342,9 @@ export async function loadLayersModelFromIOHandler( } function decodeModelAndOptimizerWeights( - buffer: ArrayBuffer, specs: io.WeightsManifestEntry[]): + weightData: io.WeightData, specs: io.WeightsManifestEntry[]): {modelWeights: NamedTensorMap, optimizerWeights: NamedTensor[]} { - const name2Tensor = io.decodeWeights(buffer, specs); + const name2Tensor = io.decodeWeights(weightData, specs); const modelWeights: NamedTensorMap = {}; const optimizerWeights: NamedTensor[] = []; specs.forEach(spec => { diff --git a/tfjs-layers/src/models_test.ts b/tfjs-layers/src/models_test.ts index 1f470ed1d7c..e2db066cf49 100644 --- a/tfjs-layers/src/models_test.ts +++ b/tfjs-layers/src/models_test.ts @@ -1293,7 +1293,8 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(weightData.byteLength).toEqual(4 * 8 + 4 * 1 + 4); + expect(new io.CompositeArrayBuffer(weightData).byteLength) + .toEqual(4 * 8 + 4 * 1 + 4); // Load the model back, with the optimizer. const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts)); @@ -1352,7 +1353,8 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(weightData.byteLength).toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); + expect(new io.CompositeArrayBuffer(weightData) + .byteLength).toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); // Load the model back, with the optimizer. const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts)); @@ -1411,7 +1413,8 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(weightData.byteLength).toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); + expect(new io.CompositeArrayBuffer(weightData).byteLength) + .toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); // Load the model back, with the optimizer. const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts)); @@ -1468,7 +1471,8 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(weightData.byteLength).toEqual(4 + 4 * 8 * 2 + 4 * 1 * 2); + expect(new io.CompositeArrayBuffer(weightData).byteLength) + .toEqual(4 + 4 * 8 * 2 + 4 * 1 * 2); // Load the model back, with the optimizer. const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts)); @@ -1530,7 +1534,8 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(weightData.byteLength).toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); + expect(new io.CompositeArrayBuffer(weightData).byteLength) + .toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); // Load the model back, with the optimizer. const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts)); @@ -1588,7 +1593,8 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(weightData.byteLength).toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); + expect(new io.CompositeArrayBuffer(weightData).byteLength) + .toEqual(4 + 4 * 8 * 3 + 4 * 1 * 3); // Load the model back, with the optimizer. const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts)); @@ -1645,7 +1651,8 @@ describeMathCPUAndWebGL2('Saving+loading model with optimizer', () => { // The second part comes from the bias of the dense layer, which has 1 // element and is also 4 bytes. const weightData = savedArtifacts.weightData; - expect(weightData.byteLength).toEqual(4 + 4 * 8 * 2 + 4 * 1 * 2); + expect(new io.CompositeArrayBuffer(weightData).byteLength) + .toEqual(4 + 4 * 8 * 2 + 4 * 1 * 2); // Load the model back, with the optimizer. const model2 = await tfl.loadLayersModel(io.fromMemory(savedArtifacts)); diff --git a/tfjs-node/src/io/file_system.ts b/tfjs-node/src/io/file_system.ts index ad1148d68c0..74b4cc582d6 100644 --- a/tfjs-node/src/io/file_system.ts +++ b/tfjs-node/src/io/file_system.ts @@ -19,7 +19,7 @@ import * as tf from '@tensorflow/tfjs'; import * as fs from 'fs'; import {dirname, join, resolve} from 'path'; import {promisify} from 'util'; -import {getModelArtifactsInfoForJSON, toArrayBuffer} from './io_utils'; +import {toArrayBuffer} from './io_utils'; const stat = promisify(fs.stat); const writeFile = promisify(fs.writeFile); @@ -121,7 +121,7 @@ export class NodeFileSystem implements tf.io.IOHandler { // TODO(cais): Use explicit tf.io.ModelArtifactsInfo type below once it // is available. // tslint:disable-next-line:no-any - modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts) as any + modelArtifactsInfo: tf.io.getModelArtifactsInfoForJSON(modelArtifacts), }; } } diff --git a/tfjs-node/src/io/file_system_test.ts b/tfjs-node/src/io/file_system_test.ts index 7a918aa545b..13627f55da5 100644 --- a/tfjs-node/src/io/file_system_test.ts +++ b/tfjs-node/src/io/file_system_test.ts @@ -150,7 +150,8 @@ describe('File system IOHandler', () => { expect(modelArtifacts.weightSpecs).toEqual(weightSpecs1); expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1); - expect(new Float32Array(modelArtifacts.weightData)) + expect(new Float32Array(tf.io.CompositeArrayBuffer.join( + modelArtifacts.weightData))) .toEqual(new Float32Array([0, 0, 0, 0])); done(); }) @@ -216,7 +217,8 @@ describe('File system IOHandler', () => { } ]); tf.test_util.expectArraysClose( - new Float32Array(modelArtifacts.weightData), + new Float32Array(tf.io.CompositeArrayBuffer.join( + modelArtifacts.weightData)), new Float32Array([-1.1, -3.3, -3.3, -7.7])); }); @@ -341,7 +343,8 @@ describe('File system IOHandler', () => { } ]); tf.test_util.expectArraysClose( - new Float32Array(modelArtifacts.weightData), + new Float32Array(tf.io.CompositeArrayBuffer.join( + modelArtifacts.weightData)), new Float32Array([-1.1, -3.3, -3.3, -7.7])); }); diff --git a/tfjs-node/src/io/io_utils.ts b/tfjs-node/src/io/io_utils.ts index 74582196eb4..28608855664 100644 --- a/tfjs-node/src/io/io_utils.ts +++ b/tfjs-node/src/io/io_utils.ts @@ -15,8 +15,6 @@ * ============================================================================= */ -import * as tf from '@tensorflow/tfjs'; - /** * Convert an ArrayBuffer to a Buffer. */ @@ -51,30 +49,3 @@ export function toArrayBuffer(buf: Buffer|Buffer[]): ArrayBuffer { return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength); } } - -// TODO(cais): Use explicit tf.io.ModelArtifactsInfo return type below once it -// is available. -/** - * Populate ModelArtifactsInfo fields for a model with JSON topology. - * @param modelArtifacts - * @returns A ModelArtifactsInfo object. - */ -export function getModelArtifactsInfoForJSON( - modelArtifacts: tf.io.ModelArtifacts) { - if (modelArtifacts.modelTopology instanceof ArrayBuffer) { - throw new Error('Expected JSON model topology, received ArrayBuffer.'); - } - return { - dateSaved: new Date(), - modelTopologyType: 'JSON', - modelTopologyBytes: modelArtifacts.modelTopology == null ? - 0 : - Buffer.byteLength(JSON.stringify(modelArtifacts.modelTopology), 'utf8'), - weightSpecsBytes: modelArtifacts.weightSpecs == null ? - 0 : - Buffer.byteLength(JSON.stringify(modelArtifacts.weightSpecs), 'utf8'), - weightDataBytes: modelArtifacts.weightData == null ? - 0 : - modelArtifacts.weightData.byteLength, - }; -} diff --git a/tfjs-node/src/io/node_http_test.ts b/tfjs-node/src/io/node_http_test.ts index ff3a52b68a0..83b6e6140d4 100644 --- a/tfjs-node/src/io/node_http_test.ts +++ b/tfjs-node/src/io/node_http_test.ts @@ -136,7 +136,8 @@ describe('nodeHTTPRequest-load', () => { expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1); - expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); + expect(new Float32Array(tf.io.CompositeArrayBuffer.join( + modelArtifacts.weightData))).toEqual(floatData); expect(requestInits).toEqual([ {credentials: 'include', cache: 'no-cache'},