From 4373a4529cc2e706b20390cf569bb6c42bad60f8 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 14 Feb 2023 11:34:28 -0800 Subject: [PATCH 01/11] Add StaticRegexReplace --- tfjs-backend-cpu/src/register_all_kernels.ts | 2 ++ tfjs-backend-cpu/src/utils/unary_types.ts | 3 ++- tfjs-backend-cpu/src/utils/unary_utils.ts | 25 +++++++++++++------- tfjs-core/src/base.ts | 2 +- tfjs-core/src/kernel_names.ts | 8 +++++++ tfjs-core/src/ops/ops.ts | 4 +++- tfjs-core/src/types.ts | 6 +++++ 7 files changed, 39 insertions(+), 11 deletions(-) diff --git a/tfjs-backend-cpu/src/register_all_kernels.ts b/tfjs-backend-cpu/src/register_all_kernels.ts index 0d572fee740..1217ad1fcb7 100644 --- a/tfjs-backend-cpu/src/register_all_kernels.ts +++ b/tfjs-backend-cpu/src/register_all_kernels.ts @@ -170,6 +170,7 @@ import {splitVConfig} from './kernels/SplitV'; import {sqrtConfig} from './kernels/Sqrt'; import {squareConfig} from './kernels/Square'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; +import {staticRegexReplaceConfig} from './kernels/StaticRegexReplace'; import {stepConfig} from './kernels/Step'; import {stridedSliceConfig} from './kernels/StridedSlice'; import {stringNGramsConfig} from './kernels/StringNGrams'; @@ -342,6 +343,7 @@ const kernelConfigs: KernelConfig[] = [ sqrtConfig, squareConfig, squaredDifferenceConfig, + staticRegexReplaceConfig, stepConfig, stridedSliceConfig, stringNGramsConfig, diff --git a/tfjs-backend-cpu/src/utils/unary_types.ts b/tfjs-backend-cpu/src/utils/unary_types.ts index 808ba083ff1..ca98bb48151 100644 --- a/tfjs-backend-cpu/src/utils/unary_types.ts +++ b/tfjs-backend-cpu/src/utils/unary_types.ts @@ -17,6 +17,7 @@ import {DataType, NamedAttrMap, TypedArray} from '@tensorflow/tfjs-core'; -export type SimpleUnaryOperation = (x: number, attrs?: NamedAttrMap) => number; +export type SimpleUnaryOperation = (x: I, attrs?: NamedAttrMap) => O; export type SimpleUnaryImpl = (values: TypedArray, dtype: DataType, attrs?: NamedAttrMap) => TypedArray; diff --git a/tfjs-backend-cpu/src/utils/unary_utils.ts b/tfjs-backend-cpu/src/utils/unary_utils.ts index ba2888396d9..34df10b233a 100644 --- a/tfjs-backend-cpu/src/utils/unary_utils.ts +++ b/tfjs-backend-cpu/src/utils/unary_utils.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {DataType, KernelFunc, TypedArray, UnaryInputs, util} from '@tensorflow/tfjs-core'; +import {backend_util, DataType, KernelFunc, TypedArray, UnaryInputs, util, DataTypeMap, DataTypeFor} from '@tensorflow/tfjs-core'; import {MathBackendCPU} from '../backend_cpu'; import {assertNotComplex} from '../cpu_util'; @@ -30,22 +30,31 @@ import {SimpleUnaryImpl, SimpleUnaryOperation} from './unary_types'; * result has the same dtype as the input. This is mainly used in certain * kernels that return bool type, such as isFinite, isInf, etc. */ -export function unaryKernelFunc( - name: string, op: SimpleUnaryOperation, dtype?: DataType): KernelFunc { +export function unaryKernelFunc( + name: string, op: SimpleUnaryOperation, dtype?: DataTypeFor): KernelFunc { return ({inputs, attrs, backend}) => { const {x} = inputs as UnaryInputs; assertNotComplex(x, name); - if (x.dtype === 'string' || dtype === 'string') { - throw new Error('unaryKernelFunc does not support string input/output'); - } const cpuBackend = backend as MathBackendCPU; - const values = cpuBackend.data.get(x.dataId).values as TypedArray; + let values = cpuBackend.data.get(x.dataId).values; + let decoded: DataTypeMap[keyof DataTypeMap]; + if (values instanceof Array) { + if (x.dtype !== 'string') { + throw new Error(`Tensor ${x} data contains an array of values but its ` + + `dtype is ${x.dtype} instead of 'string'`); + } + decoded = backend_util.fromUint8ToStringArray(values); + } else { + decoded = values; + } + const xSize = util.sizeFromShape(x.shape); const $dtype = dtype || x.dtype; const newValues = util.getArrayFromDType($dtype, xSize); for (let i = 0; i < xSize; ++i) { - newValues[i] = op(values[i], attrs); + newValues[i] = op(decoded[i] as I, attrs); } return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues); }; diff --git a/tfjs-core/src/base.ts b/tfjs-core/src/base.ts index e869ede5d25..bda9e5657a6 100644 --- a/tfjs-core/src/base.ts +++ b/tfjs-core/src/base.ts @@ -55,7 +55,7 @@ export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer'; export {SGDOptimizer} from './optimizers/sgd_optimizer'; export {DataToGPUOptions, DataToGPUWebGLOption, GPUData, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, Variable} from './tensor'; export {GradSaveFunc, NamedTensorMap, TensorContainer, TensorContainerArray, TensorContainerObject} from './tensor_types'; -export {BackendValues, DataType, DataTypeMap, DataValues, NumericDataType, PixelData, Rank, RecursiveArray, ScalarLike, ShapeMap, sumOutType, TensorLike, TypedArray, upcastType, WebGLData, WebGPUData} from './types'; +export {BackendValues, DataType, DataTypeMap, DataTypeFor, DataValues, NumericDataType, PixelData, Rank, RecursiveArray, ScalarLike, ShapeMap, sumOutType, TensorLike, TypedArray, upcastType, WebGLData, WebGPUData} from './types'; export * from './ops/ops'; export {Reduction} from './ops/loss_ops_utils'; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 861940fd48d..8b9da6660b9 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -852,6 +852,14 @@ export type SquaredDifferenceInputs = BinaryInputs; export const Square = 'Square'; export type SquareInputs = Pick; +export const StaticRegexReplace = 'StaticRegexReplace'; +export type StaticRegexReplaceInputs = UnaryInputs; +export interface StaticRegexReplaceAttrs { + pattern: string; + rewrite: string; + replaceGlobal: boolean; +} + export const StridedSlice = 'StridedSlice'; export type StridedSliceInputs = Pick; export interface StridedSliceAttrs { diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 5ab2f501951..bc7fd75fb51 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -325,11 +325,13 @@ const sparse = { import {stringNGrams} from './string/string_n_grams'; import {stringSplit} from './string/string_split'; import {stringToHashBucketFast} from './string/string_to_hash_bucket_fast'; +import {staticRegexReplace} from './string/static_regex_replace'; // tslint:disable-next-line:variable-name const string = { stringNGrams, stringSplit, - stringToHashBucketFast + stringToHashBucketFast, + staticRegexReplace, }; // Second level exports. diff --git a/tfjs-core/src/types.ts b/tfjs-core/src/types.ts index 0feeab5d7ff..b104ee70bcc 100644 --- a/tfjs-core/src/types.ts +++ b/tfjs-core/src/types.ts @@ -56,6 +56,12 @@ export interface SingleValueMap { /** @docalias 'float32'|'int32'|'bool'|'complex64'|'string' */ export type DataType = keyof DataTypeMap; export type NumericDataType = 'float32'|'int32'|'bool'|'complex64'; + +export type DataTypeFor = + T extends number | boolean ? NumericDataType : + T extends string ? 'string' : + never; + export type TypedArray = Float32Array|Int32Array|Uint8Array; /** Tensor data used in tensor creation and user-facing API. */ export type DataValues = DataTypeMap[DataType]; From 79122a39a19b8ad3a2f6295a1299cc5464ae0ccf Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 14 Feb 2023 14:04:35 -0800 Subject: [PATCH 02/11] Simplify unary kernel helper functions --- tfjs-backend-cpu/src/utils/unary_impl.ts | 11 ++-- tfjs-backend-cpu/src/utils/unary_types.ts | 9 ++-- tfjs-backend-cpu/src/utils/unary_utils.ts | 54 ++++++++----------- .../src/kernel_utils/kernel_funcs_utils.ts | 4 +- tfjs-core/src/types.ts | 6 +-- tfjs-core/src/util_base.ts | 14 +---- tsconfig.json | 3 ++ 7 files changed, 44 insertions(+), 57 deletions(-) diff --git a/tfjs-backend-cpu/src/utils/unary_impl.ts b/tfjs-backend-cpu/src/utils/unary_impl.ts index 9313b8512c3..5d4cff9a81f 100644 --- a/tfjs-backend-cpu/src/utils/unary_impl.ts +++ b/tfjs-backend-cpu/src/utils/unary_impl.ts @@ -15,20 +15,21 @@ * ============================================================================= */ -import {NumericDataType, util} from '@tensorflow/tfjs-core'; +import {util} from '@tensorflow/tfjs-core'; import {SimpleUnaryImpl, SimpleUnaryOperation} from './unary_types'; /** * Template that creates implementation for unary op. */ -export function createSimpleUnaryImpl(op: SimpleUnaryOperation): - SimpleUnaryImpl { +export function createSimpleUnaryImpl(op: SimpleUnaryOperation): + SimpleUnaryImpl { return (values, dtype, attrs) => { const newValues = - util.getTypedArrayFromDType(dtype as NumericDataType, values.length); + util.getArrayFromDType(dtype, values.length); for (let i = 0; i < values.length; ++i) { - newValues[i] = op(values[i], attrs); + newValues[i] = op(values[i] as I, attrs); } return newValues; }; diff --git a/tfjs-backend-cpu/src/utils/unary_types.ts b/tfjs-backend-cpu/src/utils/unary_types.ts index ca98bb48151..96aab70dd10 100644 --- a/tfjs-backend-cpu/src/utils/unary_types.ts +++ b/tfjs-backend-cpu/src/utils/unary_types.ts @@ -15,9 +15,12 @@ * ============================================================================= */ -import {DataType, NamedAttrMap, TypedArray} from '@tensorflow/tfjs-core'; +import { DataTypeFor, DataTypeMap, NamedAttrMap } from '@tensorflow/tfjs-core'; export type SimpleUnaryOperation = (x: I, attrs?: NamedAttrMap) => O; -export type SimpleUnaryImpl = - (values: TypedArray, dtype: DataType, attrs?: NamedAttrMap) => TypedArray; + +export type SimpleUnaryImpl = + (values: ArrayLike, dtype: DataTypeFor, + attrs?: NamedAttrMap) => DataTypeMap[DataTypeFor] diff --git a/tfjs-backend-cpu/src/utils/unary_utils.ts b/tfjs-backend-cpu/src/utils/unary_utils.ts index 34df10b233a..aacf627161b 100644 --- a/tfjs-backend-cpu/src/utils/unary_utils.ts +++ b/tfjs-backend-cpu/src/utils/unary_utils.ts @@ -15,10 +15,11 @@ * ============================================================================= */ -import {backend_util, DataType, KernelFunc, TypedArray, UnaryInputs, util, DataTypeMap, DataTypeFor} from '@tensorflow/tfjs-core'; +import {backend_util, DataTypeFor, KernelFunc, UnaryInputs} from '@tensorflow/tfjs-core'; import {MathBackendCPU} from '../backend_cpu'; import {assertNotComplex} from '../cpu_util'; +import {createSimpleUnaryImpl} from './unary_impl'; import {SimpleUnaryImpl, SimpleUnaryOperation} from './unary_types'; @@ -33,31 +34,10 @@ import {SimpleUnaryImpl, SimpleUnaryOperation} from './unary_types'; export function unaryKernelFunc( name: string, op: SimpleUnaryOperation, dtype?: DataTypeFor): KernelFunc { - return ({inputs, attrs, backend}) => { - const {x} = inputs as UnaryInputs; - assertNotComplex(x, name); - const cpuBackend = backend as MathBackendCPU; - let values = cpuBackend.data.get(x.dataId).values; - let decoded: DataTypeMap[keyof DataTypeMap]; - if (values instanceof Array) { - if (x.dtype !== 'string') { - throw new Error(`Tensor ${x} data contains an array of values but its ` - + `dtype is ${x.dtype} instead of 'string'`); - } - decoded = backend_util.fromUint8ToStringArray(values); - } else { - decoded = values; - } + const impl = createSimpleUnaryImpl(op); - const xSize = util.sizeFromShape(x.shape); - const $dtype = dtype || x.dtype; - const newValues = util.getArrayFromDType($dtype, xSize); - for (let i = 0; i < xSize; ++i) { - newValues[i] = op(decoded[i] as I, attrs); - } - return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues); - }; + return unaryKernelFuncFromImpl(name, impl, dtype); } /** @@ -69,19 +49,29 @@ export function unaryKernelFunc( + name: string, unaryImpl: SimpleUnaryImpl, dtype?: DataTypeFor): KernelFunc { return ({inputs, attrs, backend}) => { const {x} = inputs as UnaryInputs; assertNotComplex(x, name); - if (x.dtype === 'string' || dtype === 'string') { - throw new Error('unaryKernelFunc does not support string input/output'); - } const cpuBackend = backend as MathBackendCPU; - const values = cpuBackend.data.get(x.dataId).values as TypedArray; - const $dtype = dtype || x.dtype; - const newValues = unaryImpl(values, $dtype, attrs); + let values = cpuBackend.data.get(x.dataId).values; + let decoded: {[index: number]: I, length: number}//DataTypeMap[keyof DataTypeMap]; + if (values instanceof Array) { + if (x.dtype !== 'string') { + throw new Error(`Tensor ${x} data contains an array of values but its ` + + `dtype is ${x.dtype} instead of 'string'`); + } + decoded = backend_util.fromUint8ToStringArray(values) as unknown as + ArrayLike; + } else { + decoded = values as unknown as ArrayLike; + } + + const $dtype = dtype || x.dtype as DataTypeFor; + const newValues = unaryImpl(decoded, $dtype, attrs); return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues); }; } diff --git a/tfjs-backend-webgl/src/kernel_utils/kernel_funcs_utils.ts b/tfjs-backend-webgl/src/kernel_utils/kernel_funcs_utils.ts index a74d7c2f503..73e0730727a 100644 --- a/tfjs-backend-webgl/src/kernel_utils/kernel_funcs_utils.ts +++ b/tfjs-backend-webgl/src/kernel_utils/kernel_funcs_utils.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {backend_util, BinaryInputs, DataType, env, KernelFunc, TypedArray, UnaryInputs, upcastType} from '@tensorflow/tfjs-core'; +import { backend_util, BinaryInputs, DataType, env, KernelFunc, TypedArray, UnaryInputs, upcastType} from '@tensorflow/tfjs-core'; import {MathBackendWebGL} from '../backend_webgl'; import {BinaryOpProgram} from '../binaryop_gpu'; @@ -36,7 +36,7 @@ type UnaryKernelFuncConfig = { opSnippet: string, packedOpSnippet?: string, cpuKernelImpl?: SimpleUnaryKernelImplCPU, - dtype?: DataType + dtype?: DataType, }; /** diff --git a/tfjs-core/src/types.ts b/tfjs-core/src/types.ts index b104ee70bcc..a31b60a013a 100644 --- a/tfjs-core/src/types.ts +++ b/tfjs-core/src/types.ts @@ -204,8 +204,8 @@ export interface WebGPUData { } export function isWebGLData(values: unknown): values is WebGLData { - return values != null - && typeof values === 'object' + return values != null + && typeof values === 'object' && 'texture' in values && values.texture instanceof WebGLTexture; } @@ -213,6 +213,6 @@ export function isWebGPUData(values: unknown): values is WebGPUData { return typeof GPUBuffer !== 'undefined' && values != null && typeof values === 'object' - && 'buffer' in values + && 'buffer' in values && values.buffer instanceof GPUBuffer; } diff --git a/tfjs-core/src/util_base.ts b/tfjs-core/src/util_base.ts index 1d81a0ab773..48f2a2e067d 100644 --- a/tfjs-core/src/util_base.ts +++ b/tfjs-core/src/util_base.ts @@ -412,17 +412,7 @@ export function squeezeShape(shape: number[], axis?: number[]): export function getTypedArrayFromDType( dtype: D, size: number): DataTypeMap[D] { - let values = null; - if (dtype == null || dtype === 'float32') { - values = new Float32Array(size); - } else if (dtype === 'int32') { - values = new Int32Array(size); - } else if (dtype === 'bool') { - values = new Uint8Array(size); - } else { - throw new Error(`Unknown data type ${dtype}`); - } - return values as DataTypeMap[D]; + return getArrayFromDType(dtype, size); } export function getArrayFromDType( @@ -435,7 +425,7 @@ export function getArrayFromDType( } else if (dtype === 'bool') { values = new Uint8Array(size); } else if (dtype === 'string') { - values = new Array<'string'>(size); + values = new Array(size); } else { throw new Error(`Unknown data type ${dtype}`); } diff --git a/tsconfig.json b/tsconfig.json index 04b43755d47..b5974d9ad87 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -44,6 +44,9 @@ ], "@tensorflow/tfjs-core/dist/*": [ "./tfjs-core/src/*", + ], + "@tensorflow/tfjs-backend-cpu/dist/*": [ + "./tfjs-backend-cpu/src/*", ] } } From 4352c2840cc07fb1096fb19ba24b3d577b51c1d3 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 14 Feb 2023 14:05:29 -0800 Subject: [PATCH 03/11] StaticRegexReplace implementation --- .../src/kernels/StaticRegexReplace.ts | 36 +++++++++++++++++ .../src/ops/string/static_regex_replace.ts | 37 ++++++++++++++++++ .../ops/string/static_regex_replace_test.ts | 39 +++++++++++++++++++ 3 files changed, 112 insertions(+) create mode 100644 tfjs-backend-cpu/src/kernels/StaticRegexReplace.ts create mode 100644 tfjs-core/src/ops/string/static_regex_replace.ts create mode 100644 tfjs-core/src/ops/string/static_regex_replace_test.ts diff --git a/tfjs-backend-cpu/src/kernels/StaticRegexReplace.ts b/tfjs-backend-cpu/src/kernels/StaticRegexReplace.ts new file mode 100644 index 00000000000..59e825fae1b --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/StaticRegexReplace.ts @@ -0,0 +1,36 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, StaticRegexReplace, StaticRegexReplaceAttrs} from '@tensorflow/tfjs-core'; +import {createSimpleUnaryImpl} from '../utils/unary_impl'; +import {unaryKernelFuncFromImpl} from '../utils/unary_utils'; + +const staticRegexReplaceImpl = createSimpleUnaryImpl((x: string, attrs) => { + const {pattern, replaceGlobal, rewrite} = attrs as unknown as StaticRegexReplaceAttrs; + // TODO(mattSoulanille): Don't create a regex each time. + debugger; + return x.replace(new RegExp(pattern, replaceGlobal ? 'g' : ''), rewrite); +}); + +const staticRegexReplace = + unaryKernelFuncFromImpl(StaticRegexReplace, staticRegexReplaceImpl); + +export const staticRegexReplaceConfig: KernelConfig = { + kernelName: StaticRegexReplace, + backendName: 'cpu', + kernelFunc: staticRegexReplace, +}; diff --git a/tfjs-core/src/ops/string/static_regex_replace.ts b/tfjs-core/src/ops/string/static_regex_replace.ts new file mode 100644 index 00000000000..24c6ffdb750 --- /dev/null +++ b/tfjs-core/src/ops/string/static_regex_replace.ts @@ -0,0 +1,37 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../../engine'; +import {StaticRegexReplace, StaticRegexReplaceAttrs} from '../../kernel_names'; +import {NamedAttrMap} from '../../kernel_registry'; +import {Tensor} from '../../tensor'; +import {convertToTensor} from '../../tensor_util_env'; +import {TensorLike} from '../../types'; +import {op} from '../operation'; + +function staticRegexReplace_( + input: Tensor | TensorLike, pattern: string, rewrite: string, + replaceGlobal=true): Tensor { + + const $input = convertToTensor(input, 'input', 'staticRegexReplace', + 'string'); + const attrs: StaticRegexReplaceAttrs = {pattern, rewrite, replaceGlobal}; + return ENGINE.runKernel(StaticRegexReplace, {x: $input}, + attrs as unknown as NamedAttrMap); +} + +export const staticRegexReplace = /* @__PURE__ */ op({staticRegexReplace_}); diff --git a/tfjs-core/src/ops/string/static_regex_replace_test.ts b/tfjs-core/src/ops/string/static_regex_replace_test.ts new file mode 100644 index 00000000000..563a37a06e1 --- /dev/null +++ b/tfjs-core/src/ops/string/static_regex_replace_test.ts @@ -0,0 +1,39 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../../index'; +import { DataTypeFor } from '../../index'; +import {ALL_ENVS, describeWithFlags} from '../../jasmine_util'; + +describeWithFlags('staticRegexReplace', ALL_ENVS, () => { + + it('replaces the first instance of a regex pattern', async () => { + const result = tf.string.staticRegexReplace( + ['this', 'is', 'a', 'test test'], 'test', 'result', false); + + expect(await result.data>()) + .toEqual(['this', 'is', 'a', 'result test']); + }); + + it('replaces a regex pattern globally', async () => { + const result = tf.string.staticRegexReplace( + ['this', 'is', 'a', 'test test'], 'test', 'result', true); + + expect(await result.data>()) + .toEqual(['this', 'is', 'a', 'result result']); + }); +}); From c8a3cbb544dfef93a5cc309bf09df1ff64779f9c Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 14 Feb 2023 14:34:09 -0800 Subject: [PATCH 04/11] Forward StaticRegexReplace on webgl to cpu --- .../src/kernels/StaticRegexReplace.ts | 7 +-- tfjs-backend-cpu/src/shared.ts | 1 + tfjs-backend-webgl/src/kernel_utils/shared.ts | 2 + .../src/kernels/StaticRegexReplace.ts | 46 +++++++++++++++++++ 4 files changed, 53 insertions(+), 3 deletions(-) create mode 100644 tfjs-backend-webgl/src/kernels/StaticRegexReplace.ts diff --git a/tfjs-backend-cpu/src/kernels/StaticRegexReplace.ts b/tfjs-backend-cpu/src/kernels/StaticRegexReplace.ts index 59e825fae1b..5e8c51d5a55 100644 --- a/tfjs-backend-cpu/src/kernels/StaticRegexReplace.ts +++ b/tfjs-backend-cpu/src/kernels/StaticRegexReplace.ts @@ -19,10 +19,11 @@ import {KernelConfig, StaticRegexReplace, StaticRegexReplaceAttrs} from '@tensor import {createSimpleUnaryImpl} from '../utils/unary_impl'; import {unaryKernelFuncFromImpl} from '../utils/unary_utils'; -const staticRegexReplaceImpl = createSimpleUnaryImpl((x: string, attrs) => { - const {pattern, replaceGlobal, rewrite} = attrs as unknown as StaticRegexReplaceAttrs; +export const staticRegexReplaceImpl = createSimpleUnaryImpl((x: string, attrs) => { + const {pattern, replaceGlobal, rewrite} = + attrs as unknown as StaticRegexReplaceAttrs; // TODO(mattSoulanille): Don't create a regex each time. - debugger; return x.replace(new RegExp(pattern, replaceGlobal ? 'g' : ''), rewrite); }); diff --git a/tfjs-backend-cpu/src/shared.ts b/tfjs-backend-cpu/src/shared.ts index 515376b1f8b..6a50d6884d3 100644 --- a/tfjs-backend-cpu/src/shared.ts +++ b/tfjs-backend-cpu/src/shared.ts @@ -55,6 +55,7 @@ export {sparseReshapeImpl} from './kernels/SparseReshape_impl'; export {sparseSegmentReductionImpl} from './kernels/SparseSegmentReduction_impl'; export {sqrtImpl} from './kernels/Sqrt'; export {squaredDifferenceImpl} from './kernels/SquaredDifference'; +export {staticRegexReplaceImpl} from './kernels/StaticRegexReplace'; export {stridedSliceImpl} from './kernels/StridedSlice_impl'; export {stringNGramsImpl} from './kernels/StringNGrams_impl'; export {stringSplitImpl} from './kernels/StringSplit_impl'; diff --git a/tfjs-backend-webgl/src/kernel_utils/shared.ts b/tfjs-backend-webgl/src/kernel_utils/shared.ts index e8c3111aa9b..b9a7a5a59f2 100644 --- a/tfjs-backend-webgl/src/kernel_utils/shared.ts +++ b/tfjs-backend-webgl/src/kernel_utils/shared.ts @@ -66,6 +66,7 @@ const { sparseReshapeImpl: sparseReshapeImplCPU, sparseSegmentReductionImpl: sparseSegmentReductionImplCPU, sqrtImpl: sqrtImplCPU, + staticRegexReplaceImpl: staticRegexReplaceImplCPU, stridedSliceImpl: stridedSliceImplCPU, stringNGramsImpl: stringNGramsImplCPU, stringSplitImpl: stringSplitImplCPU, @@ -114,6 +115,7 @@ export { sparseReshapeImplCPU, sparseSegmentReductionImplCPU, sqrtImplCPU, + staticRegexReplaceImplCPU, stridedSliceImplCPU, stringNGramsImplCPU, stringSplitImplCPU, diff --git a/tfjs-backend-webgl/src/kernels/StaticRegexReplace.ts b/tfjs-backend-webgl/src/kernels/StaticRegexReplace.ts new file mode 100644 index 00000000000..6062306378c --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/StaticRegexReplace.ts @@ -0,0 +1,46 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, KernelConfig, KernelFunc, NamedAttrMap, StaticRegexReplace, StaticRegexReplaceAttrs, StaticRegexReplaceInputs, TensorInfo} from '@tensorflow/tfjs-core'; +import {MathBackendWebGL} from '../backend_webgl'; +import {staticRegexReplaceImplCPU} from '../kernel_utils/shared'; + +export function staticRegexReplace(args: { + inputs: StaticRegexReplaceInputs, + backend: MathBackendWebGL, + attrs: StaticRegexReplaceAttrs, +}): TensorInfo { + const {inputs, backend, attrs} = args; + const {x} = inputs; + + if (x.dtype !== 'string') { + throw new Error('Input must be of datatype string'); + } + + const $x = backend.readSync(x.dataId) as Uint8Array[]; + + const stringInput = backend_util.fromUint8ToStringArray($x); + const output = staticRegexReplaceImplCPU(stringInput, 'string', attrs as unknown as NamedAttrMap); + + return backend.makeTensorInfo(x.shape, 'string', output); +} + +export const staticRegexReplaceConfig: KernelConfig = { + kernelName: StaticRegexReplace, + backendName: 'webgl', + kernelFunc: staticRegexReplace as unknown as KernelFunc, +} From c0067a13abd30552f8b5c7700d602cda07a3ae75 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 14 Feb 2023 15:00:27 -0800 Subject: [PATCH 05/11] Improve tests. Write doc comment --- tfjs-backend-cpu/src/utils/unary_types.ts | 2 +- .../src/register_all_kernels.ts | 2 ++ .../src/ops/string/static_regex_replace.ts | 19 +++++++++++++++++++ .../ops/string/static_regex_replace_test.ts | 15 +++++++++++---- 4 files changed, 33 insertions(+), 5 deletions(-) diff --git a/tfjs-backend-cpu/src/utils/unary_types.ts b/tfjs-backend-cpu/src/utils/unary_types.ts index 96aab70dd10..70226f0e1c1 100644 --- a/tfjs-backend-cpu/src/utils/unary_types.ts +++ b/tfjs-backend-cpu/src/utils/unary_types.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import { DataTypeFor, DataTypeMap, NamedAttrMap } from '@tensorflow/tfjs-core'; +import {DataTypeFor, DataTypeMap, NamedAttrMap} from '@tensorflow/tfjs-core'; export type SimpleUnaryOperation = (x: I, attrs?: NamedAttrMap) => O; diff --git a/tfjs-backend-webgl/src/register_all_kernels.ts b/tfjs-backend-webgl/src/register_all_kernels.ts index f5045d0f35c..9f8e8f4ff39 100644 --- a/tfjs-backend-webgl/src/register_all_kernels.ts +++ b/tfjs-backend-webgl/src/register_all_kernels.ts @@ -166,6 +166,7 @@ import {splitVConfig} from './kernels/SplitV'; import {sqrtConfig} from './kernels/Sqrt'; import {squareConfig} from './kernels/Square'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; +import {staticRegexReplaceConfig} from './kernels/StaticRegexReplace'; import {stepConfig} from './kernels/Step'; import {stridedSliceConfig} from './kernels/StridedSlice'; import {stringNGramsConfig} from './kernels/StringNGrams'; @@ -337,6 +338,7 @@ const kernelConfigs: KernelConfig[] = [ sqrtConfig, squareConfig, squaredDifferenceConfig, + staticRegexReplaceConfig, stepConfig, stridedSliceConfig, stringNGramsConfig, diff --git a/tfjs-core/src/ops/string/static_regex_replace.ts b/tfjs-core/src/ops/string/static_regex_replace.ts index 24c6ffdb750..8656a242913 100644 --- a/tfjs-core/src/ops/string/static_regex_replace.ts +++ b/tfjs-core/src/ops/string/static_regex_replace.ts @@ -23,6 +23,25 @@ import {convertToTensor} from '../../tensor_util_env'; import {TensorLike} from '../../types'; import {op} from '../operation'; +/** + * Replace the match of a `pattern` in `input` with `rewrite`. + * + * ```js + * const result = tf.string.staticRegexReplace( + * ['format this spacing better'], ' +', ' '); + * result.print(); // ['format this spacing better'] + * ``` + * @param input: A Tensor of type string. The text to be processed. + * @param pattern: A string. The regular expression to match the input. + * @param rewrite: A string. The rewrite to be applied to the matched + * expression. + * @param replaceGlobal: An optional bool. Defaults to True. If True, the + * replacement is global, otherwise the replacement is done only on the + * first match. + * @return A Tensor of type string. + * + * @doc {heading: 'Operations', subheading: 'String'} + */ function staticRegexReplace_( input: Tensor | TensorLike, pattern: string, rewrite: string, replaceGlobal=true): Tensor { diff --git a/tfjs-core/src/ops/string/static_regex_replace_test.ts b/tfjs-core/src/ops/string/static_regex_replace_test.ts index 563a37a06e1..a10ef1e4522 100644 --- a/tfjs-core/src/ops/string/static_regex_replace_test.ts +++ b/tfjs-core/src/ops/string/static_regex_replace_test.ts @@ -20,8 +20,7 @@ import { DataTypeFor } from '../../index'; import {ALL_ENVS, describeWithFlags} from '../../jasmine_util'; describeWithFlags('staticRegexReplace', ALL_ENVS, () => { - - it('replaces the first instance of a regex pattern', async () => { + it('replaces the first instance of a string', async () => { const result = tf.string.staticRegexReplace( ['this', 'is', 'a', 'test test'], 'test', 'result', false); @@ -29,11 +28,19 @@ describeWithFlags('staticRegexReplace', ALL_ENVS, () => { .toEqual(['this', 'is', 'a', 'result test']); }); - it('replaces a regex pattern globally', async () => { + it('replaces a string globally by default', async () => { const result = tf.string.staticRegexReplace( - ['this', 'is', 'a', 'test test'], 'test', 'result', true); + ['this', 'is', 'a', 'test test'], 'test', 'result'); expect(await result.data>()) .toEqual(['this', 'is', 'a', 'result result']); }); + + it('matches using regex', async () => { + const result = tf.string.staticRegexReplace( + ['This will have normal whitespace'], ' +', ' '); + + expect(await result.data>()) + .toEqual(['This will have normal whitespace']); + }); }); From 6426ec9680eef1811d32ce15c0a91697737504b4 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 14 Feb 2023 15:23:08 -0800 Subject: [PATCH 06/11] Add StaticRegexReplace to converter --- .../python/tensorflowjs/op_list/string.json | 30 ++++++++++++++++++- .../operations/executors/string_executor.ts | 8 +++++ .../executors/string_executor_test.ts | 24 +++++++++++++++ 3 files changed, 61 insertions(+), 1 deletion(-) diff --git a/tfjs-converter/python/tensorflowjs/op_list/string.json b/tfjs-converter/python/tensorflowjs/op_list/string.json index 0f7de73efd1..1df8caba142 100644 --- a/tfjs-converter/python/tensorflowjs/op_list/string.json +++ b/tfjs-converter/python/tensorflowjs/op_list/string.json @@ -1,4 +1,32 @@ [ + { + "tfOpName": "StaticRegexReplace", + "category": "string", + "inputs": [ + { + "start": 0, + "name": "input", + "type": "tensor" + } + ], + "attrs": [ + { + "tfName": "pattern", + "name": "pattern", + "type": "string" + }, + { + "tfName": "rewrite", + "name": "rewrite", + "type": "string" + }, + { + "tfName": "replace_global", + "name": "replaceGlobal", + "type": "bool" + } + ] + }, { "tfOpName": "StringNGrams", "category": "string", @@ -97,4 +125,4 @@ } ] } -] \ No newline at end of file +] diff --git a/tfjs-converter/src/operations/executors/string_executor.ts b/tfjs-converter/src/operations/executors/string_executor.ts index e57e2107047..85b5a494235 100644 --- a/tfjs-converter/src/operations/executors/string_executor.ts +++ b/tfjs-converter/src/operations/executors/string_executor.ts @@ -29,6 +29,14 @@ export const executeOp: InternalOpExecutor = (node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext, ops = tfOps): Tensor[] => { switch (node.op) { + case 'StaticRegexReplace': { + return [ops.string.staticRegexReplace( + getParamValue('input', node, tensorMap, context) as Tensor, + getParamValue('pattern', node, tensorMap, context) as string, + getParamValue('rewrite', node, tensorMap, context) as string, + getParamValue('replaceGlobal', node, tensorMap, context) as boolean, + )]; + } case 'StringNGrams': { const {nGrams, nGramsSplits} = ops.string.stringNGrams( getParamValue('data', node, tensorMap, context) as Tensor1D, diff --git a/tfjs-converter/src/operations/executors/string_executor_test.ts b/tfjs-converter/src/operations/executors/string_executor_test.ts index 492d0165953..55792db2a47 100644 --- a/tfjs-converter/src/operations/executors/string_executor_test.ts +++ b/tfjs-converter/src/operations/executors/string_executor_test.ts @@ -49,6 +49,30 @@ describe('string', () => { }); describe('executeOp', () => { + describe('StaticRegexReplace', () => { + it('should call tfOps.string.staticRegexReplace', async () => { + node.op = 'StaticRegexReplace'; + node.inputParams = { + input: createTensorAttr(0), + }; + node.attrParams = { + pattern: createStrAttr('foo'), + rewrite: createStrAttr('bar'), + replaceGlobal: createBoolAttr(true), + }; + node.inputNames = ['input']; + + const input = [tfOps.tensor1d(['a', 'b', 'foo', 'd'])]; + const result = executeOp(node, {input}, context, + spyOpsAsTfOps) as Tensor[]; + + expect(spyOps.string.staticRegexReplace) + .toHaveBeenCalledWith(input[0], 'foo', 'bar', true); + + test_util.expectArraysEqual( + await result[0].data(), ['a', 'b', 'bar', 'd']); + }); + }); describe('StringNGrams', () => { it('should call tfOps.string.stringNGrams', async () => { node.op = 'StringNGrams'; From 48fd478a824ab7f355f4bfffa402bb6b19f71c0f Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 14 Feb 2023 15:26:32 -0800 Subject: [PATCH 07/11] Fix lint --- tfjs-backend-cpu/src/utils/unary_impl.ts | 2 +- tfjs-backend-cpu/src/utils/unary_types.ts | 2 +- tfjs-backend-cpu/src/utils/unary_utils.ts | 11 +++++++---- tfjs-backend-webgl/src/kernels/StaticRegexReplace.ts | 5 +++-- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/tfjs-backend-cpu/src/utils/unary_impl.ts b/tfjs-backend-cpu/src/utils/unary_impl.ts index 5d4cff9a81f..fcf27ad2f2b 100644 --- a/tfjs-backend-cpu/src/utils/unary_impl.ts +++ b/tfjs-backend-cpu/src/utils/unary_impl.ts @@ -29,7 +29,7 @@ export function createSimpleUnaryImpl = (values: ArrayLike, dtype: DataTypeFor, - attrs?: NamedAttrMap) => DataTypeMap[DataTypeFor] + attrs?: NamedAttrMap) => DataTypeMap[DataTypeFor]; diff --git a/tfjs-backend-cpu/src/utils/unary_utils.ts b/tfjs-backend-cpu/src/utils/unary_utils.ts index aacf627161b..0a4c8459bbe 100644 --- a/tfjs-backend-cpu/src/utils/unary_utils.ts +++ b/tfjs-backend-cpu/src/utils/unary_utils.ts @@ -33,7 +33,8 @@ import {SimpleUnaryImpl, SimpleUnaryOperation} from './unary_types'; */ export function unaryKernelFunc( - name: string, op: SimpleUnaryOperation, dtype?: DataTypeFor): KernelFunc { + name: string, op: SimpleUnaryOperation, + dtype?: DataTypeFor): KernelFunc { const impl = createSimpleUnaryImpl(op); @@ -51,14 +52,16 @@ export function unaryKernelFunc( - name: string, unaryImpl: SimpleUnaryImpl, dtype?: DataTypeFor): KernelFunc { + name: string, unaryImpl: SimpleUnaryImpl, + dtype?: DataTypeFor): KernelFunc { + return ({inputs, attrs, backend}) => { const {x} = inputs as UnaryInputs; assertNotComplex(x, name); const cpuBackend = backend as MathBackendCPU; - let values = cpuBackend.data.get(x.dataId).values; - let decoded: {[index: number]: I, length: number}//DataTypeMap[keyof DataTypeMap]; + const values = cpuBackend.data.get(x.dataId).values; + let decoded: ArrayLike; if (values instanceof Array) { if (x.dtype !== 'string') { throw new Error(`Tensor ${x} data contains an array of values but its ` diff --git a/tfjs-backend-webgl/src/kernels/StaticRegexReplace.ts b/tfjs-backend-webgl/src/kernels/StaticRegexReplace.ts index 6062306378c..e40ef192cb7 100644 --- a/tfjs-backend-webgl/src/kernels/StaticRegexReplace.ts +++ b/tfjs-backend-webgl/src/kernels/StaticRegexReplace.ts @@ -34,7 +34,8 @@ export function staticRegexReplace(args: { const $x = backend.readSync(x.dataId) as Uint8Array[]; const stringInput = backend_util.fromUint8ToStringArray($x); - const output = staticRegexReplaceImplCPU(stringInput, 'string', attrs as unknown as NamedAttrMap); + const output = staticRegexReplaceImplCPU(stringInput, 'string', + attrs as unknown as NamedAttrMap); return backend.makeTensorInfo(x.shape, 'string', output); } @@ -43,4 +44,4 @@ export const staticRegexReplaceConfig: KernelConfig = { kernelName: StaticRegexReplace, backendName: 'webgl', kernelFunc: staticRegexReplace as unknown as KernelFunc, -} +}; From 17b6cd3b8ba0be81e1717d41fd45f494b6a6f50a Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 14 Feb 2023 16:12:37 -0800 Subject: [PATCH 08/11] Disable StaticRegexReplace in tfjs-node --- tfjs-node/src/run_tests.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/tfjs-node/src/run_tests.ts b/tfjs-node/src/run_tests.ts index e2e8029c47b..ec6d3680ac8 100644 --- a/tfjs-node/src/run_tests.ts +++ b/tfjs-node/src/run_tests.ts @@ -159,6 +159,7 @@ const IGNORE_LIST: string[] = [ 'sparseReshape', 'sparseSegmentMean', 'sparseSegmentSum', + 'staticRegexReplace', 'stringNGrams', 'stringSplit', 'stringToHashBucketFast', From 921e67d0c1ed70eacda4823a5d040a77600b3dfc Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 14 Feb 2023 16:31:32 -0800 Subject: [PATCH 09/11] Add StaticRegexReplace to node backend --- tfjs-node/src/kernels/StaticRegexReplace.ts | 44 +++++++++++++++++++++ tfjs-node/src/register_all_kernels.ts | 2 + tfjs-node/src/run_tests.ts | 1 - 3 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 tfjs-node/src/kernels/StaticRegexReplace.ts diff --git a/tfjs-node/src/kernels/StaticRegexReplace.ts b/tfjs-node/src/kernels/StaticRegexReplace.ts new file mode 100644 index 00000000000..73546526c67 --- /dev/null +++ b/tfjs-node/src/kernels/StaticRegexReplace.ts @@ -0,0 +1,44 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, StaticRegexReplace, StaticRegexReplaceAttrs, StaticRegexReplaceInputs} from '@tensorflow/tfjs'; + +import {NodeJSKernelBackend} from '../nodejs_kernel_backend'; + +export const staticRegexReplaceConfig: KernelConfig = { + kernelName: StaticRegexReplace, + backendName: 'tensorflow', + kernelFunc: (args) => { + const tensors = args.inputs as unknown as StaticRegexReplaceInputs; + const backend = args.backend as NodeJSKernelBackend; + const {pattern, rewrite, replaceGlobal} = + args.attrs as unknown as StaticRegexReplaceAttrs; + + const opAttrs = [ + {name: 'pattern', type: backend.binding.TF_ATTR_STRING, value: pattern}, + {name: 'rewrite', type: backend.binding.TF_ATTR_STRING, value: rewrite}, + { + name: 'replace_global', + type: backend.binding.TF_ATTR_BOOL, + value: replaceGlobal, + }, + ]; + + const inputs = [tensors.x]; + return backend.executeSingleOutput('StaticRegexReplace', opAttrs, inputs); + } +}; diff --git a/tfjs-node/src/register_all_kernels.ts b/tfjs-node/src/register_all_kernels.ts index 24e0c52e510..781316f892c 100644 --- a/tfjs-node/src/register_all_kernels.ts +++ b/tfjs-node/src/register_all_kernels.ts @@ -159,6 +159,7 @@ import {splitVConfig} from './kernels/SplitV'; import {sqrtConfig} from './kernels/Sqrt'; import {squareConfig} from './kernels/Square'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; +import {staticRegexReplaceConfig} from './kernels/StaticRegexReplace'; import {stepConfig} from './kernels/Step'; import {stridedSliceConfig} from './kernels/StridedSlice'; import {subConfig} from './kernels/Sub'; @@ -316,6 +317,7 @@ const kernelConfigs: KernelConfig[] = [ sqrtConfig, squareConfig, squaredDifferenceConfig, + staticRegexReplaceConfig, stepConfig, stridedSliceConfig, subConfig, diff --git a/tfjs-node/src/run_tests.ts b/tfjs-node/src/run_tests.ts index ec6d3680ac8..e2e8029c47b 100644 --- a/tfjs-node/src/run_tests.ts +++ b/tfjs-node/src/run_tests.ts @@ -159,7 +159,6 @@ const IGNORE_LIST: string[] = [ 'sparseReshape', 'sparseSegmentMean', 'sparseSegmentSum', - 'staticRegexReplace', 'stringNGrams', 'stringSplit', 'stringToHashBucketFast', From 43fae0cdd5c31a13b1b3255a901807914b510d04 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Tue, 14 Feb 2023 16:42:04 -0800 Subject: [PATCH 10/11] Simplify string tensor check --- tfjs-backend-cpu/src/utils/unary_utils.ts | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tfjs-backend-cpu/src/utils/unary_utils.ts b/tfjs-backend-cpu/src/utils/unary_utils.ts index 0a4c8459bbe..13431e4e96d 100644 --- a/tfjs-backend-cpu/src/utils/unary_utils.ts +++ b/tfjs-backend-cpu/src/utils/unary_utils.ts @@ -62,10 +62,9 @@ export function unaryKernelFuncFromImpl; - if (values instanceof Array) { - if (x.dtype !== 'string') { - throw new Error(`Tensor ${x} data contains an array of values but its ` - + `dtype is ${x.dtype} instead of 'string'`); + if (x.dtype === 'string') { + if (!(values instanceof Array)) { + throw new Error('String tensor\'s value was not an instance of Array'); } decoded = backend_util.fromUint8ToStringArray(values) as unknown as ArrayLike; From 9912f8fe518b7b6c8b6355106cd981c2cf38f90f Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Wed, 15 Feb 2023 10:27:50 -0800 Subject: [PATCH 11/11] Use Array.isArray --- tfjs-backend-cpu/src/utils/unary_utils.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-backend-cpu/src/utils/unary_utils.ts b/tfjs-backend-cpu/src/utils/unary_utils.ts index 13431e4e96d..1f913804f50 100644 --- a/tfjs-backend-cpu/src/utils/unary_utils.ts +++ b/tfjs-backend-cpu/src/utils/unary_utils.ts @@ -63,7 +63,7 @@ export function unaryKernelFuncFromImpl; if (x.dtype === 'string') { - if (!(values instanceof Array)) { + if (!Array.isArray(values)) { throw new Error('String tensor\'s value was not an instance of Array'); } decoded = backend_util.fromUint8ToStringArray(values) as unknown as