Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] Float16Array polyfill for uniform #19307

Closed
wants to merge 42 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
5a95268
[js] Support Float16Array polyfill
axinging Jan 26, 2024
706edd4
Try conditional float16 polyfill
axinging Jan 29, 2024
5453a7a
NIt
axinging Jan 29, 2024
a8d4d77
Nit, removbe more
axinging Jan 29, 2024
ce46073
rEMOVE ONE MORE
axinging Jan 29, 2024
d6c86a3
Clean one more
axinging Jan 29, 2024
3b49dee
Remove more
axinging Jan 30, 2024
da77265
Move package to devDeps
axinging Jan 30, 2024
4396346
Remove isFloat16Array
axinging Jan 30, 2024
1b853ba
Remove else for f16
axinging Jan 30, 2024
cb3ba6c
format and rename padf16.jsonc
axinging Jan 30, 2024
1480c48
Revert webgl
axinging Jan 30, 2024
d63d39a
Add FLOAT16_TYPE_POLYFILL
axinging Jan 30, 2024
82485df
Fix lint
axinging Jan 30, 2024
f44bd9c
Merge branch 'main' into float16_polyfill_v2
axinging Jan 30, 2024
6d05812
Fix comment
axinging Jan 30, 2024
0d9e517
Support pad v11 or above
axinging Jan 31, 2024
453cfe9
Refine Float16ArrayType
axinging Jan 31, 2024
d281faf
Merge branch 'main' into float16_polyfill_v2
axinging Jan 31, 2024
2fc4f58
Fix tensor.ts build error
axinging Jan 31, 2024
6db94a6
Merge branch 'main' into float16_polyfill_v2
axinging Jan 31, 2024
4ae19d2
Support fallback to uint16
axinging Jan 31, 2024
e9a6cfa
Merge branch 'main' into float16_polyfill_v2
axinging Feb 19, 2024
453b52e
Merge branch 'main' into float16_polyfill_v2
axinging Feb 21, 2024
7e75c14
Remove FLoat16 in test-runner.ts
axinging Feb 21, 2024
861c472
Remove onnxjs change
axinging Feb 21, 2024
0082c71
Remove change in onnxjs
axinging Feb 21, 2024
b6f4cd5
Remove not necessary change in packaje.json
axinging Feb 22, 2024
cecada7
Merge branch 'main' into float16_polyfill_v2
axinging Feb 22, 2024
c9e90b1
Fix globalThis property missing
axinging Feb 22, 2024
10e9d65
Rename pad_f16 and add to test list
axinging Feb 23, 2024
7ab6a86
Merge branch 'main' into float16_polyfill_v2
axinging Feb 23, 2024
3619729
t
axinging Feb 23, 2024
c8c74d2
Fix lint
axinging Feb 23, 2024
4cb32ba
Revet webgl/test-conv-utils.ts
axinging Feb 23, 2024
3020722
Remove globalThis in test
axinging Feb 23, 2024
894c408
Nit, remove noVar
axinging Feb 23, 2024
aa474a8
Not float16 in test
axinging Feb 23, 2024
6795713
Remove float16 in common test
axinging Feb 26, 2024
432f311
use float16 in test
axinging Feb 27, 2024
44ee23f
Merge branch 'main' into float16_polyfill_v2
axinging Feb 27, 2024
a239ce2
Refine comment msg
axinging Feb 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions js/common/lib/tensor-impl-type-mapping.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import {Tensor} from './tensor.js';

export type SupportedTypedArrayConstructors = Float32ArrayConstructor|Uint8ArrayConstructor|Int8ArrayConstructor|
Uint16ArrayConstructor|Int16ArrayConstructor|Int32ArrayConstructor|BigInt64ArrayConstructor|Uint8ArrayConstructor|
Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor;
export type SupportedTypedArrayConstructors = Float32ArrayConstructor|Uint8ArrayConstructor|
Int8ArrayConstructor|Uint16ArrayConstructor|Int16ArrayConstructor|Int32ArrayConstructor|BigInt64ArrayConstructor|
Uint8ArrayConstructor|Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor;
export type SupportedTypedArray = InstanceType<SupportedTypedArrayConstructors>;

