Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StaticRegexReplace Op #7379

Merged
merged 12 commits into from
Feb 15, 2023
37 changes: 37 additions & 0 deletions tfjs-backend-cpu/src/kernels/StaticRegexReplace.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {KernelConfig, StaticRegexReplace, StaticRegexReplaceAttrs} from '@tensorflow/tfjs-core';
import {createSimpleUnaryImpl} from '../utils/unary_impl';
import {unaryKernelFuncFromImpl} from '../utils/unary_utils';

export const staticRegexReplaceImpl = createSimpleUnaryImpl<string,
string>((x: string, attrs) => {
const {pattern, replaceGlobal, rewrite} =
attrs as unknown as StaticRegexReplaceAttrs;
// TODO(mattSoulanille): Don't create a regex each time.
return x.replace(new RegExp(pattern, replaceGlobal ? 'g' : ''), rewrite);
});

const staticRegexReplace =
unaryKernelFuncFromImpl(StaticRegexReplace, staticRegexReplaceImpl);

export const staticRegexReplaceConfig: KernelConfig = {
kernelName: StaticRegexReplace,
backendName: 'cpu',
kernelFunc: staticRegexReplace,
};
2 changes: 2 additions & 0 deletions tfjs-backend-cpu/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ import {splitVConfig} from './kernels/SplitV';
import {sqrtConfig} from './kernels/Sqrt';
import {squareConfig} from './kernels/Square';
import {squaredDifferenceConfig} from './kernels/SquaredDifference';
import {staticRegexReplaceConfig} from './kernels/StaticRegexReplace';
import {stepConfig} from './kernels/Step';
import {stridedSliceConfig} from './kernels/StridedSlice';
import {stringNGramsConfig} from './kernels/StringNGrams';
Expand Down Expand Up @@ -342,6 +343,7 @@ const kernelConfigs: KernelConfig[] = [
sqrtConfig,
squareConfig,
squaredDifferenceConfig,
staticRegexReplaceConfig,
stepConfig,
stridedSliceConfig,
stringNGramsConfig,
Expand Down
1 change: 1 addition & 0 deletions tfjs-backend-cpu/src/shared.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ export {sparseReshapeImpl} from './kernels/SparseReshape_impl';
export {sparseSegmentReductionImpl} from './kernels/SparseSegmentReduction_impl';
export {sqrtImpl} from './kernels/Sqrt';
export {squaredDifferenceImpl} from './kernels/SquaredDifference';
export {staticRegexReplaceImpl} from './kernels/StaticRegexReplace';
export {stridedSliceImpl} from './kernels/StridedSlice_impl';
export {stringNGramsImpl} from './kernels/StringNGrams_impl';
export {stringSplitImpl} from './kernels/StringSplit_impl';
Expand Down
9 changes: 5 additions & 4 deletions tfjs-backend-cpu/src/utils/unary_impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
* =============================================================================
*/

import {NumericDataType, util} from '@tensorflow/tfjs-core';
import {util} from '@tensorflow/tfjs-core';

import {SimpleUnaryImpl, SimpleUnaryOperation} from './unary_types';

/**
* Template that creates implementation for unary op.
*/
export function createSimpleUnaryImpl(op: SimpleUnaryOperation):
SimpleUnaryImpl {
export function createSimpleUnaryImpl<I extends number | string = number,
O extends number | string = number>(op: SimpleUnaryOperation<I, O>):
SimpleUnaryImpl<I, O> {
return (values, dtype, attrs) => {
const newValues =
util.getTypedArrayFromDType(dtype as NumericDataType, values.length);
util.getArrayFromDType(dtype, values.length);
for (let i = 0; i < values.length; ++i) {
newValues[i] = op(values[i], attrs);
}
Expand Down
12 changes: 8 additions & 4 deletions tfjs-backend-cpu/src/utils/unary_types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
* =============================================================================
*/

import {DataType, NamedAttrMap, TypedArray} from '@tensorflow/tfjs-core';
import {DataTypeFor, DataTypeMap, NamedAttrMap} from '@tensorflow/tfjs-core';

export type SimpleUnaryOperation = (x: number, attrs?: NamedAttrMap) => number;
export type SimpleUnaryImpl =
(values: TypedArray, dtype: DataType, attrs?: NamedAttrMap) => TypedArray;
export type SimpleUnaryOperation<I extends number | string = number,
O extends number | string = number> = (x: I, attrs?: NamedAttrMap) => O;

export type SimpleUnaryImpl<I extends number | string = number | string,
O extends number | string = number | string> =
(values: ArrayLike<I>, dtype: DataTypeFor<O>,
attrs?: NamedAttrMap) => DataTypeMap[DataTypeFor<O>];
55 changes: 28 additions & 27 deletions tfjs-backend-cpu/src/utils/unary_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
* =============================================================================
*/

import {DataType, KernelFunc, TypedArray, UnaryInputs, util} from '@tensorflow/tfjs-core';
import {backend_util, DataTypeFor, KernelFunc, UnaryInputs} from '@tensorflow/tfjs-core';

import {MathBackendCPU} from '../backend_cpu';
import {assertNotComplex} from '../cpu_util';
import {createSimpleUnaryImpl} from './unary_impl';

import {SimpleUnaryImpl, SimpleUnaryOperation} from './unary_types';

Expand All @@ -30,25 +31,14 @@ import {SimpleUnaryImpl, SimpleUnaryOperation} from './unary_types';
* result has the same dtype as the input. This is mainly used in certain
* kernels that return bool type, such as isFinite, isInf, etc.
*/
export function unaryKernelFunc(
name: string, op: SimpleUnaryOperation, dtype?: DataType): KernelFunc {
return ({inputs, attrs, backend}) => {
const {x} = inputs as UnaryInputs;
assertNotComplex(x, name);
if (x.dtype === 'string' || dtype === 'string') {
throw new Error('unaryKernelFunc does not support string input/output');
}
export function unaryKernelFunc<I extends number | string = number,
O extends number | string = number>(
name: string, op: SimpleUnaryOperation<I, O>,
dtype?: DataTypeFor<O>): KernelFunc {

const cpuBackend = backend as MathBackendCPU;
const values = cpuBackend.data.get(x.dataId).values as TypedArray;
const xSize = util.sizeFromShape(x.shape);
const $dtype = dtype || x.dtype;
const newValues = util.getArrayFromDType($dtype, xSize);
for (let i = 0; i < xSize; ++i) {
newValues[i] = op(values[i], attrs);
}
return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
};
const impl = createSimpleUnaryImpl<I, O>(op);

return unaryKernelFuncFromImpl<I, O>(name, impl, dtype);
}

/**
Expand All @@ -60,19 +50,30 @@ export function unaryKernelFunc(
* result has the same dtype as the input. This is mainly used in certain
* kernels that return bool type, such as isFinite, isInf, etc.
*/
export function unaryKernelFuncFromImpl(
name: string, unaryImpl: SimpleUnaryImpl, dtype?: DataType): KernelFunc {
export function unaryKernelFuncFromImpl<I extends number | string = number,
O extends number | string = number>(
name: string, unaryImpl: SimpleUnaryImpl<I, O>,
dtype?: DataTypeFor<O>): KernelFunc {

return ({inputs, attrs, backend}) => {
const {x} = inputs as UnaryInputs;
assertNotComplex(x, name);
if (x.dtype === 'string' || dtype === 'string') {
throw new Error('unaryKernelFunc does not support string input/output');
}

const cpuBackend = backend as MathBackendCPU;
const values = cpuBackend.data.get(x.dataId).values as TypedArray;
const $dtype = dtype || x.dtype;
const newValues = unaryImpl(values, $dtype, attrs);
const values = cpuBackend.data.get(x.dataId).values;
let decoded: ArrayLike<I>;
if (x.dtype === 'string') {
if (!Array.isArray(values)) {
throw new Error('String tensor\'s value was not an instance of Array');
}
decoded = backend_util.fromUint8ToStringArray(values) as unknown as
ArrayLike<I>;
} else {
decoded = values as unknown as ArrayLike<I>;
}

const $dtype = dtype || x.dtype as DataTypeFor<O>;
const newValues = unaryImpl(decoded, $dtype, attrs);
return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
};
}
4 changes: 2 additions & 2 deletions tfjs-backend-webgl/src/kernel_utils/kernel_funcs_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {backend_util, BinaryInputs, DataType, env, KernelFunc, TypedArray, UnaryInputs, upcastType} from '@tensorflow/tfjs-core';
import { backend_util, BinaryInputs, DataType, env, KernelFunc, TypedArray, UnaryInputs, upcastType} from '@tensorflow/tfjs-core';

import {MathBackendWebGL} from '../backend_webgl';
import {BinaryOpProgram} from '../binaryop_gpu';
Expand All @@ -36,7 +36,7 @@ type UnaryKernelFuncConfig = {
opSnippet: string,
packedOpSnippet?: string,
cpuKernelImpl?: SimpleUnaryKernelImplCPU,
dtype?: DataType
dtype?: DataType,
};

/**
Expand Down
2 changes: 2 additions & 0 deletions tfjs-backend-webgl/src/kernel_utils/shared.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ const {
sparseReshapeImpl: sparseReshapeImplCPU,
sparseSegmentReductionImpl: sparseSegmentReductionImplCPU,
sqrtImpl: sqrtImplCPU,
staticRegexReplaceImpl: staticRegexReplaceImplCPU,
stridedSliceImpl: stridedSliceImplCPU,
stringNGramsImpl: stringNGramsImplCPU,
stringSplitImpl: stringSplitImplCPU,
Expand Down Expand Up @@ -114,6 +115,7 @@ export {
sparseReshapeImplCPU,
sparseSegmentReductionImplCPU,
sqrtImplCPU,
staticRegexReplaceImplCPU,
stridedSliceImplCPU,
stringNGramsImplCPU,
stringSplitImplCPU,
Expand Down
47 changes: 47 additions & 0 deletions tfjs-backend-webgl/src/kernels/StaticRegexReplace.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {backend_util, KernelConfig, KernelFunc, NamedAttrMap, StaticRegexReplace, StaticRegexReplaceAttrs, StaticRegexReplaceInputs, TensorInfo} from '@tensorflow/tfjs-core';
import {MathBackendWebGL} from '../backend_webgl';
import {staticRegexReplaceImplCPU} from '../kernel_utils/shared';

export function staticRegexReplace(args: {
inputs: StaticRegexReplaceInputs,
backend: MathBackendWebGL,
attrs: StaticRegexReplaceAttrs,
}): TensorInfo {
const {inputs, backend, attrs} = args;
const {x} = inputs;

if (x.dtype !== 'string') {
throw new Error('Input must be of datatype string');
}

const $x = backend.readSync(x.dataId) as Uint8Array[];

const stringInput = backend_util.fromUint8ToStringArray($x);
const output = staticRegexReplaceImplCPU(stringInput, 'string',
attrs as unknown as NamedAttrMap);

return backend.makeTensorInfo(x.shape, 'string', output);
}

export const staticRegexReplaceConfig: KernelConfig = {
kernelName: StaticRegexReplace,
backendName: 'webgl',
kernelFunc: staticRegexReplace as unknown as KernelFunc,
};
2 changes: 2 additions & 0 deletions tfjs-backend-webgl/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ import {splitVConfig} from './kernels/SplitV';
import {sqrtConfig} from './kernels/Sqrt';
import {squareConfig} from './kernels/Square';
import {squaredDifferenceConfig} from './kernels/SquaredDifference';
import {staticRegexReplaceConfig} from './kernels/StaticRegexReplace';
import {stepConfig} from './kernels/Step';
import {stridedSliceConfig} from './kernels/StridedSlice';
import {stringNGramsConfig} from './kernels/StringNGrams';
Expand Down Expand Up @@ -337,6 +338,7 @@ const kernelConfigs: KernelConfig[] = [
sqrtConfig,
squareConfig,
squaredDifferenceConfig,
staticRegexReplaceConfig,
stepConfig,
stridedSliceConfig,
stringNGramsConfig,
Expand Down
30 changes: 29 additions & 1 deletion tfjs-converter/python/tensorflowjs/op_list/string.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,32 @@
[
{
"tfOpName": "StaticRegexReplace",
"category": "string",
"inputs": [
{
"start": 0,
"name": "input",
"type": "tensor"
}
],
"attrs": [
{
"tfName": "pattern",
"name": "pattern",
"type": "string"
},
{
"tfName": "rewrite",
"name": "rewrite",
"type": "string"
},
{
"tfName": "replace_global",
"name": "replaceGlobal",
"type": "bool"
}
]
},
{
"tfOpName": "StringNGrams",
"category": "string",
Expand Down Expand Up @@ -97,4 +125,4 @@
}
]
}
]
]
8 changes: 8 additions & 0 deletions tfjs-converter/src/operations/executors/string_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ export const executeOp: InternalOpExecutor =
(node: Node, tensorMap: NamedTensorsMap,
context: ExecutionContext, ops = tfOps): Tensor[] => {
switch (node.op) {
case 'StaticRegexReplace': {
return [ops.string.staticRegexReplace(
getParamValue('input', node, tensorMap, context) as Tensor,
getParamValue('pattern', node, tensorMap, context) as string,
getParamValue('rewrite', node, tensorMap, context) as string,
getParamValue('replaceGlobal', node, tensorMap, context) as boolean,
)];
}
case 'StringNGrams': {
const {nGrams, nGramsSplits} = ops.string.stringNGrams(
getParamValue('data', node, tensorMap, context) as Tensor1D,
Expand Down
24 changes: 24 additions & 0 deletions tfjs-converter/src/operations/executors/string_executor_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,30 @@ describe('string', () => {
});

describe('executeOp', () => {
describe('StaticRegexReplace', () => {
it('should call tfOps.string.staticRegexReplace', async () => {
node.op = 'StaticRegexReplace';
node.inputParams = {
input: createTensorAttr(0),
};
node.attrParams = {
pattern: createStrAttr('foo'),
rewrite: createStrAttr('bar'),
replaceGlobal: createBoolAttr(true),
};
node.inputNames = ['input'];

const input = [tfOps.tensor1d(['a', 'b', 'foo', 'd'])];
const result = executeOp(node, {input}, context,
spyOpsAsTfOps) as Tensor[];

expect(spyOps.string.staticRegexReplace)
.toHaveBeenCalledWith(input[0], 'foo', 'bar', true);

test_util.expectArraysEqual(
await result[0].data(), ['a', 'b', 'bar', 'd']);
});
});
describe('StringNGrams', () => {
it('should call tfOps.string.stringNGrams', async () => {
node.op = 'StringNGrams';
Expand Down
2 changes: 1 addition & 1 deletion tfjs-core/src/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer';
export {SGDOptimizer} from './optimizers/sgd_optimizer';
export {DataToGPUOptions, DataToGPUWebGLOption, GPUData, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, Variable} from './tensor';
export {GradSaveFunc, NamedTensorMap, TensorContainer, TensorContainerArray, TensorContainerObject} from './tensor_types';
export {BackendValues, DataType, DataTypeMap, DataValues, NumericDataType, PixelData, Rank, RecursiveArray, ScalarLike, ShapeMap, sumOutType, TensorLike, TypedArray, upcastType, WebGLData, WebGPUData} from './types';
export {BackendValues, DataType, DataTypeMap, DataTypeFor, DataValues, NumericDataType, PixelData, Rank, RecursiveArray, ScalarLike, ShapeMap, sumOutType, TensorLike, TypedArray, upcastType, WebGLData, WebGPUData} from './types';

export * from './ops/ops';
export {Reduction} from './ops/loss_ops_utils';
Expand Down
8 changes: 8 additions & 0 deletions tfjs-core/src/kernel_names.ts
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,14 @@ export type SquaredDifferenceInputs = BinaryInputs;
export const Square = 'Square';
export type SquareInputs = Pick<NamedTensorInfoMap, 'x'>;

export const StaticRegexReplace = 'StaticRegexReplace';
export type StaticRegexReplaceInputs = UnaryInputs;
export interface StaticRegexReplaceAttrs {
pattern: string;
rewrite: string;
replaceGlobal: boolean;
}

export const StridedSlice = 'StridedSlice';
export type StridedSliceInputs = Pick<NamedTensorInfoMap, 'x'>;
export interface StridedSliceAttrs {
Expand Down
Loading