diff --git a/tfjs-backend-cpu/src/kernels/OneHot.ts b/tfjs-backend-cpu/src/kernels/OneHot.ts index f3e97bcc401..a5127765900 100644 --- a/tfjs-backend-cpu/src/kernels/OneHot.ts +++ b/tfjs-backend-cpu/src/kernels/OneHot.ts @@ -25,7 +25,7 @@ export function oneHot( TensorInfo { const {inputs, backend, attrs} = args; const {indices} = inputs; - const {depth, onValue, offValue} = attrs; + const {dtype, depth, onValue, offValue} = attrs; assertNotComplex(indices, 'oneHot'); @@ -41,7 +41,7 @@ export function oneHot( } } - return backend.makeTensorInfo([...indices.shape, depth], 'int32', res); + return backend.makeTensorInfo([...indices.shape, depth], dtype, res); } export const oneHotConfig: KernelConfig = { diff --git a/tfjs-backend-wasm/src/kernels/OneHot.ts b/tfjs-backend-wasm/src/kernels/OneHot.ts index b0e29bd3138..932f7551424 100644 --- a/tfjs-backend-wasm/src/kernels/OneHot.ts +++ b/tfjs-backend-wasm/src/kernels/OneHot.ts @@ -37,9 +37,9 @@ function oneHot( args: {inputs: OneHotInputs, attrs: OneHotAttrs, backend: BackendWasm}) { const {inputs, backend, attrs} = args; const {indices} = inputs; - const {depth, onValue, offValue} = attrs; + const {dtype, depth, onValue, offValue} = attrs; - const out = backend.makeOutput([...indices.shape, depth], 'int32'); + const out = backend.makeOutput([...indices.shape, depth], dtype); const outId = backend.dataIdMap.get(out.dataId).id; const indicesData = backend.dataIdMap.get(indices.dataId); diff --git a/tfjs-backend-webgl/src/kernels/OneHot.ts b/tfjs-backend-webgl/src/kernels/OneHot.ts index 49cde5f9dc0..496b5ab2166 100644 --- a/tfjs-backend-webgl/src/kernels/OneHot.ts +++ b/tfjs-backend-webgl/src/kernels/OneHot.ts @@ -28,13 +28,13 @@ export const oneHot = (args: { }): TensorInfo => { const {inputs, backend, attrs} = args; const {indices} = inputs; - const {depth, onValue, offValue} = attrs; + const {dtype, depth, onValue, offValue} = attrs; const indicesSize = util.sizeFromShape(indices.shape); const program = new OneHotProgram(indicesSize, depth, onValue, offValue); const reshaped = reshape({inputs: {x: indices}, backend, attrs: {shape: [indicesSize]}}); - const result = backend.runWebGLProgram(program, [reshaped], indices.dtype); + const result = backend.runWebGLProgram(program, [reshaped], dtype); backend.disposeIntermediateTensorInfo(reshaped); const outShape = [...indices.shape, depth]; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 0af48807d8b..f67faf69dfc 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -627,6 +627,7 @@ export interface OneHotAttrs { depth: number; onValue: number; offValue: number; + dtype: DataType; } export const Pack = 'Pack'; diff --git a/tfjs-core/src/ops/one_hot.ts b/tfjs-core/src/ops/one_hot.ts index 4ac06df68a7..b4dffbc03ae 100644 --- a/tfjs-core/src/ops/one_hot.ts +++ b/tfjs-core/src/ops/one_hot.ts @@ -21,7 +21,7 @@ import {NamedAttrMap} from '../kernel_registry'; import {Tensor} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; -import {TensorLike} from '../types'; +import {DataType, TensorLike} from '../types'; import {op} from './operation'; @@ -29,16 +29,16 @@ import {op} from './operation'; * Creates a one-hot `tf.Tensor`. The locations represented by `indices` take * value `onValue` (defaults to 1), while all other locations take value * `offValue` (defaults to 0). If `indices` is rank `R`, the output has rank - * `R+1` with the last axis of size `depth`. + * `R+1` with the last axis of size `depth`. * `indices` used to encode prediction class must start from 0. For example, * if you have 3 classes of data, class 1 should be encoded as 0, class 2 - * should be 1, and class 3 should be 2. + * should be 1, and class 3 should be 2. * * ```js * tf.oneHot(tf.tensor1d([0, 1], 'int32'), 3).print(); * ``` * - * @param indices `tf.Tensor` of indices with dtype `int32`. Indices must + * @param indices `tf.Tensor` of indices with dtype `int32`. Indices must * start from 0. * @param depth The depth of the one hot dimension. * @param onValue A number used to fill in the output when the index matches @@ -49,15 +49,15 @@ import {op} from './operation'; * @doc {heading: 'Tensors', subheading: 'Creation'} */ function oneHot_( - indices: Tensor|TensorLike, depth: number, onValue = 1, - offValue = 0): Tensor { + indices: Tensor|TensorLike, depth: number, onValue = 1, offValue = 0, + dtype: DataType = 'int32'): Tensor { if (depth < 2) { throw new Error(`Error in oneHot: depth must be >=2, but it is ${depth}`); } const $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32'); const inputs: OneHotInputs = {indices: $indices}; - const attrs: OneHotAttrs = {depth, onValue, offValue}; + const attrs: OneHotAttrs = {dtype, depth, onValue, offValue}; return ENGINE.runKernel( OneHot, inputs as unknown as NamedTensorMap, diff --git a/tfjs-core/src/ops/one_hot_test.ts b/tfjs-core/src/ops/one_hot_test.ts index bc54feab19a..4cdfde97474 100644 --- a/tfjs-core/src/ops/one_hot_test.ts +++ b/tfjs-core/src/ops/one_hot_test.ts @@ -100,6 +100,14 @@ describeWithFlags('oneHot', ALL_ENVS, () => { expect(res.dtype).toEqual(expectedType); }); + it('check specified output dtype', () => { + const expectedType = 'float32'; + const indices = tf.tensor1d([0, 1], 'int32'); + const res = tf.oneHot(indices, 2, 1, 0, 'float32'); + + expect(res.dtype).toEqual(expectedType); + }); + it('oneHot accepts a tensor-like object', async () => { const res = tf.oneHot([0, 1], 2); expect(res.shape).toEqual([2, 2]); diff --git a/tfjs-node/src/kernels/OneHot.ts b/tfjs-node/src/kernels/OneHot.ts index 331f4e64aa2..6168992fb4b 100644 --- a/tfjs-node/src/kernels/OneHot.ts +++ b/tfjs-node/src/kernels/OneHot.ts @@ -25,15 +25,15 @@ export const oneHotConfig: KernelConfig = { kernelFunc: (args) => { const {indices} = args.inputs as OneHotInputs; const backend = args.backend as NodeJSKernelBackend; - const {depth, onValue, offValue} = args.attrs as {} as OneHotAttrs; + const {dtype, depth, onValue, offValue} = args.attrs as {} as OneHotAttrs; const depthTensor = scalar(depth, 'int32'); - const onValueTensor = scalar(onValue, 'int32'); - const offValueTensor = scalar(offValue, 'int32'); + const onValueTensor = scalar(onValue, dtype); + const offValueTensor = scalar(offValue, dtype); const opAttrs = [ {name: 'axis', type: backend.binding.TF_ATTR_INT, value: -1}, - createTensorsTypeOpAttr('T', indices.dtype), + createTensorsTypeOpAttr('T', dtype), createTensorsTypeOpAttr('TI', indices.dtype) ];