Skip to content

Commit

Permalink
[wasm] Add Multinomial kernel (#7468)
Browse files Browse the repository at this point in the history
* Add multinomial kernel

* Add test

* Fix mt19937 seed type

* Fix discrete_distribution weight type

* Fix fill

* Ignore seed test in tfjs-node
  • Loading branch information
chunnienc authored Mar 13, 2023
1 parent e85305e commit ff6739d
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 2 deletions.
10 changes: 10 additions & 0 deletions tfjs-backend-wasm/src/cc/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ tfjs_cc_library(
":Min",
":Minimum",
":MirrorPad",
":Multinomial",
":Multiply",
":Neg",
":NonMaxSuppressionV3",
Expand Down Expand Up @@ -1183,6 +1184,15 @@ tfjs_cc_library(
],
)

tfjs_cc_library(
name = "Multinomial",
srcs = ["kernels/Multinomial.cc"],
deps = [
":backend",
":shape",
],
)

tfjs_cc_library(
name = "Neg",
srcs = ["kernels/Neg.cc"],
Expand Down
62 changes: 62 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/Multinomial.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include <random>

#include "tfjs-backend-wasm/src/cc/backend.h"
#include "tfjs-backend-wasm/src/cc/shape.h"

namespace tfjs::wasm {

// We use C-style API to interface with Javascript.
extern "C" {

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif

// REQUIRES
// - Tensor `probabilities` is produced from normalized `logits`.
// - Tensor `probabilities` must have dtype float32.
// - Tensor `out` must have dtype int32.
// - Tensor `probabilities` must have shape [batch_size, num_events].
// - Tensor `out` must have shape [batch_size, num_samples].
void Multinomial(const int probabilities_id, const int batch_size,
const int num_events, const int num_samples, const float seed,
const int out_id) {
const TensorInfo& prob_info = backend::get_tensor_info(probabilities_id);
TensorInfo& out_info = backend::get_tensor_info_out(out_id);
Shape<int, 2> probs_shape({batch_size, num_events});
Shape<int, 2> out_shape({batch_size, num_samples});
const float* probs_buf = prob_info.f32();
int* out_buf = out_info.i32_write();

std::mt19937 gen(*reinterpret_cast<const int32_t*>(&seed));
for (int b = 0; b < batch_size; ++b) {
const float* weights_begin = probs_buf + probs_shape.offset({b, 0});
const float* weights_end = weights_begin + num_events;
std::discrete_distribution<int32_t> distribution(weights_begin,
weights_end);
for (int i = 0; i < num_samples; ++i) {
out_buf[out_shape.offset({b, i})] = distribution(gen);
}
}
}

} // extern "C"
} // namespace tfjs::wasm
80 changes: 80 additions & 0 deletions tfjs-backend-wasm/src/kernels/Multinomial.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {KernelConfig, KernelFunc, Multinomial, MultinomialAttrs, MultinomialInputs, TensorInfo} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';
import {softmax} from './Softmax';

let wasmMultinomial: (
probabilitiesId: number, batchSize: number, numEvents: number,
numSamples: number, seed: number, outId: number) => void;

function setup(backend: BackendWasm) {
wasmMultinomial = backend.wasm.cwrap(Multinomial, null, [
'number', // probabilitiesId
'number', // batchSize
'number', // numEvents
'number', // numSamples
'number', // seed
'number', // outId
]);
}

export function multinomial(args: {
inputs: MultinomialInputs,
attrs: MultinomialAttrs,
backend: BackendWasm,
}): TensorInfo {
const {inputs, backend, attrs} = args;
const {logits} = inputs;
const {numSamples, seed, normalized} = attrs;

if (logits.dtype !== 'float32') {
throw new Error(
`Tensor logits must have dtype float32, got ${logits.dtype}`);
}

const probabilities = normalized ? logits : softmax({
inputs: {logits},
backend,
attrs: {dim: logits.shape.length - 1},
});

const [batchSize, numEvents] = probabilities.shape;
const out = backend.makeOutput([batchSize, numSamples], 'int32');

wasmMultinomial(
backend.dataIdMap.get(probabilities.dataId).id,
batchSize,
numEvents,
numSamples,
seed,
backend.dataIdMap.get(out.dataId).id,
);
if (!normalized) {
backend.disposeData(probabilities.dataId);
}
return out;
}

export const multinomialConfig: KernelConfig = {
kernelName: Multinomial,
backendName: 'wasm',
setupFunc: setup,
kernelFunc: multinomial as unknown as KernelFunc
};
2 changes: 1 addition & 1 deletion tfjs-backend-wasm/src/kernels/Softmax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function setup(backend: BackendWasm): void {
]);
}

