From 003f537a281e1ff1c6b789c6d8b48f5d0733bc27 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 16 Feb 2024 14:30:13 -0800 Subject: [PATCH] [js/webgpu] allow uint8 tensors for webgpu --- js/common/lib/tensor-impl.ts | 2 +- js/common/lib/tensor.ts | 2 +- js/web/lib/wasm/wasm-common.ts | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index e3e2b9c728556..de18126a9d0ae 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -103,7 +103,7 @@ export class Tensor implements TensorInterface { } case 'gpu-buffer': { if ((type !== 'float32' && type !== 'float16' && type !== 'int32' && type !== 'int64' && type !== 'uint32' && - type !== 'bool')) { + type !== 'uint8' && type !== 'bool')) { throw new TypeError(`unsupported type "${type}" to create tensor from gpu buffer`); } this.gpuBufferData = arg0.gpuBuffer; diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index 6c08d1fe8e057..d5da33640dc7d 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -135,7 +135,7 @@ export declare namespace Tensor { /** * supported data types for constructing a tensor from a WebGPU buffer */ - export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'bool'; + export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'uint8'|'bool'; /** * represent where the tensor data is stored diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index b9eff45e890c4..93910af1f1bf0 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -169,7 +169,8 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro * Check whether the given tensor type is supported by GPU buffer */ export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuBufferDataTypes => type === 'float32' || - type === 'int32' || type === 'int64' || type === 'bool' || type === 'float16' || type === 'uint32'; + type === 'float16' || type === 'int32' || type === 'int64' || type === 'uint32' || type === 'uint8' || + type === 'bool'; /** * Map string data location to integer value