-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid allocating a large arraybuffer when loading weights #7598
Avoid allocating a large arraybuffer when loading weights #7598
Conversation
@pyu10055 This isn't a full review request yet, but I'm interested in knowing which approach you think is better. Thanks! |
If the size of each chunk is fixed (4MB) except for the last one, you can get start and end offsets in O(1) by division and module. No need to do binary search or two-pointers approach on sorted data. In terms of API, I prefer option 2, which is a better implementation for separation of concerns. In the weight loader I just need to specify where to slice and don't need to worry about how to slice. You cam implement lazy/offline slicer If performance is a concern. |
Unfortunately, there's no guarantee on the size of the chunks. We let people configure it when converting the model. They should all be the same size, but I'd even hesitate to assume that, since it seems a bit flaky. I agree with you and also prefer option 2. I'll implement it as a binsearch for now, and if we need better perf, I can try to automatically detect the chunk size or make it check chunks near the last one read before doing a full binsearch. |
Looking at this again, it doesn't actually prevent us from storing the weights in a single ArrayBuffer. I'll leave this PR as-is and create a new one to fix this issue. |
159fcd1
to
6a07547
Compare
tfjs-core/src/io/weights_loader.ts
Outdated
buffer: ArrayBuffer, | ||
}; | ||
|
||
export class CompositeArrayBuffer { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be used in another PR that enables large model weights to be stored in a list of ArrayBuffers. That's why it's exported here.
I'm sending this out for review since it's easier to review separately from the other part of the large weights fix. I'll submit the other part, which integrates this code with the rest of the codebase, in a separate PR. |
ee87834
to
b86ee6e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 approvals obtained (waiting on @chunnienc and @mattsoulanille)
tfjs-core/src/io/weights_loader.ts
line 290 at r2 (raw file):
} } return outputBuffer
missing ;
tfjs-core/src/io/weights_loader.ts
line 292 at r2 (raw file):
return outputBuffer } private search(byteIndex: number) {
this could be improved if the searching is in order, I believe that is how our weights are setup.
tfjs-core/src/io/weights_loader.ts
line 245 at r6 (raw file):
Previously, mattsoulanille (Matthew Soulanille) wrote…
This will be used in another PR that enables large model weights to be stored in a list of ArrayBuffers. That's why it's exported here.
It might be good to be in a separate file
tfjs-core/src/io/weights_loader.ts
line 272 at r9 (raw file):
let start = 0; for (let i = 0; i < buffers.length; i++) {
start from 1?
tfjs-core/src/io/weights_loader.ts
Outdated
|
||
// Create the ranges, including their start and end points. | ||
const end = start + buffer.byteLength; | ||
this.ranges.push({buffer, start, end,}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove ',' or format to multiple lines
tfjs-core/src/io/weights_loader.ts
Outdated
} | ||
|
||
slice(start = 0, end = this.byteLength): ArrayBuffer { | ||
// NaN is treated as zero for slicing. This matches ArrayBuffer's behavior. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
convert start and end to Number with Number(...)
before checking NaN?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update:
I assume you add these nan checks because you think there may be calls from JS now or future which ignores the typescript type check. In these way I'd suggest to do start = Number(start)
since isNaN('123')
returns false
and ArrayBuffer.prototype.slice
accepts the numbers in strings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added these NaN checks because some of the tests were failing (they intentionally gave it no datatype, which I think eventually resulted in a NaN being passed to slice (since tfjs didn't know the byte length of the datatype), so I think the tests themselves are correct). I'd like this to match ArrayBuffer.slice
as closely as possible, so I implemented your comment.
tfjs-core/src/io/weights_loader.ts
Outdated
@@ -245,3 +235,180 @@ export function weightsLoaderFactory( | |||
return weightsTensorMap; | |||
}; | |||
} | |||
|
|||
type BufferRange = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naming: range -> chunk/shard/partition
And all related variable and function names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. That's a much better name. Fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 3 of 4 files at r11, 1 of 1 files at r12, all commit messages.
Reviewable status: complete! 2 of 1 approvals obtained (waiting on @chunnienc and @mattsoulanille)
* webgpu: Fix a bug in softmax (#7607) * Avoid allocating a large arraybuffer when loading weights (#7598) The loadWeights function loads weights in 4MB chunks and then concatenates them into a single large ArrayBuffer. That ArrayBuffer is used for splitting the weights data back up into tensors. Allocating large ArrayBuffers (3.5GB) can be unstable on Chrome, so this PR avoids this allocation, instead slicing the weights out of the chunks manually. The implementation wraps the array of weights (stored as ArrayBuffer[]) in a new CompositeArrayBuffer class. This class implements slice by copying the desired range out of the buffer(s) that it overlaps with. * Support using a list of ArrayBuffers as model weight data * Avoid 'Array.flat()' * Simplify some of the tests * Do not export 'CompositeArrayBuffer' from tfjs-core * Update doc for weightData * Fix tfjs-node * Remove unused import --------- Co-authored-by: Jiajia Qin <[email protected]>
Chrome ArrayBuffers throw allocation errors above 2GB in size. This makes it impossible to load TFJS models above this size in Chrome (even with weight sharding) because model loading involves concatenating all the weights into a single ArrayBuffer. This PR avoids this concatenation. Instead of slicing the weight tensors out of a single concatenated ArrayBuffer, it keeps the weight buffers in their original shards and slices them using the CompositeArrayBuffer class created in #7598.
The loadWeights function loads weights in 4MB chunks and then concatenates them into a single large ArrayBuffer. That ArrayBuffer is used for splitting the weights data back up into tensors. Allocating large ArrayBuffers (3.5GB) can be unstable on Chrome, so this PR avoids this allocation, instead slicing the weights out of the chunks manually.
The implementation wraps the array of weights (stored as
ArrayBuffer[]
) in a newCompositeArrayBuffer
class. This class implementsslice
by copying the desired range out of the buffer(s) that it overlaps with.To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is