From 936300bb191654805d1f8f27d6c9538373c9d16a Mon Sep 17 00:00:00 2001 From: Ahmed Sabie Date: Wed, 3 Aug 2022 17:37:12 -0400 Subject: [PATCH 1/6] Add kernel RaggedGather for CPU and WebGL backend --- tfjs-backend-cpu/src/kernels/RaggedGather.ts | 59 +++++ .../src/kernels/RaggedGather_impl.ts | 226 ++++++++++++++++++ tfjs-backend-cpu/src/register_all_kernels.ts | 2 + tfjs-backend-cpu/src/shared.ts | 1 + tfjs-backend-webgl/src/kernel_utils/shared.ts | 2 + .../src/kernels/RaggedGather.ts | 58 +++++ .../src/register_all_kernels.ts | 2 + tfjs-core/src/kernel_names.ts | 9 + tfjs-core/src/ops/ops.ts | 1 + tfjs-core/src/ops/ragged_gather.ts | 75 ++++++ tfjs-core/src/ops/ragged_gather_test.ts | 179 ++++++++++++++ tfjs-node/src/run_tests.ts | 1 + 12 files changed, 615 insertions(+) create mode 100644 tfjs-backend-cpu/src/kernels/RaggedGather.ts create mode 100644 tfjs-backend-cpu/src/kernels/RaggedGather_impl.ts create mode 100644 tfjs-backend-webgl/src/kernels/RaggedGather.ts create mode 100644 tfjs-core/src/ops/ragged_gather.ts create mode 100644 tfjs-core/src/ops/ragged_gather_test.ts diff --git a/tfjs-backend-cpu/src/kernels/RaggedGather.ts b/tfjs-backend-cpu/src/kernels/RaggedGather.ts new file mode 100644 index 00000000000..f84fdaa08ea --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/RaggedGather.ts @@ -0,0 +1,59 @@ +/** + * @license + * Copyright 2022 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, KernelFunc, RaggedGather, RaggedGatherAttrs, RaggedGatherInputs, TensorInfo, TypedArray} from '@tensorflow/tfjs-core'; + +import {MathBackendCPU} from '../backend_cpu'; + +import {raggedGatherImpl} from './RaggedGather_impl'; + +export function raggedGather(args: { + inputs: RaggedGatherInputs, + backend: MathBackendCPU, + attrs: RaggedGatherAttrs +}): TensorInfo[] { + const {inputs, backend, attrs} = args; + const {paramsNestedSplits, paramsDenseValues, indices} = inputs; + const {outputRaggedRank} = attrs; + + const $paramsNestedSplits = paramsNestedSplits.map( + t => backend.data.get(t.dataId).values as TypedArray); + const $paramsNestedSplitsShapes = paramsNestedSplits.map(t => t.shape); + const $paramsDenseValues = + backend.data.get(paramsDenseValues.dataId).values as TypedArray; + const $indices = backend.data.get(indices.dataId).values as TypedArray; + + const [outputNestedSplits, outputDenseValues, outputDenseValuesShape] = + raggedGatherImpl( + $paramsNestedSplits, $paramsNestedSplitsShapes, $paramsDenseValues, + paramsDenseValues.shape, paramsDenseValues.dtype, $indices, + indices.shape, outputRaggedRank); + + const outputNestedSplitsTensors = outputNestedSplits.map( + (splits) => backend.makeTensorInfo([splits.length], 'int32', splits)); + + const outputDenseValuesTensor = backend.makeTensorInfo( + outputDenseValuesShape, paramsDenseValues.dtype, outputDenseValues); + + return outputNestedSplitsTensors.concat([outputDenseValuesTensor]); +} + +export const raggedGatherConfig: KernelConfig = { + kernelName: RaggedGather, + backendName: 'cpu', + kernelFunc: raggedGather as {} as KernelFunc, +}; diff --git a/tfjs-backend-cpu/src/kernels/RaggedGather_impl.ts b/tfjs-backend-cpu/src/kernels/RaggedGather_impl.ts new file mode 100644 index 00000000000..70e583630e3 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/RaggedGather_impl.ts @@ -0,0 +1,226 @@ +/** + * @license + * Copyright 2022 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {DataType, TypedArray, util} from '@tensorflow/tfjs-core'; + +function validateIndices( + indices: TypedArray, indicesShape: number[], numParams: number) { + indices.forEach((index: number, i: number) => { + if (index < 0 || index >= numParams) { + const locString = + util.indexToLoc( + i, indicesShape.length, util.computeStrides(indicesShape)) + .join(','); + throw new Error( + `indices[${locString}] = ${index} is not in [0, ${numParams})`); + } + }); +} + +function validateSplits( + paramsNestedSplits: TypedArray[], numParamsDenseValues: number) { + // Validate + for (let dim = 0; dim < paramsNestedSplits.length; ++dim) { + const splits = paramsNestedSplits[dim]; + const lastSplit = (dim === paramsNestedSplits.length - 1) ? + numParamsDenseValues : + paramsNestedSplits[dim + 1].length; + if (splits.length === 0) { + throw new Error('Ragged splits may not be empty'); + } + if (splits[0] < 0) { + throw new Error('Ragged splits must be non-negative'); + } + if (splits[splits.length - 1] > lastSplit) { + throw new Error('Ragged splits must not point past values'); + } + for (let i = 1; i < splits.length; ++i) { + if (splits[i - 1] > splits[i]) { + throw new Error('Ragged splits must be sorted'); + } + } + } +} + +// Construct the `splits` output tensors, encoded using a nested vector. +// Also find the slices of values that need to be copied, and store them +// in `valueSlices`. The total number of values that will be copied (which +// we need for allocating the output values tensor) is stored in `numValues`. +function makeSplits( + indices: TypedArray, indicesShape: number[], + paramsNestedSplits: TypedArray[], numParamsDenseValues: number) { + const valueSlices: Array<[number, number]> = []; + let numValues = 0; + + const numSplits = indicesShape.length - 1 + paramsNestedSplits.length; + const outSplits = new Array(numSplits).fill(null).map(() => [0]); + + validateSplits(paramsNestedSplits, numParamsDenseValues); + + // Add `splits` that come from all but the last dimension of the dense + // Tensor `indices`. In particular, for each dimension D, we add a + // splits tensor whose values are: + // range(reduceProd(splits.shape[:D]) + 1) * splits.shape[D+1] + // E.g., if indices.shape=[2, 3, 4] then we will add splits tensors: + // [0, 3, 6] # length=2+1, stride=3 + // [0, 4, 8, 12, 16, 20, 24] # length=2*3+1, stride=4 + let nrows = 1; + for (let dim = 0; dim < indicesShape.length - 1; ++dim) { + nrows *= indicesShape[dim]; + const rowLength = indicesShape[dim + 1]; + for (let i = 1; i < nrows + 1; ++i) { + outSplits[dim].push(i * rowLength); + } + } + + // Add `splits` that come from `paramsNestedSplits`. Starting with the + // outermost ragged dimension (i.e., the first `splits` tensor), we work + // our way in, finding the range of values that should be copied. As we + // go, we update the output `splits` for each dimension with the appropriate + // values. In particular, the *lengths* of the slices from `param_splits` + // should be copied to generate corresponding slice lengths in the output + // splits. E.g., if we are copying a ragged row with length 4, then we + // should add a new split point to outSplits that is 4 greater than the + // previous split point in outSplits. + for (let i = 0; i < indices.length; ++i) { + let start = indices[i]; + let limit = indices[i] + 1; + + // Copy splits. + for (let dim = 0; dim < paramsNestedSplits.length; ++dim) { + const splits = paramsNestedSplits[dim]; + const outDim = dim + indicesShape.length - 1; + if (outDim >= 0) { + const outSplitsOutDim = outSplits[outDim]; + const delta = + outSplitsOutDim[outSplitsOutDim.length - 1] - splits[start]; + for (let j = start; j < limit; ++j) { + outSplits[outDim].push(splits[j + 1] + delta); + } + } + start = splits[start]; + limit = splits[limit]; + } + if (limit !== start) { + valueSlices.push([start, limit]); + numValues += limit - start; + } + } + + return {outSplits, valueSlices, numValues}; +} + +function getSplits(outSplits: number[][]) { + const splitsOut: TypedArray[] = []; + for (let i = 0; i < outSplits.length; ++i) { + const numSplits = outSplits[i].length; + const splits = util.getArrayFromDType('int32', numSplits) as TypedArray; + splitsOut.push(splits); + + outSplits[i].forEach((value, j: number) => splits[j] = value); + } + + return splitsOut; +} + +function computeFlatOuterDims(orig: number[], numOutDims: number) { + const outDims = orig.slice(0, numOutDims); + while (outDims.length < numOutDims) { + outDims.push(1); + } + + for (let inDim = numOutDims; inDim < orig.length; inDim++) { + outDims[numOutDims - 1] *= orig[inDim]; + } + + return outDims; +} +// For each slice in `(start, limit)` in `valueSlices`, append +// `paramsDenseValues[start,...,limit] to `values`. `valueSize` indicates +// the number of scalars contained in each value paramsDenseValues[i]. +function writeValueSlices( + paramsDenseValues: TypedArray, paramsDenseValuesShape: number[], + valueSlices: Array<[number, number]>, valueSize: number, values: TypedArray, + valuesShape: number[]) { + const denseM = computeFlatOuterDims(paramsDenseValuesShape, 2)[1]; + const valuesM = computeFlatOuterDims(valuesShape, 2)[1]; + + let outPos = 0; + for (const slice of valueSlices) { + for (let i = slice[0]; i < slice[1]; ++i) { + for (let j = 0; j < valueSize; ++j) { + values[outPos * valuesM + j] = paramsDenseValues[i * denseM + j]; + } + ++outPos; + } + } +} + +function getValues( + paramsDenseValues: TypedArray, paramsDenseValuesShape: number[], + paramsDenseValuesDType: DataType, valueSlices: Array<[number, number]>, + numValues: number): [TypedArray, number[]] { + const valuesShape = paramsDenseValuesShape.slice(); + valuesShape[0] = numValues; + + const valuesOut = util.getArrayFromDType( + paramsDenseValuesDType, + util.sizeFromShape(valuesShape)) as TypedArray; + + const numElements = paramsDenseValues.length; + const valueSize = + numElements === 0 ? 0 : (numElements / paramsDenseValuesShape[0]); + writeValueSlices( + paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize, + valuesOut, valuesShape); + + return [valuesOut, valuesShape]; +} +export function raggedGatherImpl( + paramsNestedSplits: TypedArray[], paramsNestedSplitsShapes: number[][], + paramsDenseValues: TypedArray, paramsDenseValuesShape: number[], + paramsDenseValuesDType: DataType, indices: TypedArray, + indicesShape: number[], + outputRaggedRank: number): [TypedArray[], TypedArray, number[]] { + if (paramsNestedSplits.length === 0) { + throw new Error('paramsNestedSplits must be non empty'); + } + + if (paramsNestedSplitsShapes[0].length === 0) { + throw new Error('Split tensors must not be scalars'); + } + const numParams = paramsNestedSplitsShapes[0][0] - 1; + validateIndices(indices, indicesShape, numParams); + + if (paramsDenseValuesShape.length === 0) { + throw new Error('params.rank must be nonzero'); + } + const numParamsDenseValues = paramsDenseValuesShape[0]; + + // Calculate the `splits`, and store the value slices that we need to + // copy in `valueSlices`. + const {outSplits, valueSlices, numValues} = makeSplits( + indices, indicesShape, paramsNestedSplits, numParamsDenseValues); + + // Write the output tensors. + const outputNestedSplits = getSplits(outSplits); + const outputDenseValues = getValues( + paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, + valueSlices, numValues); + + return [outputNestedSplits, outputDenseValues[0], outputDenseValues[1]]; +} diff --git a/tfjs-backend-cpu/src/register_all_kernels.ts b/tfjs-backend-cpu/src/register_all_kernels.ts index c1e6d95ec35..19e9e142c77 100644 --- a/tfjs-backend-cpu/src/register_all_kernels.ts +++ b/tfjs-backend-cpu/src/register_all_kernels.ts @@ -131,6 +131,7 @@ import {padV2Config} from './kernels/PadV2'; import {powConfig} from './kernels/Pow'; import {preluConfig} from './kernels/Prelu'; import {prodConfig} from './kernels/Prod'; +import {raggedGatherConfig} from './kernels/RaggedGather'; import {raggedTensorToTensorConfig} from './kernels/RaggedTensorToTensor'; import {rangeConfig} from './kernels/Range'; import {realConfig} from './kernels/Real'; @@ -300,6 +301,7 @@ const kernelConfigs: KernelConfig[] = [ powConfig, preluConfig, prodConfig, + raggedGatherConfig, raggedTensorToTensorConfig, rangeConfig, realConfig, diff --git a/tfjs-backend-cpu/src/shared.ts b/tfjs-backend-cpu/src/shared.ts index 2c66060b084..9171c934a92 100644 --- a/tfjs-backend-cpu/src/shared.ts +++ b/tfjs-backend-cpu/src/shared.ts @@ -41,6 +41,7 @@ export {multiplyImpl} from './kernels/Multiply'; export {negImpl} from './kernels/Neg'; export {notEqualImpl} from './kernels/NotEqual'; export {prodImpl} from './kernels/Prod'; +export {raggedGatherImpl} from './kernels/RaggedGather_impl'; export {raggedTensorToTensorImpl} from './kernels/RaggedTensorToTensor_impl'; export {rangeImpl} from './kernels/Range_impl'; export {rsqrtImpl} from './kernels/Rsqrt'; diff --git a/tfjs-backend-webgl/src/kernel_utils/shared.ts b/tfjs-backend-webgl/src/kernel_utils/shared.ts index f37e51185be..cf89341ef5a 100644 --- a/tfjs-backend-webgl/src/kernel_utils/shared.ts +++ b/tfjs-backend-webgl/src/kernel_utils/shared.ts @@ -53,6 +53,7 @@ const { negImpl: negImplCPU, notEqualImpl: notEqualImplCPU, prodImpl: prodImplCPU, + raggedGatherImpl: raggedGatherImplCPU, raggedTensorToTensorImpl: raggedTensorToTensorImplCPU, rangeImpl: rangeImplCPU, rsqrtImpl: rsqrtImplCPU, @@ -101,6 +102,7 @@ export { negImplCPU, notEqualImplCPU, prodImplCPU, + raggedGatherImplCPU, raggedTensorToTensorImplCPU, scatterImplCPU, sigmoidImplCPU, diff --git a/tfjs-backend-webgl/src/kernels/RaggedGather.ts b/tfjs-backend-webgl/src/kernels/RaggedGather.ts new file mode 100644 index 00000000000..f000815d614 --- /dev/null +++ b/tfjs-backend-webgl/src/kernels/RaggedGather.ts @@ -0,0 +1,58 @@ +/** + * @license + * Copyright 2022 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, KernelFunc, RaggedGather, RaggedGatherAttrs, RaggedGatherInputs, TensorInfo, TypedArray} from '@tensorflow/tfjs-core'; + +import {MathBackendWebGL} from '../backend_webgl'; +import {raggedGatherImplCPU} from '../kernel_utils/shared'; + +export function raggedGather(args: { + inputs: RaggedGatherInputs, + backend: MathBackendWebGL, + attrs: RaggedGatherAttrs +}): TensorInfo[] { + const {inputs, backend, attrs} = args; + const {paramsNestedSplits, paramsDenseValues, indices} = inputs; + const {outputRaggedRank} = attrs; + + const $paramsNestedSplits = + paramsNestedSplits.map(t => backend.readSync(t.dataId) as TypedArray); + const $paramsNestedSplitsShapes = paramsNestedSplits.map(t => t.shape); + const $paramsDenseValues = + backend.readSync(paramsDenseValues.dataId) as TypedArray; + const $indices = backend.readSync(indices.dataId) as TypedArray; + + const [outputNestedSplits, outputDenseValues, outputDenseValuesShape] = + raggedGatherImplCPU( + $paramsNestedSplits, $paramsNestedSplitsShapes, $paramsDenseValues, + paramsDenseValues.shape, paramsDenseValues.dtype, $indices, + indices.shape, outputRaggedRank); + + const outputNestedSplitsTensors = outputNestedSplits.map( + (splits) => backend.makeTensorInfo([splits.length], 'int32', splits)); + + const outputDenseValuesTensor = backend.makeTensorInfo( + outputDenseValuesShape, paramsDenseValues.dtype, outputDenseValues); + + return outputNestedSplitsTensors.concat([outputDenseValuesTensor]); +} + +export const raggedGatherConfig: KernelConfig = { + kernelName: RaggedGather, + backendName: 'webgl', + kernelFunc: raggedGather as {} as KernelFunc, +}; diff --git a/tfjs-backend-webgl/src/register_all_kernels.ts b/tfjs-backend-webgl/src/register_all_kernels.ts index 274f37a86b3..63f611dc3ad 100644 --- a/tfjs-backend-webgl/src/register_all_kernels.ts +++ b/tfjs-backend-webgl/src/register_all_kernels.ts @@ -127,6 +127,7 @@ import {padV2Config} from './kernels/PadV2'; import {powConfig} from './kernels/Pow'; import {preluConfig} from './kernels/Prelu'; import {prodConfig} from './kernels/Prod'; +import {raggedGatherConfig} from './kernels/RaggedGather'; import {raggedTensorToTensorConfig} from './kernels/RaggedTensorToTensor'; import {rangeConfig} from './kernels/Range'; import {realConfig} from './kernels/Real'; @@ -295,6 +296,7 @@ const kernelConfigs: KernelConfig[] = [ powConfig, preluConfig, prodConfig, + raggedGatherConfig, raggedTensorToTensorConfig, rangeConfig, realConfig, diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index f67faf69dfc..147f73b8699 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -659,6 +659,14 @@ export interface ProdAttrs { keepDims: boolean; } +export const RaggedGather = 'RaggedGather'; +export type RaggedGatherInputs = { + paramsNestedSplits: TensorInfo[] +}&Pick; +export interface RaggedGatherAttrs { + outputRaggedRank: number; +} + export const RaggedTensorToTensor = 'RaggedTensorToTensor'; export type RaggedTensorToTensorInputs = Pick& @@ -667,6 +675,7 @@ export interface RaggedTensorToTensorAttrs { rowPartitionTypes: string[]; } + export const Range = 'Range'; export interface RangeAttrs { start: number; diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 58d463e08ee..240a2e89264 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -135,6 +135,7 @@ export {pow} from './pow'; export {prelu} from './prelu'; export {print} from './print'; export {prod} from './prod'; +export {raggedGather} from './ragged_gather'; export {raggedTensorToTensor} from './ragged_tensor_to_tensor'; export {rand} from './rand'; export {randomGamma} from './random_gamma'; diff --git a/tfjs-core/src/ops/ragged_gather.ts b/tfjs-core/src/ops/ragged_gather.ts new file mode 100644 index 00000000000..d20d1241c8e --- /dev/null +++ b/tfjs-core/src/ops/ragged_gather.ts @@ -0,0 +1,75 @@ +/** + * @license + * Copyright 2022 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {RaggedGather, RaggedGatherAttrs, RaggedGatherInputs} from '../kernel_names'; +import {Tensor} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import {op} from './operation'; + +/** + * Gather ragged slices from params axis 0 according to indices. + * + * @param paramsNestedSplits: A list of at least 1 Tensor with type 'int32' The + * nestedRowSplits tensors that define the row-partitioning for the params + * RaggedTensor input. + * @param paramsDenseValues: A Tensor. The flatValues for the params + * RaggedTensor. + * @param indices: A Tensor. Must be one of type: int32. Indices in the + * outermost dimension of params of the values that should be gathered. + * @param outputRaggedRank: An int that is >= 0. The ragged rank of the output + * RaggedTensor. outputNestedSplits will contain this number of rowSplits + * tensors. This value should equal indices.shape.ndims + params.raggedRank + * - 1. + * @return A map with the following properties: + * - outputNestedSplits: A list of outputRaggedRank Tensor objects with the + * same type as paramsNestedSplits. + * - outputDenseValues: A Tensor. Has the same type as paramsDenseValues. + * @doc {heading: 'Operations', subheading: 'Ragged'} + */ + +interface RaggedGatherMap { + outputNestedSplits: Tensor[]; + outputDenseValues: Tensor; +} + +function raggedGather_( + paramsNestedSplits: Tensor[], paramsDenseValues: Tensor|TensorLike, + indices: Tensor|TensorLike, outputRaggedRank: number): RaggedGatherMap { + const $paramsNestedSplits = paramsNestedSplits.map( + (t, i) => convertToTensor(t, `tensors${i}`, 'raggedGather', 'int32')); + const $paramsDenseValues = + convertToTensor(paramsDenseValues, 'paramsDenseValues', 'raggedGather'); + const $indices = convertToTensor(indices, 'indices', 'raggedGather', 'int32'); + + const inputs: RaggedGatherInputs = { + paramsNestedSplits: $paramsNestedSplits, + paramsDenseValues: $paramsDenseValues, + indices: $indices, + }; + const attrs: RaggedGatherAttrs = {outputRaggedRank}; + + const result: Tensor[] = + ENGINE.runKernel(RaggedGather, inputs as {}, attrs as {}); + return { + outputNestedSplits: result.slice(0, result.length - 1), + outputDenseValues: result[result.length - 1], + }; +} + +export const raggedGather = op({raggedGather_}); diff --git a/tfjs-core/src/ops/ragged_gather_test.ts b/tfjs-core/src/ops/ragged_gather_test.ts new file mode 100644 index 00000000000..3e6d1b9fb5f --- /dev/null +++ b/tfjs-core/src/ops/ragged_gather_test.ts @@ -0,0 +1,179 @@ +/** + * @license + * Copyright 2022 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose, expectArraysEqual} from '../test_util'; + +function runRaggedGather( + indicesShape: number[], indices: number[], paramsNestedSplits: number[][], + paramsDenseValuesShape: number[], paramsDenseValues: number[]) { + const paramsRaggedRank = paramsNestedSplits.length; + const numSplits = paramsRaggedRank + indicesShape.length - 1; + + const paramsNestedSplitsTensors = + paramsNestedSplits.map(values => tf.tensor1d(values, 'int32')); + const paramsDenseValuesTensor = + tf.tensor(paramsDenseValues, paramsDenseValuesShape); + const indicesTensor = tf.tensor(indices, indicesShape, 'int32'); + + const output = tf.raggedGather( + paramsNestedSplitsTensors, paramsDenseValuesTensor, indicesTensor, + numSplits); + + tf.dispose(paramsNestedSplitsTensors); + tf.dispose([paramsDenseValuesTensor, indicesTensor]); + + expect(output.outputDenseValues.dtype).toEqual('float32'); + + output.outputNestedSplits.forEach(splits => { + expect(splits.dtype).toEqual('int32'); + expect(splits.shape.length).toEqual(1); + }); + + return { + outputDenseValues: output.outputDenseValues.dataSync(), + outputDenseValuesShape: output.outputDenseValues.shape, + outputNestedSplits: + output.outputNestedSplits.map(splits => splits.dataSync()), + tensors: output.outputNestedSplits.concat([output.outputDenseValues]) + }; +} + +describeWithFlags('raggedGather ', ALL_ENVS, () => { + it('RaggedGather', async () => { + const result = runRaggedGather( + [4], [2, 1, 0, 3], [[0, 3, 3, 7, 9]], [9], + [.1, .2, .3, .4, .5, .6, .7, .8, .9]); + + expect(result.outputNestedSplits.length).toEqual(1); + expectArraysClose(result.outputNestedSplits[0], [0, 4, 4, 7, 9]); + + expectArraysClose( + result.outputDenseValues, [.4, .5, .6, .7, .1, .2, .3, .8, .9]); + expectArraysEqual(result.outputDenseValuesShape, [9]); + }); + + it('RaggedGather3DParams', async () => { + const result = runRaggedGather( + [5], [2, 1, 0, 2, 3], [[0, 1, 3, 3, 5, 6], [0, 0, 2, 3, 5, 8, 9]], [9], + [.1, .2, .3, .4, .5, .6, .7, .8, .9]); + + expect(result.outputNestedSplits.length).toEqual(2); + expectArraysClose(result.outputNestedSplits[0], [0, 0, 2, 3, 3, 5]); + expectArraysClose(result.outputNestedSplits[1], [0, 2, 3, 3, 5, 8]); + + expectArraysClose( + result.outputDenseValues, [.1, .2, .3, .4, .5, .6, .7, .8]); + expectArraysEqual(result.outputDenseValuesShape, [8]); + }); + + it('RaggedGather4DParams', async () => { + const result = runRaggedGather( + [4], [2, 1, 0, 2], [[0, 1, 3, 3], [0, 0, 3, 4]], [4, 2], + [1, 2, 3, 4, 5, 6, 7, 8]); + + expect(result.outputNestedSplits.length).toEqual(2); + expectArraysClose(result.outputNestedSplits[0], [0, 0, 2, 3, 3]); + expectArraysClose(result.outputNestedSplits[1], [0, 3, 4, 4]); + + expectArraysClose(result.outputDenseValues, [1, 2, 3, 4, 5, 6, 7, 8]); + expectArraysEqual(result.outputDenseValuesShape, [4, 2]); + }); + + it('RaggedGather2DIndices', async () => { + const result = runRaggedGather( + [2, 2], [2, 1, 0, 3], [[0, 3, 3, 7, 9]], [9], + [.1, .2, .3, .4, .5, .6, .7, .8, .9]); + + expect(result.outputNestedSplits.length).toEqual(2); + expectArraysClose(result.outputNestedSplits[0], [0, 2, 4]); + expectArraysClose(result.outputNestedSplits[1], [0, 4, 4, 7, 9]); + + expectArraysClose( + result.outputDenseValues, [.4, .5, .6, .7, .1, .2, .3, .8, .9]); + expectArraysEqual(result.outputDenseValuesShape, [9]); + }); + + it('RaggedGatherScalarIndices', async () => { + const result = runRaggedGather( + [], [2], [[0, 3, 3, 7, 9]], [9], [.1, .2, .3, .4, .5, .6, .7, .8, .9]); + + expect(result.outputNestedSplits.length).toEqual(0); + + expectArraysClose(result.outputDenseValues, [.4, .5, .6, .7]); + expectArraysEqual(result.outputDenseValuesShape, [4]); + }); + + it('OutOfBounds', async () => { + expect( + () => runRaggedGather( + [2], [2, 10], [[0, 3, 3, 7, 9]], [9], + [.1, .2, .3, .4, .5, .6, .7, .8, .9])) + .toThrowError('indices[1] = 10 is not in [0, 4)'); + }); + + it('InvalidSplitsNotSorted', async () => { + expect( + () => runRaggedGather( + [2], [0, 2], [[0, 3, 5, 2, 9]], [9], + [.1, .2, .3, .4, .5, .6, .7, .8, .9])) + .toThrowError('Ragged splits must be sorted'); + }); + + it('InvalidSplitsNegative', async () => { + expect( + () => runRaggedGather( + [2], [0, 2], [[-1, 3, 2, 7, 9]], [9], + [.1, .2, .3, .4, .5, .6, .7, .8, .9])) + .toThrowError('Ragged splits must be non-negative'); + }); + + it('InvalidSplitsEmpty', async () => { + expect(() => runRaggedGather([0], [], [[]], [0], [])) + .toThrowError('Ragged splits may not be empty'); + }); + + it('InvalidSplitsTooBig', async () => { + expect( + () => runRaggedGather( + [2], [0, 2], [[0, 20, 40, 80, 100]], [9], + [.1, .2, .3, .4, .5, .6, .7, .8, .9])) + .toThrowError('Ragged splits must not point past values'); + }); + + it('BadValuesShape', async () => { + expect(() => runRaggedGather([0], [], [[0]], [], [.1])) + .toThrowError('params.rank must be nonzero'); + }); + + it('does not have memory leak.', async () => { + const beforeDataIds = tf.engine().backend.numDataIds(); + + const result = runRaggedGather( + [4], [2, 1, 0, 3], [[0, 3, 3, 7, 9]], [9], + [.1, .2, .3, .4, .5, .6, .7, .8, .9]); + + const afterResDataIds = tf.engine().backend.numDataIds(); + expect(afterResDataIds).toEqual(beforeDataIds + 2); + + result.tensors.map(tensor => tensor.dispose()); + + const afterDisposeDataIds = tf.engine().backend.numDataIds(); + expect(afterDisposeDataIds).toEqual(beforeDataIds); + }); +}); diff --git a/tfjs-node/src/run_tests.ts b/tfjs-node/src/run_tests.ts index d7457b2cd94..2a4049660c2 100644 --- a/tfjs-node/src/run_tests.ts +++ b/tfjs-node/src/run_tests.ts @@ -148,6 +148,7 @@ const IGNORE_LIST: string[] = [ // Node kernel for einsum is yet to be implemented. // See: ttps://github.com/tensorflow/tfjs/issues/2349 'einsum', + 'raggedGather', 'raggedTensorToTensor', 'searchSorted', 'sparseFillEmptyRows', From ae63a0477817f41b9915766c1e8dd568b617d5ef Mon Sep 17 00:00:00 2001 From: Ahmed Sabie Date: Tue, 6 Sep 2022 13:50:19 -0400 Subject: [PATCH 2/6] Improve error message --- tfjs-backend-cpu/src/kernels/RaggedGather_impl.ts | 2 +- tfjs-core/src/kernel_names.ts | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tfjs-backend-cpu/src/kernels/RaggedGather_impl.ts b/tfjs-backend-cpu/src/kernels/RaggedGather_impl.ts index 70e583630e3..e7580b4973b 100644 --- a/tfjs-backend-cpu/src/kernels/RaggedGather_impl.ts +++ b/tfjs-backend-cpu/src/kernels/RaggedGather_impl.ts @@ -50,7 +50,7 @@ function validateSplits( } for (let i = 1; i < splits.length; ++i) { if (splits[i - 1] > splits[i]) { - throw new Error('Ragged splits must be sorted'); + throw new Error('Ragged splits must be sorted in ascending order'); } } } diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 147f73b8699..e713f502015 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -675,7 +675,6 @@ export interface RaggedTensorToTensorAttrs { rowPartitionTypes: string[]; } - export const Range = 'Range'; export interface RangeAttrs { start: number; From 2e747c045502e63cc74e717f992d30c53c642ea4 Mon Sep 17 00:00:00 2001 From: Ahmed Sabie Date: Tue, 6 Sep 2022 15:40:20 -0400 Subject: [PATCH 3/6] Mark function as async --- tfjs-core/src/ops/ragged_gather_test.ts | 28 ++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tfjs-core/src/ops/ragged_gather_test.ts b/tfjs-core/src/ops/ragged_gather_test.ts index 3e6d1b9fb5f..6d7424295ef 100644 --- a/tfjs-core/src/ops/ragged_gather_test.ts +++ b/tfjs-core/src/ops/ragged_gather_test.ts @@ -19,7 +19,7 @@ import * as tf from '../index'; import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; import {expectArraysClose, expectArraysEqual} from '../test_util'; -function runRaggedGather( +async function runRaggedGather( indicesShape: number[], indices: number[], paramsNestedSplits: number[][], paramsDenseValuesShape: number[], paramsDenseValues: number[]) { const paramsRaggedRank = paramsNestedSplits.length; @@ -46,7 +46,7 @@ function runRaggedGather( }); return { - outputDenseValues: output.outputDenseValues.dataSync(), + outputDenseValues: await output.outputDenseValues.data(), outputDenseValuesShape: output.outputDenseValues.shape, outputNestedSplits: output.outputNestedSplits.map(splits => splits.dataSync()), @@ -56,7 +56,7 @@ function runRaggedGather( describeWithFlags('raggedGather ', ALL_ENVS, () => { it('RaggedGather', async () => { - const result = runRaggedGather( + const result = await runRaggedGather( [4], [2, 1, 0, 3], [[0, 3, 3, 7, 9]], [9], [.1, .2, .3, .4, .5, .6, .7, .8, .9]); @@ -69,7 +69,7 @@ describeWithFlags('raggedGather ', ALL_ENVS, () => { }); it('RaggedGather3DParams', async () => { - const result = runRaggedGather( + const result = await runRaggedGather( [5], [2, 1, 0, 2, 3], [[0, 1, 3, 3, 5, 6], [0, 0, 2, 3, 5, 8, 9]], [9], [.1, .2, .3, .4, .5, .6, .7, .8, .9]); @@ -83,7 +83,7 @@ describeWithFlags('raggedGather ', ALL_ENVS, () => { }); it('RaggedGather4DParams', async () => { - const result = runRaggedGather( + const result = await runRaggedGather( [4], [2, 1, 0, 2], [[0, 1, 3, 3], [0, 0, 3, 4]], [4, 2], [1, 2, 3, 4, 5, 6, 7, 8]); @@ -96,7 +96,7 @@ describeWithFlags('raggedGather ', ALL_ENVS, () => { }); it('RaggedGather2DIndices', async () => { - const result = runRaggedGather( + const result = await runRaggedGather( [2, 2], [2, 1, 0, 3], [[0, 3, 3, 7, 9]], [9], [.1, .2, .3, .4, .5, .6, .7, .8, .9]); @@ -110,7 +110,7 @@ describeWithFlags('raggedGather ', ALL_ENVS, () => { }); it('RaggedGatherScalarIndices', async () => { - const result = runRaggedGather( + const result = await runRaggedGather( [], [2], [[0, 3, 3, 7, 9]], [9], [.1, .2, .3, .4, .5, .6, .7, .8, .9]); expect(result.outputNestedSplits.length).toEqual(0); @@ -121,7 +121,7 @@ describeWithFlags('raggedGather ', ALL_ENVS, () => { it('OutOfBounds', async () => { expect( - () => runRaggedGather( + () => await runRaggedGather( [2], [2, 10], [[0, 3, 3, 7, 9]], [9], [.1, .2, .3, .4, .5, .6, .7, .8, .9])) .toThrowError('indices[1] = 10 is not in [0, 4)'); @@ -129,7 +129,7 @@ describeWithFlags('raggedGather ', ALL_ENVS, () => { it('InvalidSplitsNotSorted', async () => { expect( - () => runRaggedGather( + () => await runRaggedGather( [2], [0, 2], [[0, 3, 5, 2, 9]], [9], [.1, .2, .3, .4, .5, .6, .7, .8, .9])) .toThrowError('Ragged splits must be sorted'); @@ -137,34 +137,34 @@ describeWithFlags('raggedGather ', ALL_ENVS, () => { it('InvalidSplitsNegative', async () => { expect( - () => runRaggedGather( + () => await runRaggedGather( [2], [0, 2], [[-1, 3, 2, 7, 9]], [9], [.1, .2, .3, .4, .5, .6, .7, .8, .9])) .toThrowError('Ragged splits must be non-negative'); }); it('InvalidSplitsEmpty', async () => { - expect(() => runRaggedGather([0], [], [[]], [0], [])) + expect(() => await runRaggedGather([0], [], [[]], [0], [])) .toThrowError('Ragged splits may not be empty'); }); it('InvalidSplitsTooBig', async () => { expect( - () => runRaggedGather( + () => await runRaggedGather( [2], [0, 2], [[0, 20, 40, 80, 100]], [9], [.1, .2, .3, .4, .5, .6, .7, .8, .9])) .toThrowError('Ragged splits must not point past values'); }); it('BadValuesShape', async () => { - expect(() => runRaggedGather([0], [], [[0]], [], [.1])) + expect(() => await runRaggedGather([0], [], [[0]], [], [.1])) .toThrowError('params.rank must be nonzero'); }); it('does not have memory leak.', async () => { const beforeDataIds = tf.engine().backend.numDataIds(); - const result = runRaggedGather( + const result = await runRaggedGather( [4], [2, 1, 0, 3], [[0, 3, 3, 7, 9]], [9], [.1, .2, .3, .4, .5, .6, .7, .8, .9]); From 03710e8b1a2aa47d00434e19db850a08855ea098 Mon Sep 17 00:00:00 2001 From: Ahmed Sabie Date: Tue, 6 Sep 2022 15:51:11 -0400 Subject: [PATCH 4/6] Mark function as async --- tfjs-core/src/ops/ragged_gather_test.ts | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tfjs-core/src/ops/ragged_gather_test.ts b/tfjs-core/src/ops/ragged_gather_test.ts index 6d7424295ef..739d85dd2b2 100644 --- a/tfjs-core/src/ops/ragged_gather_test.ts +++ b/tfjs-core/src/ops/ragged_gather_test.ts @@ -121,7 +121,7 @@ describeWithFlags('raggedGather ', ALL_ENVS, () => { it('OutOfBounds', async () => { expect( - () => await runRaggedGather( + () => runRaggedGather( [2], [2, 10], [[0, 3, 3, 7, 9]], [9], [.1, .2, .3, .4, .5, .6, .7, .8, .9])) .toThrowError('indices[1] = 10 is not in [0, 4)'); @@ -129,7 +129,7 @@ describeWithFlags('raggedGather ', ALL_ENVS, () => { it('InvalidSplitsNotSorted', async () => { expect( - () => await runRaggedGather( + () => runRaggedGather( [2], [0, 2], [[0, 3, 5, 2, 9]], [9], [.1, .2, .3, .4, .5, .6, .7, .8, .9])) .toThrowError('Ragged splits must be sorted'); @@ -137,27 +137,27 @@ describeWithFlags('raggedGather ', ALL_ENVS, () => { it('InvalidSplitsNegative', async () => { expect( - () => await runRaggedGather( + () => runRaggedGather( [2], [0, 2], [[-1, 3, 2, 7, 9]], [9], [.1, .2, .3, .4, .5, .6, .7, .8, .9])) .toThrowError('Ragged splits must be non-negative'); }); it('InvalidSplitsEmpty', async () => { - expect(() => await runRaggedGather([0], [], [[]], [0], [])) + expect(() => runRaggedGather([0], [], [[]], [0], [])) .toThrowError('Ragged splits may not be empty'); }); it('InvalidSplitsTooBig', async () => { expect( - () => await runRaggedGather( + () => runRaggedGather( [2], [0, 2], [[0, 20, 40, 80, 100]], [9], [.1, .2, .3, .4, .5, .6, .7, .8, .9])) .toThrowError('Ragged splits must not point past values'); }); it('BadValuesShape', async () => { - expect(() => await runRaggedGather([0], [], [[0]], [], [.1])) + expect(() => runRaggedGather([0], [], [[0]], [], [.1])) .toThrowError('params.rank must be nonzero'); }); From 05f6486324c4fdf991b201c548125853e0b8648b Mon Sep 17 00:00:00 2001 From: Ahmed Sabie Date: Wed, 7 Sep 2022 14:14:34 -0400 Subject: [PATCH 5/6] Fix expect call for async --- tfjs-core/src/ops/ragged_gather_test.ts | 43 +++++++++++-------------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/tfjs-core/src/ops/ragged_gather_test.ts b/tfjs-core/src/ops/ragged_gather_test.ts index 739d85dd2b2..dc2f3c0f4a9 100644 --- a/tfjs-core/src/ops/ragged_gather_test.ts +++ b/tfjs-core/src/ops/ragged_gather_test.ts @@ -120,45 +120,40 @@ describeWithFlags('raggedGather ', ALL_ENVS, () => { }); it('OutOfBounds', async () => { - expect( - () => runRaggedGather( - [2], [2, 10], [[0, 3, 3, 7, 9]], [9], - [.1, .2, .3, .4, .5, .6, .7, .8, .9])) - .toThrowError('indices[1] = 10 is not in [0, 4)'); + await expectAsync(runRaggedGather([2], [2, 10], [[0, 3, 3, 7, 9]], [9], [ + .1, .2, .3, .4, .5, .6, .7, .8, .9 + ])).toBeRejectedWithError('indices[1] = 10 is not in [0, 4)'); }); it('InvalidSplitsNotSorted', async () => { - expect( - () => runRaggedGather( - [2], [0, 2], [[0, 3, 5, 2, 9]], [9], - [.1, .2, .3, .4, .5, .6, .7, .8, .9])) - .toThrowError('Ragged splits must be sorted'); + await expectAsync(runRaggedGather( + [2], [0, 2], [[0, 3, 5, 2, 9]], [9], + [.1, .2, .3, .4, .5, .6, .7, .8, .9])) + .toBeRejectedWithError( + 'Ragged splits must be sorted in ascending order'); }); it('InvalidSplitsNegative', async () => { - expect( - () => runRaggedGather( - [2], [0, 2], [[-1, 3, 2, 7, 9]], [9], - [.1, .2, .3, .4, .5, .6, .7, .8, .9])) - .toThrowError('Ragged splits must be non-negative'); + await expectAsync(runRaggedGather([2], [0, 2], [[-1, 3, 2, 7, 9]], [9], [ + .1, .2, .3, .4, .5, .6, .7, .8, .9 + ])).toBeRejectedWithError('Ragged splits must be non-negative'); }); it('InvalidSplitsEmpty', async () => { - expect(() => runRaggedGather([0], [], [[]], [0], [])) - .toThrowError('Ragged splits may not be empty'); + await expectAsync(runRaggedGather([0], [], [[]], [0], [])) + .toBeRejectedWithError('Ragged splits may not be empty'); }); it('InvalidSplitsTooBig', async () => { - expect( - () => runRaggedGather( - [2], [0, 2], [[0, 20, 40, 80, 100]], [9], - [.1, .2, .3, .4, .5, .6, .7, .8, .9])) - .toThrowError('Ragged splits must not point past values'); + await expectAsync(runRaggedGather( + [2], [0, 2], [[0, 20, 40, 80, 100]], [9], + [.1, .2, .3, .4, .5, .6, .7, .8, .9])) + .toBeRejectedWithError('Ragged splits must not point past values'); }); it('BadValuesShape', async () => { - expect(() => runRaggedGather([0], [], [[0]], [], [.1])) - .toThrowError('params.rank must be nonzero'); + await expectAsync(runRaggedGather([0], [], [[0]], [], [.1])) + .toBeRejectedWithError('params.rank must be nonzero'); }); it('does not have memory leak.', async () => { From 59bc169ade7f81b45bbd7f076480494ecf1e4347 Mon Sep 17 00:00:00 2001 From: Ahmed Sabie Date: Wed, 7 Sep 2022 14:38:45 -0400 Subject: [PATCH 6/6] Fix datasync call --- tfjs-core/src/ops/ragged_gather_test.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tfjs-core/src/ops/ragged_gather_test.ts b/tfjs-core/src/ops/ragged_gather_test.ts index dc2f3c0f4a9..954f5285971 100644 --- a/tfjs-core/src/ops/ragged_gather_test.ts +++ b/tfjs-core/src/ops/ragged_gather_test.ts @@ -48,8 +48,8 @@ async function runRaggedGather( return { outputDenseValues: await output.outputDenseValues.data(), outputDenseValuesShape: output.outputDenseValues.shape, - outputNestedSplits: - output.outputNestedSplits.map(splits => splits.dataSync()), + outputNestedSplits: await Promise.all( + output.outputNestedSplits.map(splits => splits.data())), tensors: output.outputNestedSplits.concat([output.outputDenseValues]) }; }