// a runtime map that maps type string to TypedArray constructor. Should match Tensor.DataTypeMap.
Expand All @@ -14,7 +14,6 @@ export const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map<string, SupportedTy
['uint8', Uint8Array],
['int8', Int8Array],
['uint16', Uint16Array],
['float16', Uint16Array],
['int16', Int16Array],
['int32', Int32Array],
['bool', Uint8Array],
Expand All @@ -34,16 +33,23 @@ export const NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP = new Map<SupportedTypedArray
[Uint32Array, 'uint32'],
]);

// the following code allows delaying execution of BigInt checking. This allows lazy initialization for
// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt polyfill
// if available.
let isBigIntChecked = false;
export const checkBigInt = () => {
if (!isBigIntChecked) {
isBigIntChecked = true;
const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && typeof BigInt64Array.from === 'function';
const isBigUint64ArrayAvailable =
typeof BigUint64Array !== 'undefined' && typeof BigUint64Array.from === 'function';
// a dummy type declaration for Float16Array in case any polyfill is available.
declare global {
// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any
const Float16Array: any;
const isFloat16Array: any;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isFloat16Array is not going to be polyfilled in the global scope. we can still use a instanceof Float16Array to perform the check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}

// the following code allows delaying execution of BigInt/Float16Array checking. This allows lazy initialization for
// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt/Float16Array
// polyfill if available.
let isTypedArrayChecked = false;
export const checkTypedArray = () => {
if (!isTypedArrayChecked) {
isTypedArrayChecked = true;
const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && BigInt64Array.from;
const isBigUint64ArrayAvailable = typeof BigUint64Array !== 'undefined' && BigUint64Array.from;
const isFloat16ArrayAvailable = typeof Float16Array !== 'undefined' && Float16Array.from;

if (isBigInt64ArrayAvailable) {
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('int64', BigInt64Array);
Expand All @@ -53,5 +59,12 @@ export const checkBigInt = () => {
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('uint64', BigUint64Array);
NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigUint64Array, 'uint64');
}
if (isFloat16ArrayAvailable) {
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Float16Array);
NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(Float16Array, 'float16');
} else {
// if Float16Array is not available, use 'Uint16Array' to store the data.
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Uint16Array);
}
}
};
12 changes: 7 additions & 5 deletions js/common/lib/tensor-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {tensorToDataURL, tensorToImageData} from './tensor-conversion-impl.js';
import {TensorToDataUrlOptions, TensorToImageDataOptions} from './tensor-conversion.js';
import {tensorFromGpuBuffer, tensorFromImage, tensorFromPinnedBuffer, tensorFromTexture} from './tensor-factory-impl.js';
import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters} from './tensor-factory.js';
import {checkBigInt, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js';
import {checkTypedArray, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js';
import {calculateSize, tensorReshape} from './tensor-utils-impl.js';
import {Tensor as TensorInterface} from './tensor.js';

Expand Down Expand Up @@ -67,8 +67,8 @@ export class Tensor implements TensorInterface {
arg0: TensorType|TensorDataType|readonly string[]|readonly boolean[]|CpuPinnedConstructorParameters|
TextureConstructorParameters|GpuBufferConstructorParameters,
arg1?: TensorDataType|readonly number[]|readonly string[]|readonly boolean[], arg2?: readonly number[]) {
// perform one-time check for BigInt support
checkBigInt();
// perform one-time check for BigInt/Float16Array support
checkTypedArray();

let type: TensorType;
let dims: readonly number[];
Expand Down Expand Up @@ -146,8 +146,8 @@ export class Tensor implements TensorInterface {
// Throw error here because when user try to use number array as data,
// e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call
// Uint16Array.from(arg1) which generates wrong data.
throw new TypeError(
'Creating a float16 tensor from number array is not supported. Please use Uint16Array as data.');
// eslint-disable-next-line @typescript-eslint/no-explicit-any
data = (typedArrayConstructor as any).from(arg1);
Copy link
Contributor

@fs-eire fs-eire Jan 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as explained in the old error message, you cannot do typedArrayConstructor.from because it does not work with Uint16Array. you need to check if typedArrayConstructor is Float16Array. You can just use the code from that of my PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks!

} else if (arg0 === 'uint64' || arg0 === 'int64') {
// use 'as any' here because:
// 1. TypeScript's check on type of 'Array.isArray()' does not work with readonly arrays.
Expand All @@ -168,6 +168,8 @@ export class Tensor implements TensorInterface {
}
} else if (arg1 instanceof typedArrayConstructor) {
data = arg1;
} else if (isFloat16Array !== undefined && isFloat16Array(arg1)) {
data = arg1 as InstanceType<typeof Float16Array>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the code actually call into this if branch? I assume that map NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP should already set with key 'float16', so it should be handled in condition (arg1 instanceof typedArrayConstructor)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is not necessary now. Thanks!

} else {
throw new TypeError(`A ${type} tensor's data must be type of ${typedArrayConstructor}`);
}
Expand Down
1 change: 1 addition & 0 deletions js/common/lib/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ interface TypedTensorBase<T extends Tensor.Type> {
dispose(): void;
}

// type Float16ArrayType = InstanceType<typeof Float16Array>;
export declare namespace Tensor {
interface DataTypeMap {
float32: Float32Array;
Expand Down
5 changes: 4 additions & 1 deletion js/common/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,8 @@
"ONNXRuntime",
"ONNX Runtime"
],
"description": "ONNXRuntime JavaScript API library"
"description": "ONNXRuntime JavaScript API library",
"dependencies": {
"@petamoriken/float16": "^3.8.4"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add as devDependencies

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
}
3 changes: 2 additions & 1 deletion js/common/test/unit-tests/common.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Float16Array} from '@petamoriken/float16';
import assert from 'assert/strict';
import {Tensor} from 'onnxruntime-common';

Expand Down Expand Up @@ -34,7 +35,7 @@ export const BIGINT_TYPES = [
/**
* float16 type, data represented by Uint16Array
*/
export const FLOAT16_TYPE = ['float16', Uint16Array, false] as const;
export const FLOAT16_TYPE = ['float16', Float16Array, false] as const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest not to update this part. Both paths (with and without polyfill) need to be tested. to test f16 polyfill, may need to add a separated target.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


/**
* A list of all numerical types.
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/onnxjs/backends/webgl/ops/pad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ const validateInputsV2 = (inputs: Tensor[]): void => {
if (!inputs || inputs.length !== 1) {
throw new Error('Pad requires 1 input');
}
if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
if (inputs[0].type !== 'float16' && inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
throw new Error('Invalid input type.');
}
};
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/onnxjs/backends/webgl/ops/resize-packed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ const prepareInputs = (inputs: Tensor[], attributes: UpsampleAttributes): [reado

const parseScalesData = (scale: Tensor, mode: string, isResize: boolean): number[] => {
const scales = Array.from(scale.floatData);
scalesValidation(scales, mode, isResize);
return scales;
scalesValidation(scales as number[], mode, isResize);
return scales as number[];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can keep onnxjs code unchanged as they are going to deprecated. we don't plan to support new features to webgl.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this may complains: $ npm run build

[email protected] prebuild
tsc -p . --noEmit && tsc -p lib/wasm/proxy-worker --noEmit

lib/onnxjs/backends/webgl/ops/resize-packed.ts:244:20 - error TS2345: Argument of type 'unknown[]' is not assignable to parameter of type 'number[]'.
Type 'unknown' is not assignable to type 'number'.

244 scalesValidation(scales, mode, isResize);
~~~~~~

lib/onnxjs/backends/webgl/ops/resize-packed.ts:245:3 - error TS2322: Type 'unknown[]' is not assignable to type 'number[]'.

245 return scales;
~~~~~~

Found 2 errors in the same file, starting at: lib/onnxjs/backends/webgl/ops/resize-packed.ts:244

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me try to make some modifications to unit test so that it no longer depends on onnxjs/tensor. Then you don't need to modify any file under onnxjs folder .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#19358 should work with it

};

const parseScalesDataFromOutputSize =
Expand Down
10 changes: 8 additions & 2 deletions js/web/lib/onnxjs/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ import {decodeUtf8String, ProtoUtil, ShapeUtil} from './util';

import ortFbs = onnxruntime.experimental.fbs;

type Float16ArrayType = InstanceType<typeof Float16Array>;
export declare namespace Tensor {
export interface DataTypeMap {
bool: Uint8Array;
float16: Float16ArrayType;
float32: Float32Array;
float64: Float64Array;
string: string[];
Expand All @@ -31,7 +33,7 @@ export declare namespace Tensor {
export type BooleanType = Tensor.DataTypeMap['bool'];
export type IntegerType = Tensor.DataTypeMap['int8']|Tensor.DataTypeMap['uint8']|Tensor.DataTypeMap['int16']|
Tensor.DataTypeMap['uint16']|Tensor.DataTypeMap['int32']|Tensor.DataTypeMap['uint32'];
export type FloatType = Tensor.DataTypeMap['float32']|Tensor.DataTypeMap['float64'];
export type FloatType = Tensor.DataTypeMap['float16']|Tensor.DataTypeMap['float32']|Tensor.DataTypeMap['float64'];
export type NumberType = BooleanType|IntegerType|FloatType;

export type Id = Guid;
Expand Down Expand Up @@ -93,6 +95,7 @@ export class Tensor {
*/
get floatData() {
switch (this.type) {
case 'float16':
case 'float32':
case 'float64':
return this.data as Tensor.FloatType;
Expand Down Expand Up @@ -188,7 +191,7 @@ export class Tensor {
} else {
if (cache !== undefined) {
const constructor = dataviewConstructor(type);
if (!(cache instanceof constructor)) {
if (!(cache instanceof constructor) && !isFloat16Array(cache)) {
throw new TypeError(`cache should be type ${constructor.name}`);
}
}
Expand Down Expand Up @@ -357,6 +360,7 @@ function sizeof(type: Tensor.DataType): number {
return 1;
case 'int16':
case 'uint16':
case 'float16':
return 2;
case 'int32':
case 'uint32':
Expand Down Expand Up @@ -412,6 +416,8 @@ function dataviewConstructor(type: Tensor.DataType) {
return Uint32Array;
case 'int64':
return BigInt64Array;
case 'float16':
return Float16Array; //typeof Float16Array !== 'undefined' && Float16Array.from ? Float16Array : Uint16Array;;
case 'float32':
return Float32Array;
case 'float64':
Expand Down
16 changes: 12 additions & 4 deletions js/web/lib/wasm/wasm-common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

import {Tensor} from 'onnxruntime-common';

// a dummy type declaration for Float16Array in case any polyfill is available.
declare global {
var Float16Array: any;
var isFloat16Array: any;
}

// This file includes common definitions. They do NOT have dependency on the WebAssembly instance.

/**
Expand Down Expand Up @@ -112,12 +118,14 @@ export const getTensorElementSize = (dateType: number): number|
/**
* get typed array constructor by the given tensor type
*/
export const tensorTypeToTypedArrayConstructor = (type: Tensor.Type): Float32ArrayConstructor|Uint8ArrayConstructor|
Int8ArrayConstructor|Uint16ArrayConstructor|Int16ArrayConstructor|Int32ArrayConstructor|BigInt64ArrayConstructor|
Uint8ArrayConstructor|Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor => {
export const tensorTypeToTypedArrayConstructor =
(type: Tensor.Type): Float32ArrayConstructor|Uint8ArrayConstructor|Int8ArrayConstructor|
Uint16ArrayConstructor|Int16ArrayConstructor|Int32ArrayConstructor|BigInt64ArrayConstructor|Uint8ArrayConstructor|
Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor => {
switch (type) {
case 'float16':
return Uint16Array;
// allow Float16Array polyfill.
return typeof Float16Array !== 'undefined' && Float16Array.from ? Float16Array : Uint16Array;;
case 'float32':
return Float32Array;
case 'uint8':
Expand Down
1 change: 1 addition & 0 deletions js/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"version": "1.18.0",
"jsdelivr": "dist/ort.min.js",
"dependencies": {
"@petamoriken/float16": "^3.8.4",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add as devDependencies

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"flatbuffers": "^1.12.0",
"guid-typescript": "^1.0.9",
"long": "^5.2.3",
Expand Down
35 changes: 35 additions & 0 deletions js/web/test/data/ops/padf16.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
[
{
"name": "constant 2D",
"operator": "Pad",
"opset": { "domain": "", "version": 10 },
"attributes": [
{ "name": "mode", "data": "constant", "type": "string" },
{ "name": "value", "data": 1.2, "type": "float" },
{ "name": "pads", "data": [3, 2, 2, 3], "type": "ints" }
],
"cases": [
{
"name": "[2,2]->[7,7]",
"inputs": [
{
"data": [1.0, 2.0, 3.0, 4.0],
"dims": [2, 2],
"type": "float16"
}
],
"outputs": [
{
"data": [
1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2,
1.2, 1.2, 1.0, 2.0, 1.2, 1.2, 1.2, 1.2, 1.2, 3.0, 4.0, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2,
1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2
],
"dims": [7, 7],
"type": "float16"
}
]
}
]
}
]
5 changes: 5 additions & 0 deletions js/web/test/test-main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ import {Logger} from '../lib/onnxjs/instrument';

import {Test} from './test-types';

import {Float16Array, isFloat16Array} from '@petamoriken/float16';

globalThis.Float16Array = Float16Array;
globalThis.isFloat16Array = isFloat16Array;

if (ORT_WEB_TEST_CONFIG.model.some(testGroup => testGroup.tests.some(test => test.backend === 'cpu'))) {
// require onnxruntime-node
require('../../node');
Expand Down
11 changes: 7 additions & 4 deletions js/web/test/test-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {getTensorElementSize, isGpuBufferSupportedType, tensorDataTypeStringToEn

import {base64toBuffer, createMockGraph, readFile} from './test-shared';
import {Test} from './test-types';

type Float16ArrayType = InstanceType<typeof Float16Array>;
// the threshold that used to compare 2 float numbers. See above for TensorResultValidator.floatEqual().
const CPU_THRESHOLD_ABSOLUTE_ERROR = 1.0e-4;
const CPU_THRESHOLD_RELATIVE_ERROR = 1.000001;
Expand Down Expand Up @@ -393,11 +393,12 @@ export class TensorResultValidator {
case 'string':
return this.strictEqual(actual.stringData, expected.stringData);

case 'float16':
case 'float32':
case 'float64':
return this.floatEqual(
actual.numberData as number[] | Float32Array | Float64Array,
expected.numberData as number[] | Float32Array | Float64Array);
actual.numberData as number[] | Float16ArrayType | Float32Array | Float64Array,
expected.numberData as number[] | Float16ArrayType | Float32Array | Float64Array);

case 'uint8':
case 'int8':
Expand Down Expand Up @@ -425,7 +426,9 @@ export class TensorResultValidator {
return false;
}
}
floatEqual(actual: number[]|Float32Array|Float64Array, expected: number[]|Float32Array|Float64Array): boolean {
floatEqual(
actual: number[]|Float16ArrayType|Float32Array|Float64Array,
expected: number[]|Float16ArrayType|Float32Array|Float64Array): boolean {
if (actual.length !== expected.length) {
return false;
}
Expand Down
Loading