-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid allocating a large arraybuffer when loading weights #7598
Changes from 19 commits
f7be3c6
076c8a2
bca64fd
0d26418
0051fd7
c5b4f8e
49ffd05
05a9831
0c4e517
a530b92
0294b53
f81638f
6a07547
61caf8c
029ec8f
b1fe5eb
b86ee6e
7e88a89
8287312
f2fa6f8
91db31a
26c51df
ec9a382
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
import {env} from '../environment'; | ||
|
||
import {NamedTensorMap} from '../tensor_types'; | ||
import {TypedArray} from '../types'; | ||
import * as util from '../util'; | ||
import {decodeWeights} from './io_utils'; | ||
import {monitorPromisesProgress} from './progress'; | ||
|
@@ -212,24 +213,13 @@ export function weightsLoaderFactory( | |
groupIndicesToFetch.forEach(i => { | ||
const numBuffers = manifest[i].paths.length; | ||
|
||
let groupBytes = 0; | ||
for (let i = 0; i < numBuffers; i++) { | ||
groupBytes += buffers[bufferIndexOffset + i].byteLength; | ||
} | ||
|
||
// Create a buffer for the whole group. | ||
const groupBuffer = new ArrayBuffer(groupBytes); | ||
const groupByteBuffer = new Uint8Array(groupBuffer); | ||
let groupBufferOffset = 0; | ||
for (let i = 0; i < numBuffers; i++) { | ||
const buffer = new Uint8Array(buffers[bufferIndexOffset + i]); | ||
groupByteBuffer.set(buffer, groupBufferOffset); | ||
groupBufferOffset += buffer.byteLength; | ||
} | ||
const weightsBuffer = new CompositeArrayBuffer( | ||
buffers.slice(bufferIndexOffset, bufferIndexOffset + numBuffers)); | ||
|
||
const weightsEntries = groupWeightsToFetch[i]; | ||
|
||
weightsEntries.forEach(weightsEntry => { | ||
const byteBuffer = groupBuffer.slice( | ||
const byteBuffer = weightsBuffer.slice( | ||
weightsEntry.groupOffset, | ||
weightsEntry.groupOffset + weightsEntry.sizeBytes); | ||
const nameToTensorMap = | ||
|
@@ -245,3 +235,180 @@ export function weightsLoaderFactory( | |
return weightsTensorMap; | ||
}; | ||
} | ||
|
||
type BufferRange = { | ||
start: number, | ||
end: number, | ||
buffer: ArrayBuffer, | ||
}; | ||
|
||
export class CompositeArrayBuffer { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will be used in another PR that enables large model weights to be stored in a list of ArrayBuffers. That's why it's exported here. |
||
private ranges: BufferRange[] = []; | ||
private previousRangeIndex = 0; | ||
private bufferUniformSize?: number; | ||
public readonly byteLength: number; | ||
|
||
constructor(buffers: ArrayBuffer | ArrayBuffer[] | TypedArray | ||
| TypedArray[]) { | ||
// Normalize the `buffers` input to be `ArrayBuffer[]`. | ||
if (!(buffers instanceof Array)) { | ||
buffers = [buffers]; | ||
} | ||
buffers = buffers.map((bufferOrTypedArray) => { | ||
if (util.isTypedArray(bufferOrTypedArray)) { | ||
return bufferOrTypedArray.buffer; | ||
} | ||
return bufferOrTypedArray; | ||
}); | ||
|
||
// Skip setting up ranges if there are no buffers. | ||
if (buffers.length === 0) { | ||
return; | ||
} | ||
|
||
this.bufferUniformSize = buffers[0].byteLength; | ||
let start = 0; | ||
|
||
for (let i = 0; i < buffers.length; i++) { | ||
const buffer = buffers[i]; | ||
// Check that all buffers except the last one have the same length. | ||
if (i !== buffers.length - 1 && | ||
buffer.byteLength !== this.bufferUniformSize) { | ||
// Unset the buffer uniform size, since the buffer sizes are not | ||
// uniform. | ||
this.bufferUniformSize = undefined; | ||
} | ||
|
||
// Create the ranges, including their start and end points. | ||
const end = start + buffer.byteLength; | ||
this.ranges.push({buffer, start, end,}); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: remove ',' or format to multiple lines |
||
start = end; | ||
} | ||
|
||
// Set the byteLenghth | ||
if (this.ranges.length === 0) { | ||
this.byteLength = 0; | ||
} | ||
this.byteLength = this.ranges[this.ranges.length - 1].end; | ||
} | ||
|
||
slice(start = 0, end = this.byteLength): ArrayBuffer { | ||
// NaN is treated as zero for slicing. This matches ArrayBuffer's behavior. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. convert start and end to Number with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added these NaN checks because some of the tests were failing (they intentionally gave it no datatype, which I think eventually resulted in a NaN being passed to slice (since tfjs didn't know the byte length of the datatype), so I think the tests themselves are correct). I'd like this to match |
||
start = isNaN(start) ? 0 : start; | ||
end = isNaN(end) ? 0 : end; | ||
|
||
// Fix the bounds to within the array. | ||
start = Math.max(0, start); | ||
end = Math.min(this.byteLength, end); | ||
if (end <= start) { | ||
return new ArrayBuffer(0); | ||
} | ||
|
||
const startRangeIndex = this.findRangeForByte(start); | ||
if (startRangeIndex === -1) { | ||
// This should not happen since the start and end indices are always | ||
// within 0 and the composite array's length. | ||
throw new Error(`Could not find start range for byte ${start}`); | ||
} | ||
|
||
const size = end - start; | ||
const outputBuffer = new ArrayBuffer(size); | ||
const outputArray = new Uint8Array(outputBuffer); | ||
let sliced = 0; | ||
for (let i = startRangeIndex; i < this.ranges.length; i++) { | ||
const range = this.ranges[i]; | ||
|
||
const globalStart = start + sliced; | ||
const localStart = globalStart - range.start; | ||
const outputStart = sliced; | ||
|
||
const globalEnd = Math.min(end, range.end); | ||
const localEnd = globalEnd - range.start; | ||
|
||
const outputSlice = new Uint8Array(range.buffer.slice(localStart, | ||
localEnd)); | ||
outputArray.set(outputSlice, outputStart); | ||
sliced += outputSlice.length; | ||
|
||
if (end < range.end) { | ||
break; | ||
} | ||
} | ||
return outputBuffer; | ||
} | ||
|
||
/** | ||
* Get the index of the range that contains the byte at `byteIndex`. | ||
*/ | ||
private findRangeForByte(byteIndex: number): number { | ||
if (this.ranges.length === 0 || byteIndex < 0 || | ||
byteIndex >= this.byteLength) { | ||
return -1; | ||
} | ||
|
||
// If the buffers have a uniform size, compute the range directly. | ||
if (this.bufferUniformSize != null) { | ||
this.previousRangeIndex = Math.floor(byteIndex / this.bufferUniformSize); | ||
return this.previousRangeIndex; | ||
} | ||
|
||
// If the buffers don't have a uniform size, we need to search for the | ||
// range. That means we need a function to check where the byteIndex lies | ||
// relative to a given range. | ||
function check(range: BufferRange) { | ||
if (byteIndex < range.start) { | ||
return -1; | ||
} | ||
if (byteIndex >= range.end) { | ||
return 1; | ||
} | ||
return 0; | ||
} | ||
|
||
// For efficiency, try the previous range first. | ||
if (check(this.ranges[this.previousRangeIndex]) === 0) { | ||
return this.previousRangeIndex; | ||
} | ||
|
||
// Otherwise, use a generic search function. | ||
// This should almost never end up being used in practice since the weight | ||
// entries should always be in order. | ||
const index = search(this.ranges, check); | ||
if (index === -1) { | ||
return -1; | ||
} | ||
|
||
this.previousRangeIndex = index; | ||
return this.previousRangeIndex; | ||
} | ||
} | ||
|
||
/** | ||
* Search for an element of a sorted array. | ||
* | ||
* @param sortedArray The sorted array to search | ||
* @param compare A function to compare the current value against the searched | ||
* value. Return 0 on a match, negative if the searched value is less than | ||
* the value passed to the function, and positive if the searched value is | ||
* greater than the value passed to the function. | ||
* @returns The index of the element, or -1 if it's not in the array. | ||
*/ | ||
function search<T>(sortedArray: T[], compare: (t: T) => number): number { | ||
// Binary search | ||
let min = 0; | ||
let max = sortedArray.length; | ||
|
||
while (min <= max) { | ||
const middle = Math.floor((max - min) / 2) + min; | ||
const side = compare(sortedArray[middle]); | ||
|
||
if (side === 0) { | ||
return middle; | ||
} else if (side < 0) { | ||
max = middle; | ||
} else { | ||
min = middle + 1; | ||
} | ||
} | ||
return -1; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naming: range -> chunk/shard/partition
And all related variable and function names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. That's a much better name. Fixed.