Skip to content

Commit

Permalink
Add unique (#7469)
Browse files Browse the repository at this point in the history
  • Loading branch information
chunnienc authored Mar 13, 2023
1 parent ff6739d commit 8cd5451
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 7 deletions.
13 changes: 7 additions & 6 deletions tfjs-backend-cpu/src/kernels/Unique_impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ export function uniqueImpl(

// A map from unique elements (their string representations) to their values
// in "indices" (below).
const uniqueElements: {[key: string]: number} = {};
const uniqueElements = new Map<string, number>();
// The indices of each unique element in the original tensor along the given
// axis. It is 1D and has the same size as the given axis.
const indices = new Int32Array(shape[$axis]);
Expand All @@ -119,11 +119,12 @@ export function uniqueImpl(
}

// Dedup and update various indices.
if (uniqueElements[element] !== undefined) {
indices[i] = uniqueElements[element];
const existingIndex = uniqueElements.get(element);
if (existingIndex != null) {
indices[i] = existingIndex;
} else {
const uniqueIndex = Object.keys(uniqueElements).length;
uniqueElements[element] = uniqueIndex;
const uniqueIndex = uniqueElements.size;
uniqueElements.set(element, uniqueIndex);
indices[i] = uniqueIndex;
uniqueIndices.push(i);
}
Expand All @@ -133,7 +134,7 @@ export function uniqueImpl(
// (uniqueIndices). Extract them from input buffer and store them in the
// output buffer.
const outputTmpShape = newShape.slice();
outputTmpShape[1] = Object.keys(uniqueElements).length;
outputTmpShape[1] = uniqueElements.size;
const outputBuffer = new TensorBuffer(outputTmpShape, dtype);
uniqueIndices.forEach((uniqueElementIndex, i) => {
for (let m = 0; m < newShape[0]; m++) {
Expand Down
5 changes: 4 additions & 1 deletion tfjs-backend-wasm/src/kernel_utils/shared.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ import {stringNGramsImpl as stringNGramsImplCPU} from '@tensorflow/tfjs-backend-
import {stringSplitImpl as stringSplitImplCPU} from '@tensorflow/tfjs-backend-cpu/dist/shared';
// tslint:disable-next-line: no-imports-from-dist
import {stringToHashBucketFastImpl as stringToHashBucketFastImplCPU} from '@tensorflow/tfjs-backend-cpu/dist/shared';
// tslint:disable-next-line: no-imports-from-dist
import {uniqueImpl as uniqueImplCPU} from '@tensorflow/tfjs-backend-cpu/dist/shared';

export {
concatImplCPU,
rangeImplCPU,
sliceImplCPU,
stringNGramsImplCPU,
stringSplitImplCPU,
stringToHashBucketFastImplCPU
stringToHashBucketFastImplCPU,
uniqueImplCPU,
};
44 changes: 44 additions & 0 deletions tfjs-backend-wasm/src/kernels/Unique.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {KernelConfig, KernelFunc, TensorInfo, Unique, UniqueAttrs, UniqueInputs} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';
import {uniqueImplCPU} from '../kernel_utils/shared';

function unique(
args: {inputs: UniqueInputs, attrs: UniqueAttrs, backend: BackendWasm}):
TensorInfo[] {
const {inputs, attrs, backend} = args;
const {axis} = attrs;
const {x} = inputs;

const {outputValues, outputShape, indices} =
uniqueImplCPU(backend.readSync(x.dataId), axis, x.shape, x.dtype);

return [
backend.makeOutput(
outputShape, x.dtype, /*memoryOffset=*/undefined, outputValues),
backend.makeOutput(
[indices.length], 'int32', /*memoryOffset=*/undefined, indices),
];
}

export const uniqueConfig: KernelConfig = {
kernelName: Unique,
backendName: 'wasm',
kernelFunc: unique as unknown as KernelFunc,
};
2 changes: 2 additions & 0 deletions tfjs-backend-wasm/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ import {tileConfig} from './kernels/Tile';
import {topKConfig} from './kernels/TopK';
import {transformConfig} from './kernels/Transform';
import {transposeConfig} from './kernels/Transpose';
import {uniqueConfig} from './kernels/Unique';
import {unpackConfig} from './kernels/Unpack';
import {zerosLikeConfig} from './kernels/ZerosLike';

Expand Down Expand Up @@ -310,6 +311,7 @@ const kernelConfigs: KernelConfig[] = [
topKConfig,
transformConfig,
transposeConfig,
uniqueConfig,
unpackConfig,
zerosLikeConfig
];
Expand Down
1 change: 1 addition & 0 deletions tfjs-backend-wasm/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ const TEST_FILTERS: TestFilter[] = [
{include: 'bincount'},
{include: 'expm1 '},
{include: 'multinomial'},
{include: 'unique'},
];

const customInclude = (testName: string) => {
Expand Down

0 comments on commit 8cd5451

Please sign in to comment.