function softmax(
export function softmax(
args: {backend: BackendWasm, inputs: SoftmaxInputs, attrs: SoftmaxAttrs}):
TensorInfo {
const {backend, inputs: {logits}, attrs: {dim}} = args;
Expand Down
2 changes: 2 additions & 0 deletions tfjs-backend-wasm/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ import {meanConfig} from './kernels/Mean';
import {minConfig} from './kernels/Min';
import {minimumConfig} from './kernels/Minimum';
import {mirrorPadConfig} from './kernels/MirrorPad';
import {multinomialConfig} from './kernels/Multinomial';
import {multiplyConfig} from './kernels/Multiply';
import {negConfig} from './kernels/Neg';
import {nonMaxSuppressionV3Config} from './kernels/NonMaxSuppressionV3';
Expand Down Expand Up @@ -247,6 +248,7 @@ const kernelConfigs: KernelConfig[] = [
minConfig,
minimumConfig,
mirrorPadConfig,
multinomialConfig,
multiplyConfig,
negConfig,
nonMaxSuppressionV3Config,
Expand Down
1 change: 1 addition & 0 deletions tfjs-backend-wasm/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ const TEST_FILTERS: TestFilter[] = [
{include: 'linspace'},
{include: 'bincount'},
{include: 'expm1 '},
{include: 'multinomial'},
];

const customInclude = (testName: string) => {
Expand Down
4 changes: 3 additions & 1 deletion tfjs-core/src/ops/fill.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {Fill, FillAttrs} from '../kernel_names';
import {NamedAttrMap} from '../kernel_registry';
import {Tensor} from '../tensor';
import {DataType, Rank, ShapeMap} from '../types';
import {inferDtype} from '../util';
import {assertNonNegativeIntegerDimensions} from '../util_base';

/**
Expand All @@ -32,14 +33,15 @@ import {assertNonNegativeIntegerDimensions} from '../util_base';
* @param shape An array of integers defining the output tensor shape.
* @param value The scalar value to fill the tensor with.
* @param dtype The type of an element in the resulting tensor. Defaults to
* 'float'.
* 'float32' if the given param value is a number, otherwise 'string'.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function fill<R extends Rank>(
shape: ShapeMap[R], value: number|string, dtype?: DataType): Tensor<R> {
assertNonNegativeIntegerDimensions(shape);

dtype = dtype || inferDtype(value);
const attrs: FillAttrs = {shape, value, dtype};

return ENGINE.runKernel(Fill, {}, attrs as unknown as NamedAttrMap);
Expand Down
6 changes: 6 additions & 0 deletions tfjs-core/src/ops/multinomial_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ describeWithFlags('multinomial', ALL_ENVS, () => {
expectArraysClose(outcomeProbs, [0, 1], EPSILON);
});

it('creates the same data given the same seed', async () => {
const res1 = tf.multinomial([1, 2, 3, 4], NUM_SAMPLES, SEED);
const res2 = tf.multinomial([1, 2, 3, 4], NUM_SAMPLES, SEED);
expectArraysClose(await res1.data(), await res2.data());
});

function computeProbs(
events: Float32Array|Uint8Array|Int32Array, numOutcomes: number) {
const counts = [];
Expand Down
1 change: 1 addition & 0 deletions tfjs-node/src/run_tests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ const IGNORE_LIST: string[] = [
// upperBound and lowerBound use SearchSorted, which is unsupported
'upperBound',
'lowerBound',
'multinomial test-tensorflow {} creates the same data given the same seed',
];

if (process.platform === 'win32') {
Expand Down

0 comments on commit ff6739d

Please sign in to comment.