Skip to content

Commit

Permalink
fix: refactor runtime and runtime-async packages to allocate/grow buf…
Browse files Browse the repository at this point in the history
…fers on demand (#484)

Fixes #471
  • Loading branch information
chrispcampbell authored May 20, 2024
1 parent 8de0d16 commit 5e1c686
Show file tree
Hide file tree
Showing 38 changed files with 1,812 additions and 487 deletions.
6 changes: 3 additions & 3 deletions packages/runtime-async/docs/functions/exposeModelWorker.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

Expose an object in the current worker thread that communicates with the
[`ModelRunner`](../../../runtime/docs/interfaces/ModelRunner.md) instance running in the main thread. The exposed worker
object will take care of running the `WasmModel` on the worker thread
and sending the outputs back to the main process.
object will take care of running the `RunnableModel` on the worker thread
and sending the outputs back to the main thread.

#### Parameters

| Name | Type | Description |
| :------ | :------ | :------ |
| `init` | () => `Promise`<[`WasmModelInitResult`](../../../runtime/docs/interfaces/WasmModelInitResult.md)\> | The function that initializes the `WasmModel` instance that is used in the worker thread. |
| `init` | () => `Promise`<`RunnableModel` \| [`WasmModelInitResult`](../../../runtime/docs/interfaces/WasmModelInitResult.md)\> | The function that initializes the `RunnableModel` instance that is used in the worker thread. |

#### Returns

Expand Down
58 changes: 15 additions & 43 deletions packages/runtime-async/src/runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import { BlobWorker, spawn, Thread, Transfer, Worker } from 'threads'

import type { ModelRunner } from '@sdeverywhere/runtime'
import { Outputs, updateOutputIndices } from '@sdeverywhere/runtime'
import { BufferedRunModelParams, Outputs } from '@sdeverywhere/runtime'

/**
* Initialize a `ModelRunner` that runs the model asynchronously in a worker thread.
Expand Down Expand Up @@ -57,29 +57,9 @@ async function spawnAsyncModelRunnerWithWorker(worker: Worker): Promise<ModelRun

// Wait for the worker to initialize the wasm model (in the worker thread)
const initResult = await modelWorker.initModel()
let ioBuffer: ArrayBuffer = initResult.ioBuffer

// The run time is stored in the first 8 bytes of the shared buffer
const runTimeOffsetInBytes = 0
const runTimeLengthInElements = 1
const runTimeLengthInBytes = runTimeLengthInElements * 8

// The inputs are stored after the run time in the shared buffer
const inputsOffsetInBytes = runTimeOffsetInBytes + runTimeLengthInBytes
const inputsLengthInElements: number = initResult.inputsLength
const inputsLengthInBytes = inputsLengthInElements * 8

// The outputs are stored after the inputs in the shared buffer
const outputsOffsetInBytes = inputsOffsetInBytes + inputsLengthInBytes
const outputsLengthInElements: number = initResult.outputsLength
const outputsLengthInBytes = outputsLengthInElements * 8

// The output indices are (optionally) stored after the outputs in the shared buffer
const indicesOffsetInBytes = outputsOffsetInBytes + outputsLengthInBytes
const indicesLengthInElements: number = initResult.outputIndicesLength

// The row length is the number of elements in each row of the outputs buffer
const outputRowLength: number = initResult.outputRowLength
// Maintain a `BufferedRunModelParams` instance that holds the I/O parameters
const params = new BufferedRunModelParams()

// Use a flag to ensure that only one request is made at a time
let running = false
Expand All @@ -101,37 +81,29 @@ async function spawnAsyncModelRunnerWithWorker(worker: Worker): Promise<ModelRun
running = true
}

// Capture the current set of input values into the reusable buffer
const inputsArray = new Float64Array(ioBuffer, inputsOffsetInBytes, inputsLengthInElements)
for (let i = 0; i < inputs.length; i++) {
inputsArray[i] = inputs[i].get()
}

// Update the output indices, if needed
if (indicesLengthInElements > 0) {
const outputSpecs = outputs.varSpecs || []
const indicesArray = new Int32Array(ioBuffer, indicesOffsetInBytes, indicesLengthInElements)
updateOutputIndices(indicesArray, outputSpecs)
}
// Update the I/O parameters
params.updateFromParams(inputs, outputs)

// Run the model in the worker. We pass the underlying `ArrayBuffer`
// instance back to the worker wrapped in a `Transfer` to make it
// no-copy transferable, and then the worker will return it back
// to us.
let ioBuffer: ArrayBuffer
try {
ioBuffer = await modelWorker.runModel(Transfer(ioBuffer))
ioBuffer = await modelWorker.runModel(Transfer(params.getEncodedBuffer()))
} finally {
running = false
}

// Save the model run time
const runTimeArray = new Float64Array(ioBuffer, runTimeOffsetInBytes, runTimeLengthInElements)
outputs.runTimeInMillis = runTimeArray[0]
// Once the buffer is transferred to the worker, the buffer in the
// `BufferedRunModelParams` becomes "detached" and is no longer usable.
// After the buffer is transferred back from the worker, we need to
// restore the state of the object to use the new buffer.
params.updateFromEncodedBuffer(ioBuffer)

// Capture the outputs array by copying the data into the given `Outputs`
// data structure
const outputsArray = new Float64Array(ioBuffer, outputsOffsetInBytes, outputsLengthInElements)
outputs.updateFromBuffer(outputsArray, outputRowLength)
// Copy the output values and elapsed time from the buffer to the
// `Outputs` instance
params.finalizeOutputs(outputs)

return outputs
},
Expand Down
151 changes: 54 additions & 97 deletions packages/runtime-async/src/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,122 +3,79 @@
import type { TransferDescriptor } from 'threads'
import { expose, Transfer } from 'threads/worker'

import type { WasmBuffer, WasmModel, WasmModelInitResult } from '@sdeverywhere/runtime'
import { perfElapsed, perfNow } from '@sdeverywhere/runtime'
import type { RunnableModel, WasmModelInitResult } from '@sdeverywhere/runtime'
import { BufferedRunModelParams } from '@sdeverywhere/runtime'

// TODO: To avoid breaking existing code that returns `WasmModelInitResult`
// from this init function, we allow it to return either `WasmModelInitResult`
// or the newer `RunnableModel`. We will remove the `WasmModelInitResult` part
// in a future set of changes.
/** @hidden */
let initWasmModel: () => Promise<WasmModelInitResult>
/** @hidden */
let wasmModel: WasmModel
/** @hidden */
let inputsWasmBuffer: WasmBuffer<Float64Array>
/** @hidden */
let outputsWasmBuffer: WasmBuffer<Float64Array>
let initRunnableModel: () => Promise<RunnableModel | WasmModelInitResult>

/** @hidden */
let outputIndicesWasmBuffer: WasmBuffer<Int32Array>
let runnableModel: RunnableModel

/**
* Maintain a `BufferedRunModelParams` instance that wraps the transferable buffer
* containing the I/O parameters.
* @hidden
*/
const params = new BufferedRunModelParams()

interface InitResult {
outputVarIds: string[]
startTime: number
endTime: number
saveFreq: number
inputsLength: number
outputsLength: number
outputIndicesLength: number
outputRowLength: number
ioBuffer: ArrayBuffer
}

/** @hidden */
const modelWorker = {
async initModel(): Promise<TransferDescriptor<InitResult>> {
if (wasmModel) {
throw new Error('WasmModel was already initialized')
async initModel(): Promise<InitResult> {
if (runnableModel) {
throw new Error('RunnableModel was already initialized')
}

// Initialize the wasm model and associated buffers
const wasmResult = await initWasmModel()

// Capture the `WasmModel` instance and `WasmBuffer` instances
wasmModel = wasmResult.model
inputsWasmBuffer = wasmResult.inputsBuffer
outputsWasmBuffer = wasmResult.outputsBuffer
outputIndicesWasmBuffer = wasmResult.outputIndicesBuffer

// Create a combined array that will hold a copy of the inputs and outputs
// wasm buffers; this buffer is no-copy transferable, whereas the wasm ones
// are not allowed to be transferred
const runTimeLength = 8
const inputsLength = inputsWasmBuffer.getArrayView().length
const outputsLength = outputsWasmBuffer.getArrayView().length
const outputIndicesLength = outputIndicesWasmBuffer?.getArrayView().length || 0
const totalLength = runTimeLength + inputsLength + outputsLength + outputIndicesLength
const ioArray = new Float64Array(totalLength)
// Initialize the runnable model
// TODO: To avoid breaking existing code that returns `WasmModelInitResult`
// from this init function, we allow it to return either `WasmModelInitResult`
// or the newer `RunnableModel`. We will remove the `WasmModelInitResult` part
// in a future set of changes.
const initResult = await initRunnableModel()
// eslint-disable-next-line @typescript-eslint/no-explicit-any
if ((initResult as any).model !== undefined) {
// The result is a `WasmModelInitResult`, so extract the `WasmModel` (which implements
// the `RunnableModel` interface)
// eslint-disable-next-line @typescript-eslint/no-explicit-any
runnableModel = (initResult as any).model as RunnableModel
} else {
// Otherwise, we assume the result is a `RunnableModel`
runnableModel = initResult as RunnableModel
}

// Transfer the underlying buffer to the runner
const ioBuffer = ioArray.buffer
const initResult: InitResult = {
outputVarIds: wasmResult.outputVarIds,
startTime: wasmModel.startTime,
endTime: wasmModel.endTime,
saveFreq: wasmModel.saveFreq,
inputsLength,
outputsLength,
outputIndicesLength,
outputRowLength: wasmModel.numSavePoints,
ioBuffer
// Transfer the model metadata to the runner
return {
outputVarIds: runnableModel.outputVarIds,
startTime: runnableModel.startTime,
endTime: runnableModel.endTime,
saveFreq: runnableModel.saveFreq,
outputRowLength: runnableModel.numSavePoints
}
return Transfer(initResult, [ioBuffer])
},

runModel(ioBuffer: ArrayBuffer): TransferDescriptor<ArrayBuffer> {
if (!wasmModel) {
throw new Error('WasmModel must be initialized before running the model in worker')
if (!runnableModel) {
throw new Error('RunnableModel must be initialized before running the model in worker')
}

// The run time is stored in the first 8 bytes
const runTimeOffsetInBytes = 0
const runTimeLengthInElements = 1
const runTimeLengthInBytes = runTimeLengthInElements * 8

// Copy the inputs into the wasm inputs buffer
const inputsWasmArray = inputsWasmBuffer.getArrayView()
const inputsOffsetInBytes = runTimeOffsetInBytes + runTimeLengthInBytes
const inputsLengthInElements = inputsWasmArray.length
const inputsLengthInBytes = inputsWasmArray.byteLength
const inputsBufferArray = new Float64Array(ioBuffer, inputsOffsetInBytes, inputsLengthInElements)
inputsWasmArray.set(inputsBufferArray)

// Copy the output indices into the wasm buffer, if needed
const outputsWasmArray = outputsWasmBuffer.getArrayView()
const outputsOffsetInBytes = runTimeLengthInBytes + inputsLengthInBytes
const outputsLengthInBytes = outputsWasmArray.byteLength
let useIndices = false
if (outputIndicesWasmBuffer) {
const indicesWasmArray = outputIndicesWasmBuffer.getArrayView()
const indicesLengthInElements = indicesWasmArray.length
const indicesOffsetInBytes = outputsOffsetInBytes + outputsLengthInBytes
const indicesBufferArray = new Int32Array(ioBuffer, indicesOffsetInBytes, indicesLengthInElements)
if (indicesBufferArray[0] !== 0) {
// Only use the indices if the first index is non-zero
indicesWasmArray.set(indicesBufferArray)
useIndices = true
}
}

// Run the model using the wasm buffers
const t0 = perfNow()
wasmModel.runModel(inputsWasmBuffer, outputsWasmBuffer, useIndices ? outputIndicesWasmBuffer : undefined)
const elapsed = perfElapsed(t0)

// Write the model run time to the buffer
const runTimeBufferArray = new Float64Array(ioBuffer, runTimeOffsetInBytes, runTimeLengthInElements)
runTimeBufferArray[0] = elapsed
// Update the `BufferedRunModelParams` to use the values in the buffer that was transferred
// from the runner to the worker
params.updateFromEncodedBuffer(ioBuffer)

// Copy the outputs from the wasm outputs buffer
const outputsLengthInElements = outputsWasmArray.length
const outputsBufferArray = new Float64Array(ioBuffer, outputsOffsetInBytes, outputsLengthInElements)
outputsBufferArray.set(outputsWasmArray)
// Run the model synchronously on the worker thread using those I/O parameters
runnableModel.runModel(params)

// Transfer the buffer back to the runner
return Transfer(ioBuffer)
Expand All @@ -128,16 +85,16 @@ const modelWorker = {
/**
* Expose an object in the current worker thread that communicates with the
* `ModelRunner` instance running in the main thread. The exposed worker
* object will take care of running the `WasmModel` on the worker thread
* and sending the outputs back to the main process.
* object will take care of running the `RunnableModel` on the worker thread
* and sending the outputs back to the main thread.
*
* @param init The function that initializes the `WasmModel` instance that
* @param init The function that initializes the `RunnableModel` instance that
* is used in the worker thread.
*/
export function exposeModelWorker(init: () => Promise<WasmModelInitResult>): void {
export function exposeModelWorker(init: () => Promise<RunnableModel | WasmModelInitResult>): void {
// Save the initializer, which will be used when the runner calls `initModel`
// on the worker
initWasmModel = init
initRunnableModel = init

// Expose the worker implementation to `threads.js`
expose(modelWorker)
Expand Down
Loading

0 comments on commit 5e1c686

Please sign in to comment.