Skip to content

Commit

Permalink
[js/webgpu] Support capture and replay for jsep (#18989)
Browse files Browse the repository at this point in the history
This PR expands the graph capture capability to JS EP, which is similar
to #16081. But for JS EP, we don't use the CUDA Graph, instead, we
records all gpu commands and replay them, which removes most of the cpu
overhead to avoid the the situation that gpu waiting for cpu.

mobilenetv2-12 becomes 3.7ms from 6ms on NV 3090 and becomes 3.38ms from
4.58ms on Intel A770.

All limitations are similar with CUDA EP:
1. Models with control-flow ops (i.e. If, Loop and Scan ops) are not
supported.
2. Usage of graph capture is limited to models where-in all ops in the
model can be partitioned to the JS EP or CPU EP and no memory copy
between them.
3. Shapes of inputs/outputs cannot change across inference calls.
4. IObinding is required.

The usage is like below:
Method 1: specify outputs buffers explicitly.
```
    const sessionOptions = {
        executionProviders: [
          {
            name: "webgpu",
          },
        ],
        enableGraphCapture: true,
      };
    const session = await ort.InferenceSession.create('./models/mobilenetv2-12.onnx', sessionOptions);

    // prepare the inputBuffer/outputBuffer
    ... ...

   const feeds = {
       'input': ort.Tensor.fromGpuBuffer(inputBuffer, { dataType: 'float32', dims })
   };

   const fetches = {
       'output': ort.Tensor.fromGpuBuffer(outputBuffer, { dataType: 'float32', dims: [1, 1000] })
   };

   let results = await session.run(feeds, fetches);  // The first run will begin to capture the graph.

   // update inputBuffer content
  ... ...
   results = = await session.run(feeds, fetches);  // The 2ed run and after will directly call replay to execute the graph.

  ... ...
   session.release();
```
Method 2: Don't specify outputs buffers explicitly. Internally, when
graph capture is enabled, it will set all outputs location to
'gpu-buffer'.
```
    const sessionOptions = {
        executionProviders: [
          {
            name: "webgpu",
          },
        ],
        enableGraphCapture: true,
      };
    const session = await ort.InferenceSession.create('./models/mobilenetv2-12.onnx', sessionOptions);

    // prepare the inputBuffer
    ... ...

   const feeds = {
       'input': ort.Tensor.fromGpuBuffer(inputBuffer, { dataType: 'float32', dims })
   };

   let results = await session.run(feeds);  // The first run will begin to capture the graph.

   // update inputBuffer content
  ... ...
   results = = await session.run(feeds);  // The 2ed run and after will directly call replay to execute the graph.

  ... ...
   session.release();
  • Loading branch information
qjia7 authored and fs-eire committed Mar 15, 2024
1 parent c61a8e5 commit 73640b6
Show file tree
Hide file tree
Showing 16 changed files with 438 additions and 135 deletions.
8 changes: 7 additions & 1 deletion js/common/lib/inference-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ export declare namespace InferenceSession {
optimizedModelFilePath?: string;

/**
* Wether enable profiling.
* Whether enable profiling.
*
* This setting is a placeholder for a future use.
*/
Expand Down Expand Up @@ -154,6 +154,12 @@ export declare namespace InferenceSession {
*/
preferredOutputLocation?: OnnxValueDataLocation|{readonly [outputName: string]: OnnxValueDataLocation};

/**
* Whether enable graph capture.
* This setting is available only in ONNXRuntime Web for WebGPU EP.
*/
enableGraphCapture?: boolean;

/**
* Store configurations for a session. See
* https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/
Expand Down
25 changes: 16 additions & 9 deletions js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ export declare namespace JSEP {
type ReleaseKernelFunction = (kernel: number) => void;
type RunFunction =
(kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array<Promise<string|null>>) => number;
type CaptureBeginFunction = () => void;
type CaptureEndFunction = () => void;
type ReplayFunction = () => void;
}

export interface OrtWasmModule extends EmscriptenModule {
Expand Down Expand Up @@ -128,7 +131,8 @@ export interface OrtWasmModule extends EmscriptenModule {
jsepInit?
(backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction,
download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction,
releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void;
releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction, captureBegin: JSEP.CaptureBeginFunction,
captureEnd: JSEP.CaptureEndFunction, replay: JSEP.ReplayFunction): void;

/**
* [exported from wasm] Specify a kernel's output when running OpKernel::Compute().
Expand Down Expand Up @@ -158,12 +162,6 @@ export interface OrtWasmModule extends EmscriptenModule {
* @returns the GPU data ID for the registered GPU buffer.
*/
jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number;
/**
* [exported from js_internal_api.js] Unregister all user GPU buffers for a session.
*
* @param sessionId - specify the session ID.
*/
jsepUnregisterBuffers?: (sessionId: number) => void;
/**
* [exported from js_internal_api.js] Get the GPU buffer by GPU data ID.
*
Expand All @@ -183,9 +181,18 @@ export interface OrtWasmModule extends EmscriptenModule {
(gpuBuffer: GPUBuffer, size: number,
type: Tensor.GpuBufferDataTypes) => () => Promise<Tensor.DataTypeMap[Tensor.GpuBufferDataTypes]>;
/**
* [exported from js_internal_api.js] Called when InferenceSession.run started.
* [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before
* _OrtRun[WithBinding]() is called.
* @param sessionId - specify the session ID.
*/
jsepOnRunStart: (sessionId: number) => void;
/**
* [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is
* called.
* @param sessionId - specify the session ID.
* @returns
*/
jsepOnRunStart: () => void;
jsepOnReleaseSession: (sessionId: number) => void;
// #endregion
}

Expand Down
100 changes: 96 additions & 4 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@ import {createView, TensorView} from './tensor-view';
import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager';
import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules';
import {ProgramManager} from './webgpu/program-manager';
import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, TimestampQuery} from './webgpu/types';
import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types';

interface CommandInfo {
readonly kernelId: number;
readonly computePipeline: GPUComputePipeline;
readonly bindGroup: GPUBindGroup;
readonly dispatchGroup: [number, number, number];
}

interface KernelInfo {
readonly kernelType: string;
Expand Down Expand Up @@ -103,6 +110,13 @@ export class WebGpuBackend {
*/
programManager: ProgramManager;

/**
* representing the session ID of which is currently being run.
* `null` means no session is being run.
* only valid when session.run is executed.
*/
currentSessionId: number|null = null;

/**
* representing the kernel ID of which is currently being computed (CPU code perspective).
* `null` means no kernel is being computed.
Expand Down Expand Up @@ -155,6 +169,16 @@ export class WebGpuBackend {
queryType: TimestampQuery;

env: Env;
sessionStatus: SessionState = 'default';
/**
* a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session.
*/
capturedCommandList: Map<number, CommandInfo[]> = new Map();

/**
* a SessionID -> PendingKernelInfo[] mapping for profiling.
*/
private capturedPendingKernels: Map<number, PendingKernelInfo[]> = new Map();

/**
* a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping.
Expand Down Expand Up @@ -228,6 +252,7 @@ export class WebGpuBackend {

getComputePassEncoder(): GPUComputePassEncoder {
if (!this.computePassEncoder) {
const commandEncoder = this.getCommandEncoder();
const computePassDescriptor: GPUComputePassDescriptor = {};

if (this.queryType === 'at-passes') {
Expand All @@ -238,7 +263,7 @@ export class WebGpuBackend {
};
}

this.computePassEncoder = this.getCommandEncoder().beginComputePass(computePassDescriptor);
this.computePassEncoder = commandEncoder.beginComputePass(computePassDescriptor);
}
return this.computePassEncoder;
}
Expand Down Expand Up @@ -494,14 +519,17 @@ export class WebGpuBackend {
() => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${
normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`);

