diff --git a/tfjs-core/src/ops/tensor_ops_util.ts b/tfjs-core/src/ops/tensor_ops_util.ts index 197ccf1e30d..c38783e254c 100644 --- a/tfjs-core/src/ops/tensor_ops_util.ts +++ b/tfjs-core/src/ops/tensor_ops_util.ts @@ -17,7 +17,7 @@ import {ENGINE} from '../engine'; import {Tensor} from '../tensor'; -import {TensorLike, TypedArray, WebGLData, WebGPUData} from '../types'; +import {isWebGLData, isWebGPUData, TensorLike, TypedArray, WebGLData, WebGPUData} from '../types'; import {DataType} from '../types'; import {assert, assertNonNegativeIntegerDimensions, flatten, inferDtype, isTypedArray, sizeFromShape, toTypedArray} from '../util'; @@ -33,16 +33,14 @@ export function makeTensor( `Please use tf.complex(real, imag).`); } - if (typeof values === 'object' && - ('texture' in values || - ('buffer' in values && !(values.buffer instanceof ArrayBuffer)))) { + if (isWebGPUData(values) || isWebGLData(values)) { if (dtype !== 'float32' && dtype !== 'int32') { throw new Error( `Creating tensor from GPU data only supports ` + `'float32'|'int32' dtype, while the dtype is ${dtype}.`); } return ENGINE.backend.createTensorFromGPUData( - values as WebGLData | WebGPUData, shape || inferredShape, dtype); + values, shape || inferredShape, dtype); } if (!isTypedArray(values) && !Array.isArray(values) && diff --git a/tfjs-core/src/tensor_util_env.ts b/tfjs-core/src/tensor_util_env.ts index e7be4297429..da71b094592 100644 --- a/tfjs-core/src/tensor_util_env.ts +++ b/tfjs-core/src/tensor_util_env.ts @@ -18,7 +18,7 @@ import {ENGINE} from './engine'; import {env} from './environment'; import {Tensor} from './tensor'; -import {DataType, TensorLike, WebGLData, WebGPUData} from './types'; +import {DataType, isWebGLData, isWebGPUData, TensorLike, WebGLData, WebGPUData} from './types'; import {assert, flatten, inferDtype, isTypedArray, toTypedArray} from './util'; import {bytesPerElement} from './util_base'; @@ -29,14 +29,12 @@ export function inferShape( if (isTypedArray(val)) { return dtype === 'string' ? [] : [val.length]; } - const isObject = typeof val === 'object'; - if (isObject) { - if ('texture' in val) { - const usedChannels = val.channels || 'RGBA'; - return [val.height, val.width * usedChannels.length]; - } else if ('buffer' in val && !(val.buffer instanceof ArrayBuffer)) { - return [val.buffer.size / (dtype == null ? 4 : bytesPerElement(dtype))]; - } + + if (isWebGLData(val)) { + const usedChannels = val.channels || 'RGBA'; + return [val.height, val.width * usedChannels.length]; + } else if (isWebGPUData(val)) { + return [val.buffer.size / (dtype == null ? 4 : bytesPerElement(dtype))]; } if (!Array.isArray(val)) { return []; // Scalar. diff --git a/tfjs-core/src/types.ts b/tfjs-core/src/types.ts index 2d3fe88ddad..0feeab5d7ff 100644 --- a/tfjs-core/src/types.ts +++ b/tfjs-core/src/types.ts @@ -196,3 +196,17 @@ export interface WebGPUData { buffer: GPUBuffer; zeroCopy?: boolean; } + +export function isWebGLData(values: unknown): values is WebGLData { + return values != null + && typeof values === 'object' + && 'texture' in values + && values.texture instanceof WebGLTexture; +} +export function isWebGPUData(values: unknown): values is WebGPUData { + return typeof GPUBuffer !== 'undefined' + && values != null + && typeof values === 'object' + && 'buffer' in values + && values.buffer instanceof GPUBuffer; +}