Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support loading models with weights above 2GB on Chrome #7609

Merged
merged 14 commits into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions tfjs-converter/src/executor/graph_model_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,8 @@ describe('Model', () => {
expect(handler.savedArtifacts.modelTopology).toEqual(CUSTOM_OP_MODEL);
expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest);
tfc.test_util.expectArraysClose(
new Int32Array(handler.savedArtifacts.weightData), bias.dataSync());
new Int32Array(io.CompositeArrayBuffer.join(
handler.savedArtifacts.weightData)), bias.dataSync());
});
});
});
Expand Down Expand Up @@ -616,7 +617,8 @@ describe('Model', () => {
});
expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest);
tfc.test_util.expectArraysClose(
new Int32Array(handler.savedArtifacts.weightData), bias.dataSync());
new Int32Array(io.CompositeArrayBuffer.join(
handler.savedArtifacts.weightData)), bias.dataSync());
});
});

Expand Down Expand Up @@ -904,7 +906,8 @@ describe('Model', () => {
});
expect(handler.savedArtifacts.weightSpecs).toEqual(weightsManifest);
tfc.test_util.expectArraysClose(
new Int32Array(handler.savedArtifacts.weightData), bias.dataSync());
new Int32Array(io.CompositeArrayBuffer.join(handler.savedArtifacts
.weightData)), bias.dataSync());
});
});

Expand Down
14 changes: 13 additions & 1 deletion tfjs-converter/src/operations/executors/spy_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,23 @@
* =============================================================================
*/

// The opposite of Extract<T, U>
type Without<T, U> = T extends U ? never : T;

// Do not spy on CompositeArrayBuffer because it is a class constructor.
type NotSpiedOn = 'CompositeArrayBuffer';

export type RecursiveSpy<T> =
T extends Function ? jasmine.Spy : {[K in keyof T]: RecursiveSpy<T[K]>};
T extends Function ? jasmine.Spy :
{[K in Without<keyof T, NotSpiedOn>]: RecursiveSpy<T[K]>} &
{[K in Extract<keyof T, NotSpiedOn>]: T[K]};

export function spyOnAllFunctions<T>(obj: T): RecursiveSpy<T> {
return Object.fromEntries(Object.entries(obj).map(([key, val]) => {
// TODO(mattSoulanille): Do not hard code this
if (key === 'CompositeArrayBuffer') {
return val;
}
if (val instanceof Function) {
return [key, jasmine.createSpy(`${key} spy`, val).and.callThrough()];
} else if (val instanceof Array) {
Expand Down
16 changes: 11 additions & 5 deletions tfjs-core/src/io/browser_files.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
import '../flags';
import {env} from '../environment';

import {basename, concatenateArrayBuffers, getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts} from './io_utils';
import {basename, getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts} from './io_utils';
import {IORouter, IORouterRegistry} from './router_registry';
import {IOHandler, ModelArtifacts, ModelJSON, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {IOHandler, ModelArtifacts, ModelJSON, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {CompositeArrayBuffer} from './composite_array_buffer';

const DEFAULT_FILE_NAME_PREFIX = 'model';
const DEFAULT_JSON_EXTENSION_NAME = '.json';
Expand Down Expand Up @@ -70,8 +71,13 @@ export class BrowserDownloads implements IOHandler {
'Browser downloads are not supported in ' +
'this environment since `document` is not present');
}

// TODO(mattsoulanille): Support saving models over 2GB that exceed
// Chrome's ArrayBuffer size limit.
const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData);

const weightsURL = window.URL.createObjectURL(new Blob(
[modelArtifacts.weightData], {type: 'application/octet-stream'}));
[weightBuffer], {type: 'application/octet-stream'}));

if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
throw new Error(
Expand Down Expand Up @@ -169,7 +175,7 @@ class BrowserFiles implements IOHandler {
}

private loadWeights(weightsManifest: WeightsManifestConfig): Promise<[
/* weightSpecs */ WeightsManifestEntry[], /* weightData */ ArrayBuffer
/* weightSpecs */ WeightsManifestEntry[], WeightData,
]> {
const weightSpecs: WeightsManifestEntry[] = [];
const paths: string[] = [];
Expand All @@ -185,7 +191,7 @@ class BrowserFiles implements IOHandler {
paths.map(path => this.loadWeightsFile(path, pathToFile[path]));

return Promise.all(promises).then(
buffers => [weightSpecs, concatenateArrayBuffers(buffers)]);
buffers => [weightSpecs, buffers]);
}

private loadWeightsFile(path: string, file: File): Promise<ArrayBuffer> {
Expand Down
24 changes: 14 additions & 10 deletions tfjs-core/src/io/browser_files_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import * as tf from '../index';
import {BROWSER_ENVS, describeWithFlags} from '../jasmine_util';
import {browserDownloads, BrowserDownloads, browserDownloadsRouter} from './browser_files';
import {WeightsManifestConfig, WeightsManifestEntry} from './types';
import {CompositeArrayBuffer} from './composite_array_buffer';

const modelTopology1: {} = {
'class_name': 'Sequential',
Expand Down Expand Up @@ -310,7 +311,7 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => {
expect(modelArtifacts.modelInitializer).toEqual({});
expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1);

expect(new Uint8Array(modelArtifacts.weightData))
expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData)))
.toEqual(new Uint8Array(weightData1));
});

Expand Down Expand Up @@ -351,9 +352,10 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => {
const modelArtifacts = await filesHandler.load();
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs).toEqual(weightSpecs);
expect(new Uint8Array(modelArtifacts.weightData)).toEqual(new Uint8Array([
1, 2, 3, 4, 10, 20, 30, 40
]));
expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData)))
.toEqual(new Uint8Array([
1, 2, 3, 4, 10, 20, 30, 40
]));
});