if (this.queryType !== 'none') {
if (this.queryType !== 'none' || this.sessionStatus === 'capturing') {
const pendingKernelInfo: PendingKernelInfo = {
kernelId: this.currentKernelId!,
programName: artifact.programInfo.name,
inputTensorViews,
outputTensorViews,
};
this.pendingKernels.push(pendingKernelInfo);

const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!);
sessionPendingKernels!.push(pendingKernelInfo);
}

this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup, uniformBufferBinding);
Expand Down Expand Up @@ -672,7 +700,71 @@ export class WebGpuBackend {
}
}
}
onRunStart(): void {

captureBegin(): void {
LOG_DEBUG('info', 'captureBegin');
let sessionCommandList = this.capturedCommandList.get(this.currentSessionId!);
let sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!);
if (!sessionCommandList) {
sessionCommandList = [];
this.capturedCommandList.set(this.currentSessionId!, sessionCommandList);
sessionPendingKernels = [];
this.capturedPendingKernels.set(this.currentSessionId!, sessionPendingKernels);
}
// flush the left commands before we change the status.
this.flush();
this.sessionStatus = 'capturing';
}
captureEnd(): void {
LOG_DEBUG('info', 'captureEnd');
// flush the left commands before we change the status.
this.flush();
this.sessionStatus = 'default';
}
replay(): void {
LOG_DEBUG('info', 'replay');
this.sessionStatus = 'replaying';
const sessionCommandList = this.capturedCommandList.get(this.currentSessionId!);
const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!);
const length = sessionCommandList!.length;
this.pendingKernels = [];
for (let i = 0; i < length; i++) {
const computePassEncoder = this.getComputePassEncoder();
const command = sessionCommandList![i];
this.writeTimestamp(this.pendingDispatchNumber * 2);
computePassEncoder.setPipeline(command.computePipeline);
computePassEncoder.setBindGroup(0, command.bindGroup);
computePassEncoder.dispatchWorkgroups(...command.dispatchGroup);
this.writeTimestamp(this.pendingDispatchNumber * 2 + 1);
this.pendingDispatchNumber++;
if (this.queryType !== 'none') {
this.pendingKernels.push(sessionPendingKernels![i]);
}
if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') {
this.endComputePass();
}
if (this.pendingDispatchNumber >= this.maxDispatchNumber) {
this.flush();
}
}
// flush the left commands before we change the status.
this.flush();
this.sessionStatus = 'default';
}

