-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support loading models with weights above 2GB on Chrome #7609
Support loading models with weights above 2GB on Chrome #7609
Conversation
fe844d2
to
b21b302
Compare
|
||
// TODO(cais): Use explicit tf.io.ModelArtifactsInfo return type below once it | ||
// is available. | ||
/** | ||
* Populate ModelArtifactsInfo fields for a model with JSON topology. | ||
* @param modelArtifacts | ||
* @returns A ModelArtifactsInfo object. | ||
*/ | ||
export function getModelArtifactsInfoForJSON( | ||
modelArtifacts: tf.io.ModelArtifacts) { | ||
if (modelArtifacts.modelTopology instanceof ArrayBuffer) { | ||
throw new Error('Expected JSON model topology, received ArrayBuffer.'); | ||
} | ||
return { | ||
dateSaved: new Date(), | ||
modelTopologyType: 'JSON', | ||
modelTopologyBytes: modelArtifacts.modelTopology == null ? | ||
0 : | ||
Buffer.byteLength(JSON.stringify(modelArtifacts.modelTopology), 'utf8'), | ||
weightSpecsBytes: modelArtifacts.weightSpecs == null ? | ||
0 : | ||
Buffer.byteLength(JSON.stringify(modelArtifacts.weightSpecs), 'utf8'), | ||
weightDataBytes: modelArtifacts.weightData == null ? | ||
0 : | ||
modelArtifacts.weightData.byteLength, | ||
}; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is duplicated in tfjs-core/src/io/io_utils.ts
@@ -285,7 +291,7 @@ const useNodeBuffer = typeof Buffer !== 'undefined' && | |||
*/ | |||
export function stringByteLength(str: string): number { | |||
if (useNodeBuffer) { | |||
return Buffer.byteLength(str); | |||
return Buffer.byteLength(str, 'utf8'); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tfjs-node used utf8
in its implementation, so I think it should also be here.
* @returns Result of concatenating `buffers` in order. | ||
*/ | ||
export function concatenateArrayBuffers(buffers: ArrayBuffer[]): ArrayBuffer { | ||
export function concatenateArrayBuffers(buffers: ArrayBuffer[] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this to be part of CompositeArrayBuffer
? like static method CompositeArrayBuffer.join(buffers: ArrayBuffer[])
or through public method new CompositeArrayBuffer(buffers).toArrayBuffer()
, which makes it easier to bridge CompositeArrayBuffer with native ArrayBuffer and pass CompositeArrayBuffer around in the future if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was considering that, and my original implementation actually used new CompositeArrayBuffer(buffers).slice()
, but I removed it in favor of concatenateArrayBuffers
because of an issue with the types in tfjs-converter tests (here was my fix for it in the spy_ops.ts
file, but it's a bit hacky).
I'm fine with using the converter spy_ops.ts fix if it'll make the core implementation cleaner. What do you think?
Edit: ...and we can add a toArrayBuffer
or static join
method instead of using .slice
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, I can move composite_array_buffer.ts
out of io/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took a look at the usage of spyOnAllFunctions
in tests, and I think the test is something we should fix. A hacky way like what you did is probably fine.
In general, instead of automatically replace everything with spy using spyOnAllFunctions
, we should explicitly create an ioSpy object which only contains the function we want to spy, so that we can make the test more controllable and reliable. There are some stuffs exported in io apparently should not be spied, like getWeightSpecs
, which is a io helper function instead of a function to do io.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chunnienc I've replaced concatenateArrayBuffers
with CompositeArrayBuffer.join
in tfjs-core and deprecated concatenateArrayBuffers
. We can't replace it in other packages yet because that would introduce a breaking change. Downstream packages could not be used with an earlier version of tfjs-core that does not implement CompositeArrayBuffer
(see #7273 for an example of why this is important). We can apply this change to all the packages in the next major release.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, it's fine to use it in tests, since users will never run those. I'll swap concatenateArrayBuffers
for CompositeArrayBuffer.join
in the test files.
0276e01
to
94ed22e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 4 of 18 files at r1, 1 of 1 files at r2, 13 of 13 files at r3, 4 of 4 files at r4, all commit messages.
Reviewable status: complete! 2 of 1 approvals obtained (waiting on @chunnienc)
Chrome ArrayBuffers throw allocation errors above 2GB in size. This makes it impossible to load TFJS models above this size in Chrome (even with weight sharding) because model loading involves concatenating all the weights into a single ArrayBuffer.
This PR avoids this concatenation. Instead of slicing the weight tensors out of a single concatenated ArrayBuffer, it keeps the weight buffers in their original shards and slices them using the CompositeArrayBuffer class created in #7598.
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is