diff --git a/tfjs-backend-cpu/src/kernels/Unique_impl.ts b/tfjs-backend-cpu/src/kernels/Unique_impl.ts index e1e104ca07d..ab9bfd1717b 100644 --- a/tfjs-backend-cpu/src/kernels/Unique_impl.ts +++ b/tfjs-backend-cpu/src/kernels/Unique_impl.ts @@ -92,7 +92,7 @@ export function uniqueImpl( // A map from unique elements (their string representations) to their values // in "indices" (below). - const uniqueElements: {[key: string]: number} = {}; + const uniqueElements = new Map(); // The indices of each unique element in the original tensor along the given // axis. It is 1D and has the same size as the given axis. const indices = new Int32Array(shape[$axis]); @@ -119,11 +119,12 @@ export function uniqueImpl( } // Dedup and update various indices. - if (uniqueElements[element] !== undefined) { - indices[i] = uniqueElements[element]; + const existingIndex = uniqueElements.get(element); + if (existingIndex != null) { + indices[i] = existingIndex; } else { - const uniqueIndex = Object.keys(uniqueElements).length; - uniqueElements[element] = uniqueIndex; + const uniqueIndex = uniqueElements.size; + uniqueElements.set(element, uniqueIndex); indices[i] = uniqueIndex; uniqueIndices.push(i); } @@ -133,7 +134,7 @@ export function uniqueImpl( // (uniqueIndices). Extract them from input buffer and store them in the // output buffer. const outputTmpShape = newShape.slice(); - outputTmpShape[1] = Object.keys(uniqueElements).length; + outputTmpShape[1] = uniqueElements.size; const outputBuffer = new TensorBuffer(outputTmpShape, dtype); uniqueIndices.forEach((uniqueElementIndex, i) => { for (let m = 0; m < newShape[0]; m++) { diff --git a/tfjs-backend-wasm/src/kernel_utils/shared.ts b/tfjs-backend-wasm/src/kernel_utils/shared.ts index 003ecaf361f..3b6f4202266 100644 --- a/tfjs-backend-wasm/src/kernel_utils/shared.ts +++ b/tfjs-backend-wasm/src/kernel_utils/shared.ts @@ -29,6 +29,8 @@ import {stringNGramsImpl as stringNGramsImplCPU} from '@tensorflow/tfjs-backend- import {stringSplitImpl as stringSplitImplCPU} from '@tensorflow/tfjs-backend-cpu/dist/shared'; // tslint:disable-next-line: no-imports-from-dist import {stringToHashBucketFastImpl as stringToHashBucketFastImplCPU} from '@tensorflow/tfjs-backend-cpu/dist/shared'; +// tslint:disable-next-line: no-imports-from-dist +import {uniqueImpl as uniqueImplCPU} from '@tensorflow/tfjs-backend-cpu/dist/shared'; export { concatImplCPU, @@ -36,5 +38,6 @@ export { sliceImplCPU, stringNGramsImplCPU, stringSplitImplCPU, - stringToHashBucketFastImplCPU + stringToHashBucketFastImplCPU, + uniqueImplCPU, }; diff --git a/tfjs-backend-wasm/src/kernels/Unique.ts b/tfjs-backend-wasm/src/kernels/Unique.ts new file mode 100644 index 00000000000..4ea90069c15 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Unique.ts @@ -0,0 +1,44 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {KernelConfig, KernelFunc, TensorInfo, Unique, UniqueAttrs, UniqueInputs} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; +import {uniqueImplCPU} from '../kernel_utils/shared'; + +function unique( + args: {inputs: UniqueInputs, attrs: UniqueAttrs, backend: BackendWasm}): + TensorInfo[] { + const {inputs, attrs, backend} = args; + const {axis} = attrs; + const {x} = inputs; + + const {outputValues, outputShape, indices} = + uniqueImplCPU(backend.readSync(x.dataId), axis, x.shape, x.dtype); + + return [ + backend.makeOutput( + outputShape, x.dtype, /*memoryOffset=*/undefined, outputValues), + backend.makeOutput( + [indices.length], 'int32', /*memoryOffset=*/undefined, indices), + ]; +} + +export const uniqueConfig: KernelConfig = { + kernelName: Unique, + backendName: 'wasm', + kernelFunc: unique as unknown as KernelFunc, +}; diff --git a/tfjs-backend-wasm/src/register_all_kernels.ts b/tfjs-backend-wasm/src/register_all_kernels.ts index 4d31b216e22..5e7b6d321f8 100644 --- a/tfjs-backend-wasm/src/register_all_kernels.ts +++ b/tfjs-backend-wasm/src/register_all_kernels.ts @@ -162,6 +162,7 @@ import {tileConfig} from './kernels/Tile'; import {topKConfig} from './kernels/TopK'; import {transformConfig} from './kernels/Transform'; import {transposeConfig} from './kernels/Transpose'; +import {uniqueConfig} from './kernels/Unique'; import {unpackConfig} from './kernels/Unpack'; import {zerosLikeConfig} from './kernels/ZerosLike'; @@ -310,6 +311,7 @@ const kernelConfigs: KernelConfig[] = [ topKConfig, transformConfig, transposeConfig, + uniqueConfig, unpackConfig, zerosLikeConfig ]; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 62beecc8005..618519875bb 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -409,6 +409,7 @@ const TEST_FILTERS: TestFilter[] = [ {include: 'bincount'}, {include: 'expm1 '}, {include: 'multinomial'}, + {include: 'unique'}, ]; const customInclude = (testName: string) => {