onReleaseSession(sessionId: number): void {
this.unregisterBuffers(sessionId);
if (this.capturedCommandList.has(sessionId)) {
this.capturedCommandList.delete(sessionId);
}
if (this.capturedPendingKernels.has(sessionId)) {
this.capturedPendingKernels.delete(sessionId);
}
this.gpuDataManager.onReleaseSession(sessionId);
}

onRunStart(sessionId: number): void {
this.currentSessionId = sessionId;
this.setQueryType();
}
}
8 changes: 7 additions & 1 deletion js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -201,5 +201,11 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte
contextDataOffset}`);
const context = new ComputeContextImpl(module, backend, contextDataOffset);
return backend.computeKernel(kernel, context, errors);
});
},
// jsepCaptureBegin
() => backend.captureBegin(),
// jsepCaptureEnd
() => backend.captureEnd(),
// jsepReplay
() => backend.replay());
};
74 changes: 62 additions & 12 deletions js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,15 @@ export interface GpuDataManager {
unregisterExternalBuffer(buffer: GPUBuffer): void;

/**
* destroy all gpu buffers. Call this when the session.release is called.
* destroy all gpu buffers.
*/
dispose(): void;

/**
* release session related data.
* @param sessionId - specify the session ID.
*/
onReleaseSession(sessionId: number): void;
}

interface StorageCacheValue {
Expand Down Expand Up @@ -139,13 +145,18 @@ class GpuDataManagerImpl implements GpuDataManager {
// The external buffers registered users for IO Binding.
private externalBuffers: Map<GPUBuffer, GpuDataId>;

// The pendingBuffers for capture graph.
// a SessionID -> GPUBuffer[] mapping.
private capturedPendingBuffers: Map<number, GPUBuffer[]>;

constructor(private backend: WebGpuBackend) {
this.storageCache = new Map();
this.freeBuffers = new Map();
this.freeUniformBuffers = new Map();
this.buffersForUploadingPending = [];
this.buffersPending = [];
this.externalBuffers = new Map();
this.capturedPendingBuffers = new Map();
}

upload(id: GpuDataId, data: Uint8Array): void {
Expand Down Expand Up @@ -220,6 +231,9 @@ class GpuDataManagerImpl implements GpuDataManager {
() => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${
id}, buffer is the same, skip.`);
return id;
} else if (this.backend.capturedCommandList.has(this.backend.currentSessionId!)) {
throw new Error(`Registering a different external buffer under graph capture mode is not supported yet.
Please use the previous external buffer!`);
}
this.externalBuffers.delete(previousBuffer);
} else {
Expand Down Expand Up @@ -312,20 +326,39 @@ class GpuDataManagerImpl implements GpuDataManager {
buffer.destroy();
}
this.buffersForUploadingPending = [];
for (const buffer of this.buffersPending) {
// eslint-disable-next-line no-bitwise
if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) {
// Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing.
this.freeBuffers.get(buffer.size)!.push(buffer);

if (this.buffersPending.length === 0) {
return;
}

if (this.backend.sessionStatus === 'default') {
for (const buffer of this.buffersPending) {
// eslint-disable-next-line no-bitwise
} else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) {
// Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing.
this.freeUniformBuffers.get(buffer.size)!.push(buffer);
} else {
buffer.destroy();
if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) {
// Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing.
this.freeBuffers.get(buffer.size)!.push(buffer);
// eslint-disable-next-line no-bitwise
} else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) {
// Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing.
this.freeUniformBuffers.get(buffer.size)!.push(buffer);
} else {
buffer.destroy();
}
}
this.buffersPending = [];
} else {
// Don't release intermediate tensors in non-default mode.
// TODO: reuse the storage buffers in non-default mode.
let capturedBuffers = this.capturedPendingBuffers.get(this.backend.currentSessionId!);
if (!capturedBuffers) {
capturedBuffers = [];
this.capturedPendingBuffers.set(this.backend.currentSessionId!, capturedBuffers);
}
for (const buffer of this.buffersPending) {
capturedBuffers.push(buffer);
}
this.buffersPending = [];
}
this.buffersPending = [];
}

