Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN EP] Automatically move input CPU tensors to ml-tensor #23073

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 77 additions & 11 deletions js/web/lib/wasm/jsep/backend-webnn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,19 @@ export class WebNNBackend {
* Current session id.
*/
private activeSessionId?: number;
/**
* Maps from session id to list of graph inputs.
*/
private sessionGraphInputs: Map<number, string[]> = new Map();
/**
* Temporary graph inputs for the current session.
* These inputs will be registered when the session is created.
*/
private temporaryGraphInputs: string[] = [];
/**
* Temporary tensors for the current session.
*/
private temporarySessionTensorIds: Map<number, TensorId[]> = new Map();

constructor(env: Env) {
configureLogger(env.logLevel!, !!env.debug);
Expand All @@ -88,9 +101,24 @@ export class WebNNBackend {
}

public onRunStart(sessionId: number): void {
LOG_DEBUG('verbose', () => `[WebNN] onRunStart {sessionId: ${sessionId}}`);
this.activeSessionId = sessionId;
}

public onRunEnd(sessionId: number): void {
LOG_DEBUG('verbose', () => `[WebNN] onRunEnd {sessionId: ${sessionId}}`);
const tensorIds = this.temporarySessionTensorIds.get(sessionId);
if (!tensorIds) {
return;
}
for (const tensorId of tensorIds) {
LOG_DEBUG('verbose', () => `[WebNN] releasing temporary tensor {tensorId: ${tensorId}}`);
this.tensorManager.releaseTensorId(tensorId);
}
this.temporarySessionTensorIds.delete(sessionId);
this.activeSessionId = undefined;
}

public async createMLContext(optionsOrDevice?: MLContextOptions | GPUDevice): Promise<MLContext> {
if (optionsOrDevice instanceof GPUDevice) {
const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.gpuDevice === optionsOrDevice);
Expand Down Expand Up @@ -126,14 +154,6 @@ export class WebNNBackend {
}
}

public get currentContext(): MLContext {
const mlContext = this.getMLContext(this.currentSessionId);
if (!mlContext) {
throw new Error(`No MLContext found for session ${this.currentSessionId}`);
}
return mlContext;
}

public registerMLContext(sessionId: number, mlContext: MLContext): void {
this.mlContextBySessionId.set(sessionId, mlContext);
let sessionIds = this.sessionIdsByMLContext.get(mlContext);
Expand All @@ -142,9 +162,15 @@ export class WebNNBackend {
this.sessionIdsByMLContext.set(mlContext, sessionIds);
}
sessionIds.add(sessionId);

if (this.temporaryGraphInputs.length > 0) {
this.sessionGraphInputs.set(sessionId, this.temporaryGraphInputs);
this.temporaryGraphInputs = [];
}
}

public onReleaseSession(sessionId: number): void {
this.sessionGraphInputs.delete(sessionId);
const mlContext = this.mlContextBySessionId.get(sessionId)!;
if (!mlContext) {
// Current session is not a WebNN session.
Expand Down Expand Up @@ -177,6 +203,7 @@ export class WebNNBackend {
}

public async ensureTensor(
sessionId: number | undefined,
tensorId: TensorId,
onnxDataType: DataType,
dimensions: number[],
Expand All @@ -186,7 +213,34 @@ export class WebNNBackend {
if (!webnnDataType) {
throw new Error(`Unsupported ONNX data type: ${onnxDataType}`);
}
return this.tensorManager.ensureTensor(tensorId, webnnDataType, dimensions, copyOld);
return this.tensorManager.ensureTensor(
sessionId ?? this.currentSessionId,
tensorId,
webnnDataType,
dimensions,
copyOld,
);
}

public async createTemporaryTensor(
sessionId: number,
onnxDataType: DataType,
shape: readonly number[],
): Promise<TensorId> {
LOG_DEBUG('verbose', () => `[WebNN] createTemporaryTensor {onnxDataType: ${onnxDataType}, shape: ${shape}}`);
const dataType = onnxDataTypeToWebnnDataType.get(onnxDataType);
if (!dataType) {
throw new Error(`Unsupported ONNX data type: ${onnxDataType}`);
}
const tensorId = this.tensorManager.reserveTensorId();
await this.tensorManager.ensureTensor(sessionId, tensorId, dataType, shape, false);
const tensorIds = this.temporarySessionTensorIds.get(sessionId);
if (!tensorIds) {
this.temporarySessionTensorIds.set(sessionId, [tensorId]);
} else {
tensorIds.push(tensorId);
}
return tensorId;
}

public uploadTensor(tensorId: TensorId, data: Uint8Array): void {
Expand All @@ -209,13 +263,13 @@ export class WebNNBackend {
};
}

public registerMLTensor(tensor: MLTensor, onnxDataType: DataType, dimensions: number[]): TensorId {
public registerMLTensor(sessionId: number, tensor: MLTensor, onnxDataType: DataType, dimensions: number[]): TensorId {
const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType);
if (!webnnDataType) {
throw new Error(`Unsupported ONNX data type: ${onnxDataType}`);
}

const id = this.tensorManager.registerTensor(this.currentContext, tensor, webnnDataType, dimensions);
const id = this.tensorManager.registerTensor(sessionId, tensor, webnnDataType, dimensions);
LOG_DEBUG(
'verbose',
() =>
Expand Down Expand Up @@ -291,6 +345,18 @@ export class WebNNBackend {
return builder.constant(desc, bufferView);
}

public registerGraphInput(inputName: string): void {
this.temporaryGraphInputs.push(inputName);
}

public isGraphInput(sessionId: number, inputName: string): boolean {
const inputNames = this.sessionGraphInputs.get(sessionId);
if (!inputNames) {
return false;
}
return inputNames.includes(inputName);
}

public flush(): void {
// Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations.
}
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,8 @@ export const init = async (
// jsepReleaseTensorId,
(tensorId: number) => backend.releaseTensorId(tensorId),
// jsepEnsureTensor
async (tensorId: number, onnxDataType: number, shape: number[], copyOld) =>
backend.ensureTensor(tensorId, onnxDataType, shape, copyOld),
async (sessionId: number | undefined, tensorId: number, onnxDataType: number, shape: number[], copyOld) =>
backend.ensureTensor(sessionId, tensorId, onnxDataType, shape, copyOld),
// jsepUploadTensor
(tensorId: number, data: Uint8Array) => {
backend.uploadTensor(tensorId, data);
Expand Down
30 changes: 21 additions & 9 deletions js/web/lib/wasm/jsep/webnn/tensor-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export interface TensorManager {
* Ensure a MLTensor is created for the TensorId.
*/
ensureTensor(
sessionId: number,
tensorId: TensorId,
dataType: MLOperandDataType,
shape: readonly number[],
Expand All @@ -46,9 +47,9 @@ export interface TensorManager {
*/
releaseTensorsForSession(session: number): void;
/**
* Register an externally created MLTensor with a given MLContext and return a TensorId.
* Register an externally created MLTensor with a given session id and return a TensorId.
*/
registerTensor(mlContext: MLContext, mlTensor: MLTensor, dataType: MLOperandDataType, shape: number[]): TensorId;
registerTensor(sessionId: number, mlTensor: MLTensor, dataType: MLOperandDataType, shape: number[]): TensorId;
}

let tensorGuid = 1;
Expand Down Expand Up @@ -176,6 +177,7 @@ class TensorIdTracker {
}

public async ensureTensor(
sessionId: number,
dataType: MLOperandDataType,
shape: readonly number[],
copyOld: boolean,
Expand All @@ -196,7 +198,7 @@ class TensorIdTracker {

// eslint-disable-next-line no-bitwise
const usage = typeof MLTensorUsage == 'undefined' ? undefined : MLTensorUsage.READ | MLTensorUsage.WRITE;
this.wrapper = await this.tensorManager.getCachedTensor(dataType, shape, usage, true, true);
this.wrapper = await this.tensorManager.getCachedTensor(sessionId, dataType, shape, usage, true, true);

if (copyOld && this.activeUpload) {
this.wrapper.write(this.activeUpload);
Expand Down Expand Up @@ -254,6 +256,14 @@ class TensorManagerImpl implements TensorManager {

constructor(private backend: WebNNBackend) {}

public getMLContext(sessionId: number): MLContext {
const context = this.backend.getMLContext(sessionId);
if (!context) {
throw new Error('MLContext not found for session.');
}
return context;
}

public reserveTensorId(): TensorId {
const tensorId = createNewTensorId();
this.tensorTrackersById.set(tensorId, new TensorIdTracker(this));
Expand All @@ -272,6 +282,7 @@ class TensorManagerImpl implements TensorManager {
}

public async ensureTensor(
sessionId: number,
tensorId: TensorId,
dataType: MLOperandDataType,
shape: number[],
Expand All @@ -288,7 +299,7 @@ class TensorManagerImpl implements TensorManager {
if (!tensor) {
throw new Error('Tensor not found.');
}
return tensor.ensureTensor(dataType, shape, copyOld);
return tensor.ensureTensor(sessionId, dataType, shape, copyOld);
}

public upload(tensorId: TensorId, data: Uint8Array): void {
Expand Down Expand Up @@ -323,17 +334,18 @@ class TensorManagerImpl implements TensorManager {
}

public registerTensor(
mlContext: MLContext,
sessionId: number,
mlTensor: MLTensor,
dataType: MLOperandDataType,
shape: readonly number[],
): TensorId {
const context = this.getMLContext(sessionId);
const tensorId = createNewTensorId();
// Defaulting to READ | WRITE if usage is not provided.
// eslint-disable-next-line no-bitwise
const wrapper = new TensorWrapper({
sessionId: this.backend.currentSessionId,
context: mlContext,
sessionId,
context,
tensor: mlTensor,
dataType,
shape,
Expand All @@ -347,13 +359,13 @@ class TensorManagerImpl implements TensorManager {
* Get or create an MLTensor with the given data type and shape.
*/
public async getCachedTensor(
sessionId: number,
dataType: MLOperandDataType,
shape: readonly number[],
usage: MLTensorUsageFlags | undefined,
writable: boolean,
readable: boolean,
): Promise<TensorWrapper> {
const sessionId = this.backend.currentSessionId;
for (const [index, tensor] of this.freeTensors.entries()) {
if (tensor.sameTypeAndShape(dataType, shape)) {
LOG_DEBUG('verbose', () => `[WebNN] Reusing tensor {dataType: ${dataType}, shape: ${shape}}`);
Expand All @@ -362,7 +374,7 @@ class TensorManagerImpl implements TensorManager {
return wrapper;
}
}
const context = this.backend.currentContext;
const context = this.getMLContext(sessionId);
LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`);
const tensor = await context.createTensor({
dataType,
Expand Down
Loading
Loading