diff --git a/tfjs-core/src/tensor_util_env.ts b/tfjs-core/src/tensor_util_env.ts index da71b094592..60fd1272b1e 100644 --- a/tfjs-core/src/tensor_util_env.ts +++ b/tfjs-core/src/tensor_util_env.ts @@ -17,7 +17,7 @@ import {ENGINE} from './engine'; import {env} from './environment'; -import {Tensor} from './tensor'; +import {getGlobalTensorClass, Tensor} from './tensor'; import {DataType, isWebGLData, isWebGPUData, TensorLike, WebGLData, WebGPUData} from './types'; import {assert, flatten, inferDtype, isTypedArray, toTypedArray} from './util'; import {bytesPerElement} from './util_base'; @@ -98,7 +98,7 @@ function assertDtype( export function convertToTensor( x: T|TensorLike, argName: string, functionName: string, parseAsDtype: DataType|'numeric'|'string_or_numeric' = 'numeric'): T { - if (x instanceof Tensor) { + if (x instanceof getGlobalTensorClass()) { assertDtype(parseAsDtype, x.dtype, argName, functionName); return x; }