Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] Float16Array polyfill for uniform #19307

Closed
wants to merge 42 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
5a95268
[js] Support Float16Array polyfill
axinging Jan 26, 2024
706edd4
Try conditional float16 polyfill
axinging Jan 29, 2024
5453a7a
NIt
axinging Jan 29, 2024
a8d4d77
Nit, removbe more
axinging Jan 29, 2024
ce46073
rEMOVE ONE MORE
axinging Jan 29, 2024
d6c86a3
Clean one more
axinging Jan 29, 2024
3b49dee
Remove more
axinging Jan 30, 2024
da77265
Move package to devDeps
axinging Jan 30, 2024
4396346
Remove isFloat16Array
axinging Jan 30, 2024
1b853ba
Remove else for f16
axinging Jan 30, 2024
cb3ba6c
format and rename padf16.jsonc
axinging Jan 30, 2024
1480c48
Revert webgl
axinging Jan 30, 2024
d63d39a
Add FLOAT16_TYPE_POLYFILL
axinging Jan 30, 2024
82485df
Fix lint
axinging Jan 30, 2024
f44bd9c
Merge branch 'main' into float16_polyfill_v2
axinging Jan 30, 2024
6d05812
Fix comment
axinging Jan 30, 2024
0d9e517
Support pad v11 or above
axinging Jan 31, 2024
453cfe9
Refine Float16ArrayType
axinging Jan 31, 2024
d281faf
Merge branch 'main' into float16_polyfill_v2
axinging Jan 31, 2024
2fc4f58
Fix tensor.ts build error
axinging Jan 31, 2024
6db94a6
Merge branch 'main' into float16_polyfill_v2
axinging Jan 31, 2024
4ae19d2
Support fallback to uint16
axinging Jan 31, 2024
e9a6cfa
Merge branch 'main' into float16_polyfill_v2
axinging Feb 19, 2024
453b52e
Merge branch 'main' into float16_polyfill_v2
axinging Feb 21, 2024
7e75c14
Remove FLoat16 in test-runner.ts
axinging Feb 21, 2024
861c472
Remove onnxjs change
axinging Feb 21, 2024
0082c71
Remove change in onnxjs
axinging Feb 21, 2024
b6f4cd5
Remove not necessary change in packaje.json
axinging Feb 22, 2024
cecada7
Merge branch 'main' into float16_polyfill_v2
axinging Feb 22, 2024
c9e90b1
Fix globalThis property missing
axinging Feb 22, 2024
10e9d65
Rename pad_f16 and add to test list
axinging Feb 23, 2024
7ab6a86
Merge branch 'main' into float16_polyfill_v2
axinging Feb 23, 2024
3619729
t
axinging Feb 23, 2024
c8c74d2
Fix lint
axinging Feb 23, 2024
4cb32ba
Revet webgl/test-conv-utils.ts
axinging Feb 23, 2024
3020722
Remove globalThis in test
axinging Feb 23, 2024
894c408
Nit, remove noVar
axinging Feb 23, 2024
aa474a8
Not float16 in test
axinging Feb 23, 2024
6795713
Remove float16 in common test
axinging Feb 26, 2024
432f311
use float16 in test
axinging Feb 27, 2024
44ee23f
Merge branch 'main' into float16_polyfill_v2
axinging Feb 27, 2024
a239ce2
Refine comment msg
axinging Feb 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions js/common/lib/tensor-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,7 @@ export class Tensor implements TensorInterface {
throw new TypeError(`Unsupported tensor type: ${arg0}.`);
}
if (Array.isArray(arg1)) {
if (arg0 === 'float16' && typedArrayConstructor === Uint16Array) {
// When no Float16Array polyfill is used, we cannot create 'float16' tensor from number array.
//
// Throw error here because when user try to use number array as data,
// e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call
// Uint16Array.from(arg1) which generates wrong data.
throw new TypeError(
'Creating a float16 tensor from number array is not supported. Please use Uint16Array as data.');
} else if (arg0 === 'uint64' || arg0 === 'int64') {
if (arg0 === 'uint64' || arg0 === 'int64') {
// use 'as any' here because:
// 1. TypeScript's check on type of 'Array.isArray()' does not work with readonly arrays.
// see https://github.com/microsoft/TypeScript/issues/17002
Expand Down
10 changes: 8 additions & 2 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,14 @@ export class WebGpuBackend {
} else if (v.type === DataType.uint32) {
new Uint32Array(arrayBuffer, offset, data.length).set(data);
} else if (v.type === DataType.float16) {
// TODO: use Float16Array.
new Uint16Array(arrayBuffer, offset, data.length).set(data);
if (typeof Float16Array !== 'undefined') {
new Float16Array(arrayBuffer, offset, data.length).set(data);
} else {
// Fallback to Uint16Array when Float16Array polyfill is not available, unit test only.
// eslint-disable-next-line no-console
console.warn('Unit test only, please make sure the float16 data has been encoded as float 16 bits.');
new Uint16Array(arrayBuffer, offset, data.length).set(data);
}
} else if (v.type === DataType.float) {
new Float32Array(arrayBuffer, offset, data.length).set(data);
} else {
Expand Down
12 changes: 11 additions & 1 deletion js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import {Env} from 'onnxruntime-common';

import {OrtWasmModule} from '../binding/ort-wasm';
import {DataType, getTensorElementSize} from '../wasm-common';
import {DataType, Float16ArrayType, getTensorElementSize} from '../wasm-common';

import {WebGpuBackend} from './backend-webgpu';
import {LOG_DEBUG} from './log';
Expand All @@ -19,6 +19,16 @@ class TensorViewImpl implements TensorView {
private module: OrtWasmModule, public readonly dataType: number, public readonly data: number,
public readonly dims: readonly number[]) {}

getFloat16Array(): Float16ArrayType {
if (this.dataType !== DataType.float16) {
throw new Error('Invalid data type');
}
const elementCount = ShapeUtil.size(this.dims);
const float16ViewConstructor = typeof Float16Array !== 'undefined' ? Float16Array : Uint16Array;
return elementCount === 0 ? new float16ViewConstructor() :
new float16ViewConstructor(this.module.HEAP8.buffer, this.data, elementCount);
}

getFloat32Array(): Float32Array {
if (this.dataType !== DataType.float) {
throw new Error('Invalid data type');
Expand Down
7 changes: 6 additions & 1 deletion js/web/lib/wasm/jsep/tensor-view.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import {Tensor} from 'onnxruntime-common';

import {tensorTypeToTypedArrayConstructor} from '../wasm-common';
import {Float16ArrayType, tensorTypeToTypedArrayConstructor} from '../wasm-common';

export const createView = (dataBuffer: ArrayBuffer, type: Tensor.Type): Int32Array|Uint32Array|BigInt64Array|
BigUint64Array|Uint8Array|Float32Array|Float64Array|Int8Array|Int16Array|Uint16Array =>
Expand All @@ -17,6 +17,11 @@ export interface TensorView {
readonly dataType: number;
readonly dims: readonly number[];

/**
* get a Float16Array data view of the tensor data. tensor data must be on CPU.
*/
getFloat16Array(): Float16ArrayType;

/**
* get a Float32Array data view of the tensor data. tensor data must be on CPU.
*/
Expand Down
4 changes: 3 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/pad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,9 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr
const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes: PadAttributes): PadAttributes => {
if (inputs.length > 1) {
const bigInt64Pads = inputs[1].getBigInt64Array();
const value = (inputs.length >= 3 && inputs[2].data) ? inputs[2].getFloat32Array()[0] : 0.0;
const value = (inputs.length >= 3 && inputs[2].data) ?
(inputs[2].dataType === DataType.float16 ? inputs[2].getFloat16Array()[0] : inputs[2].getFloat32Array()[0]) :
0.0;

const inputRank = inputs[0].dims.length;
const updatePads = new Int32Array(2 * inputRank).fill(0);
Expand Down
1 change: 1 addition & 0 deletions js/web/lib/wasm/wasm-common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ declare global {
// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any
const Float16Array: any;
}
export type Float16ArrayType = InstanceType<typeof Float16Array>;

// This file includes common definitions. They do NOT have dependency on the WebAssembly instance.

Expand Down
3 changes: 2 additions & 1 deletion js/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@
"minimatch": "^7.4.2",
"minimist": "^1.2.8",
"numpy-parser": "^1.2.3",
"strip-json-comments": "^5.0.0"
"strip-json-comments": "^5.0.0",
"@petamoriken/float16": "^3.8.4"
},
"main": "dist/ort-web.node.js",
"exports": {
Expand Down
74 changes: 74 additions & 0 deletions js/web/test/data/ops/pad-f16.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
[
{
"name": "constant 2D float16",
"operator": "Pad",
"opset": { "domain": "", "version": 10 },
"attributes": [
{ "name": "mode", "data": "constant", "type": "string" },
{ "name": "value", "data": 15565, "type": "float" },
{ "name": "pads", "data": [3, 2, 2, 3], "type": "ints" }
],
"cases": [
{
"name": "[2,2]->[7,7]",
"inputs": [
{
"data": [1.0, 2.0, 3.0, 4.5],
"dims": [2, 2],
"type": "float16"
}
],
"outputs": [
{
"data": [
1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2,
1.2, 1.2, 1.0, 2.0, 1.2, 1.2, 1.2, 1.2, 1.2, 3.0, 4.5, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2,
1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2
],
"dims": [7, 7],
"type": "float16"
}
]
}
]
},
{
"name": "constant 2D float16",
"operator": "Pad",
"opset": { "domain": "", "version": 19 },
"attributes": [{ "name": "mode", "data": "constant", "type": "string" }],
"cases": [
{
"name": "[2,2]->[7,7]",
"inputs": [
{
"data": [1.0, 2.0, 3.0, 4.5],
"dims": [2, 2],
"type": "float16"
},
{
"data": [3, 2, 2, 3],
"dims": [4],
"type": "int64"
},
{
"data": [1.2],
"dims": [1],
"type": "float16"
}
],
"outputs": [
{
"data": [
1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2,
1.2, 1.2, 1.0, 2.0, 1.2, 1.2, 1.2, 1.2, 1.2, 3.0, 4.5, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2,
1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2
],
"dims": [7, 7],
"type": "float16"
}
]
}
]
}
]
1 change: 1 addition & 0 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,7 @@
"gelu.jsonc",
"pad.jsonc",
"pad-big.jsonc",
"pad-f16.jsonc",
"pow.jsonc",
"pow_int32.jsonc",
"pow-big-number.jsonc",
Expand Down
16 changes: 12 additions & 4 deletions js/web/test/test-runner.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Float16Array} from '@petamoriken/float16';
import {expect} from 'chai';
import * as ort from 'onnxruntime-common';
import {extname} from 'path';
Expand Down Expand Up @@ -390,6 +391,7 @@ export class TensorResultValidator {
case 'string':
return this.strictEqual(actual.data, expected.data);

case 'float16':
case 'float32':
case 'float64':
return this.floatEqual(
Expand Down Expand Up @@ -890,8 +892,11 @@ async function runProtoOpTestcase(
const fetches: Record<string, Pick<ort.Tensor, 'dims'|'type'>> = {};
testCase.inputs.forEach((input, i) => {
if (input.data) {
let data: number[]|BigUint64Array|BigInt64Array = input.data;
if (input.type === 'uint64') {
let data: number[]|BigUint64Array|BigInt64Array|Uint16Array = input.data;
if (input.type === 'float16') {
const floata16Array = Float16Array.from(input.data);
data = new Uint16Array(floata16Array.buffer, 0, floata16Array.length);
} else if (input.type === 'uint64') {
data = BigUint64Array.from(input.data.map(BigInt));
} else if (input.type === 'int64') {
data = BigInt64Array.from(input.data.map(BigInt));
Expand All @@ -904,8 +909,11 @@ async function runProtoOpTestcase(
const expectedOutputNames: string[] = [];
testCase.outputs.forEach((output, i) => {
if (output.data) {
let data: number[]|BigUint64Array|BigInt64Array = output.data;
if (output.type === 'uint64') {
let data: number[]|BigUint64Array|BigInt64Array|Uint16Array = output.data;
if (output.type === 'float16') {
const floata16Array = Float16Array.from(output.data);
data = new Uint16Array(floata16Array.buffer, 0, floata16Array.length);
} else if (output.type === 'uint64') {
data = BigUint64Array.from(output.data.map(BigInt));
} else if (output.type === 'int64') {
data = BigInt64Array.from(output.data.map(BigInt));
Expand Down