it(`Two groups, four paths, reverseOrder=false`, async () => {
Expand Down Expand Up @@ -418,9 +420,10 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => {
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs)
.toEqual(weightSpecs1.concat(weightSpecs2));
expect(new Uint8Array(modelArtifacts.weightData)).toEqual(new Uint8Array([
1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80
]));
expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData)))
.toEqual(new Uint8Array([
1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80
]));
});

it(`Two groups, four paths, reverseOrder=true`, async () => {
Expand Down Expand Up @@ -485,9 +488,10 @@ describeWithFlags('browserFiles', BROWSER_ENVS, () => {
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs)
.toEqual(weightSpecs1.concat(weightSpecs2));
expect(new Uint8Array(modelArtifacts.weightData)).toEqual(new Uint8Array([
1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80
]));
expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData)))
.toEqual(new Uint8Array([
1, 3, 5, 7, 10, 30, 50, 70, 2, 4, 6, 8, 20, 40, 60, 80
]));
});

it('Upload model topology only', async () => {
Expand Down
26 changes: 23 additions & 3 deletions tfjs-core/src/io/composite_array_buffer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,22 @@ export class CompositeArrayBuffer {
private bufferUniformSize?: number;
public readonly byteLength: number;

constructor(buffers: ArrayBuffer | ArrayBuffer[] | TypedArray |
/**
* Concatenate a number of ArrayBuffers into one.
*
* @param buffers An array of ArrayBuffers to concatenate, or a single
* ArrayBuffer.
* @returns Result of concatenating `buffers` in order.
*/
static join(buffers?: ArrayBuffer[] | ArrayBuffer) {
return new CompositeArrayBuffer(buffers).slice();
}

constructor(buffers?: ArrayBuffer | ArrayBuffer[] | TypedArray |
TypedArray[]) {
if (buffers == null) {
return;
}
// Normalize the `buffers` input to be `ArrayBuffer[]`.
if (!(buffers instanceof Array)) {
buffers = [buffers];
Expand Down Expand Up @@ -85,6 +99,12 @@ export class CompositeArrayBuffer {
}

slice(start = 0, end = this.byteLength): ArrayBuffer {
// If there are no shards, then the CompositeArrayBuffer was initialized
// with no data.
if (this.shards.length === 0) {
return new ArrayBuffer(0);
}

// NaN is treated as zero for slicing. This matches ArrayBuffer's behavior.
start = isNaN(Number(start)) ? 0 : start;
end = isNaN(Number(end)) ? 0 : end;
Expand Down Expand Up @@ -117,8 +137,8 @@ export class CompositeArrayBuffer {
const globalEnd = Math.min(end, shard.end);
const localEnd = globalEnd - shard.start;

const outputSlice = new Uint8Array(shard.buffer.slice(localStart,
localEnd));
const outputSlice = new Uint8Array(shard.buffer, localStart,
localEnd - localStart);
outputArray.set(outputSlice, outputStart);
sliced += outputSlice.length;

Expand Down
14 changes: 13 additions & 1 deletion tfjs-core/src/io/composite_array_buffer_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ describe('CompositeArrayBuffer', () => {
});
}

it('can be passed an empty arraybuffer', () => {
it('can be created from an empty arraybuffer', () => {
const array = new Uint8Array([]);
const singleComposite = new CompositeArrayBuffer(array.buffer);
expectArraysEqual(new Uint8Array(singleComposite.slice()), []);
Expand All @@ -92,6 +92,18 @@ describe('CompositeArrayBuffer', () => {
expectArraysEqual(new Uint8Array(singleComposite.slice()), array);
});

it('can be created from zero arrays', () => {
const singleComposite = new CompositeArrayBuffer([]);
expectArraysEqual(new Uint8Array(singleComposite.slice()),
new Uint8Array());
});

it('can be created from undefined input', () => {
const singleComposite = new CompositeArrayBuffer();
expectArraysEqual(new Uint8Array(singleComposite.slice()),
new Uint8Array());
});

it('treats NaN as zero when passed as the start of slice', () => {
const array = new Uint8Array([1,2,3]);
const composite = new CompositeArrayBuffer(array.buffer);
Expand Down
15 changes: 10 additions & 5 deletions tfjs-core/src/io/http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
import {env} from '../environment';

import {assert} from '../util';
import {concatenateArrayBuffers, getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts, getWeightSpecs} from './io_utils';
import {getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts, getWeightSpecs} from './io_utils';
import {CompositeArrayBuffer} from './composite_array_buffer';
import {IORouter, IORouterRegistry} from './router_registry';
import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightData, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {loadWeightsAsArrayBuffer} from './weights_loader';

const OCTET_STREAM_MIME_TYPE = 'application/octet-stream';
Expand Down Expand Up @@ -110,9 +111,13 @@ export class HTTPRequest implements IOHandler {
'model.json');

if (modelArtifacts.weightData != null) {
// TODO(mattsoulanille): Support saving models over 2GB that exceed
// Chrome's ArrayBuffer size limit.
const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData);

init.body.append(
'model.weights.bin',
new Blob([modelArtifacts.weightData], {type: OCTET_STREAM_MIME_TYPE}),
new Blob([weightBuffer], {type: OCTET_STREAM_MIME_TYPE}),
'model.weights.bin');
}

Expand Down Expand Up @@ -182,7 +187,7 @@ export class HTTPRequest implements IOHandler {
}

private async loadWeights(weightsManifest: WeightsManifestConfig):
Promise<[WeightsManifestEntry[], ArrayBuffer]> {
Promise<[WeightsManifestEntry[], WeightData]> {
const weightPath = Array.isArray(this.path) ? this.path[1] : this.path;
const [prefix, suffix] = parseUrl(weightPath);
const pathPrefix = this.weightPathPrefix || prefix;
Expand Down Expand Up @@ -210,7 +215,7 @@ export class HTTPRequest implements IOHandler {
fetchFunc: this.fetch,
onProgress: this.onProgress
});
return [weightSpecs, concatenateArrayBuffers(buffers)];
return [weightSpecs, buffers];
}
}

Expand Down
42 changes: 25 additions & 17 deletions tfjs-core/src/io/http_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import * as tf from '../index';
import {BROWSER_ENVS, CHROME_ENVS, describeWithFlags, NODE_ENVS} from '../jasmine_util';
import {HTTPRequest, httpRouter, parseUrl} from './http';
import {CompositeArrayBuffer} from './composite_array_buffer';

// Test data.
const modelTopology1: {} = {
Expand Down Expand Up @@ -161,7 +162,8 @@ describeWithFlags('http-load fetch', NODE_ENVS, () => {
expect(modelArtifacts.generatedBy).toEqual('1.15');
expect(modelArtifacts.convertedBy).toEqual('1.3.1');
expect(modelArtifacts.userDefinedMetadata).toEqual({});
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
expect(new Float32Array(CompositeArrayBuffer.join(
modelArtifacts.weightData))).toEqual(floatData);
});

it('throw exception if no fetch polyfill', () => {
Expand Down Expand Up @@ -507,7 +509,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
expect(modelArtifacts.userDefinedMetadata).toEqual({});
expect(modelArtifacts.modelInitializer).toEqual({});

expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
expect(new Float32Array(CompositeArrayBuffer.join(modelArtifacts
.weightData))).toEqual(floatData);
expect(Object.keys(requestInits).length).toEqual(2);
// Assert that fetch is invoked with `window` as the context.
expect(fetchSpy.calls.mostRecent().object).toEqual(window);
Expand Down Expand Up @@ -550,7 +553,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
const modelArtifacts = await handler.load();
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
expect(new Float32Array(CompositeArrayBuffer.join(modelArtifacts
.weightData))).toEqual(floatData);
expect(Object.keys(requestInits).length).toEqual(2);
expect(Object.keys(requestInits).length).toEqual(2);
expect(requestInits['./model.json'].headers['header_key_1'])
Expand Down Expand Up @@ -599,8 +603,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
const modelArtifacts = await handler.load();
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
expect(new Float32Array(modelArtifacts.weightData))
.toEqual(new Float32Array([1, 3, 3, 7, 4]));
expect(new Float32Array(CompositeArrayBuffer.join(modelArtifacts
.weightData))).toEqual(new Float32Array([1, 3, 3, 7, 4]));
});

it('2 groups, 2 weight, 2 paths', async () => {
Expand Down Expand Up @@ -644,8 +648,9 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
expect(modelArtifacts.weightSpecs)
.toEqual(
weightsManifest[0].weights.concat(weightsManifest[1].weights));
expect(new Float32Array(modelArtifacts.weightData))
.toEqual(new Float32Array([1, 3, 3, 7, 4]));
expect(new Float32Array(CompositeArrayBuffer.join(
modelArtifacts.weightData)))
.toEqual(new Float32Array([1, 3, 3, 7, 4]));
});

it('2 groups, 2 weight, 2 paths, Int32 and Uint8 Data', async () => {
Expand Down Expand Up @@ -689,10 +694,10 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
expect(modelArtifacts.weightSpecs)
.toEqual(
weightsManifest[0].weights.concat(weightsManifest[1].weights));
expect(new Int32Array(modelArtifacts.weightData.slice(0, 12)))
.toEqual(new Int32Array([1, 3, 3]));
expect(new Uint8Array(modelArtifacts.weightData.slice(12, 14)))
.toEqual(new Uint8Array([7, 4]));
expect(new Int32Array(CompositeArrayBuffer.join(modelArtifacts.weightData)
.slice(0, 12))).toEqual(new Int32Array([1, 3, 3]));
expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData)
.slice(12, 14))).toEqual(new Uint8Array([7, 4]));
});

it('topology only', async () => {
Expand Down Expand Up @@ -752,10 +757,11 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
expect(modelArtifacts.weightSpecs)
.toEqual(
weightsManifest[0].weights.concat(weightsManifest[1].weights));
expect(new Int32Array(modelArtifacts.weightData.slice(0, 12)))
.toEqual(new Int32Array([1, 3, 3]));
expect(new Float32Array(modelArtifacts.weightData.slice(12, 20)))
.toEqual(new Float32Array([-7, -4]));
expect(new Int32Array(CompositeArrayBuffer.join(modelArtifacts.weightData)
.slice(0, 12))).toEqual(new Int32Array([1, 3, 3]));
expect(new Float32Array(CompositeArrayBuffer
.join(modelArtifacts.weightData)
.slice(12, 20))).toEqual(new Float32Array([-7, -4]));
});

it('Missing modelTopology and weightsManifest leads to error', async () => {
Expand Down Expand Up @@ -840,7 +846,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
const modelArtifacts = await handler.load();
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
expect(new Float32Array(CompositeArrayBuffer.join(
modelArtifacts.weightData))).toEqual(floatData);
expect(Object.keys(requestInits).length).toEqual(2);
expect(Object.keys(requestInits).length).toEqual(2);
expect(requestInits['./model.json'].headers['header_key_1'])
Expand Down Expand Up @@ -902,7 +909,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1);
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
expect(new Float32Array(CompositeArrayBuffer
.join(modelArtifacts.weightData))).toEqual(floatData);

expect(fetchInputs).toEqual(['./model.json', './weightfile0']);
expect(fetchInits.length).toEqual(2);
Expand Down
Loading