dispose() {
Expand All @@ -344,9 +377,26 @@ class GpuDataManagerImpl implements GpuDataManager {
storage.gpuData.buffer.destroy();
});

this.capturedPendingBuffers.forEach((buffers) => {
buffers.forEach(buffer => {
buffer.destroy();
});
});
this.storageCache = new Map();
this.freeBuffers = new Map();
this.freeUniformBuffers = new Map();
this.capturedPendingBuffers = new Map();
}

onReleaseSession(sessionId: number) {
// release the captured pending buffers.
const pendingBuffers = this.capturedPendingBuffers.get(sessionId);
if (pendingBuffers) {
pendingBuffers.forEach(buffer => {
buffer.destroy();
});
this.capturedPendingBuffers.delete(sessionId);
}
}
}

Expand Down
15 changes: 13 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/program-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ export class ProgramManager {
const device = this.backend.device;
const computePassEncoder = this.backend.getComputePassEncoder();
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2);
computePassEncoder.setPipeline(buildArtifact.computePipeline);
const entries = [];
for (const input of inputs) {
entries.push({binding: entries.length, resource: {buffer: input.buffer}});
Expand All @@ -51,8 +50,20 @@ export class ProgramManager {
}
const bindGroup = device.createBindGroup(
{layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries, label: buildArtifact.programInfo.name});
computePassEncoder.setBindGroup(0, bindGroup);

if (this.backend.sessionStatus === 'capturing') {
const commandInfo = {
kernelId: this.backend.currentKernelId!,
computePipeline: buildArtifact.computePipeline,
bindGroup,
dispatchGroup
};
const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!);
sessionCommandList!.push(commandInfo);
}

computePassEncoder.setPipeline(buildArtifact.computePipeline);
computePassEncoder.setBindGroup(0, bindGroup);
computePassEncoder.dispatchWorkgroups(...dispatchGroup);
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1);
this.backend.pendingDispatchNumber++;
Expand Down
Loading

0 comments on commit 73640b6

Please sign in to comment.