From 9a580c52d172c8d2a91a914efbe6ca936793fea6 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Tue, 23 Aug 2022 09:58:23 -0700 Subject: [PATCH] [converter] added dtype support for oneHot for converter (#6782) * added dtype support for oneHot for converter * update the doc for onehot --- .../python/tensorflowjs/op_list/creation.json | 5 ++--- .../src/operations/executors/creation_executor.ts | 8 +++++--- .../operations/executors/creation_executor_test.ts | 11 +++++++---- tfjs-core/src/ops/one_hot.ts | 1 + 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/tfjs-converter/python/tensorflowjs/op_list/creation.json b/tfjs-converter/python/tensorflowjs/op_list/creation.json index 9ad1996f6bc..10991469340 100644 --- a/tfjs-converter/python/tensorflowjs/op_list/creation.json +++ b/tfjs-converter/python/tensorflowjs/op_list/creation.json @@ -88,8 +88,7 @@ { "tfName": "T", "name": "dtype", - "type": "dtype", - "notSupported": true + "type": "dtype" } ] }, @@ -366,4 +365,4 @@ } ] } -] \ No newline at end of file +] diff --git a/tfjs-converter/src/operations/executors/creation_executor.ts b/tfjs-converter/src/operations/executors/creation_executor.ts index f4e83345578..8fda4f1d7d5 100644 --- a/tfjs-converter/src/operations/executors/creation_executor.ts +++ b/tfjs-converter/src/operations/executors/creation_executor.ts @@ -26,8 +26,8 @@ import {InternalOpExecutor, Node} from '../types'; import {getParamValue} from './utils'; export const executeOp: InternalOpExecutor = - (node: Node, tensorMap: NamedTensorsMap, - context: ExecutionContext, ops = tfOps): Tensor[] => { + (node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext, + ops = tfOps): Tensor[] => { switch (node.op) { case 'Fill': { const shape = @@ -64,7 +64,9 @@ export const executeOp: InternalOpExecutor = getParamValue('onValue', node, tensorMap, context) as number; const offValue = getParamValue('offValue', node, tensorMap, context) as number; - return [ops.oneHot(indices, depth, onValue, offValue)]; + const dtype = + getParamValue('dtype', node, tensorMap, context) as DataType; + return [ops.oneHot(indices, depth, onValue, offValue, dtype)]; } case 'Ones': { return [ops.ones( diff --git a/tfjs-converter/src/operations/executors/creation_executor_test.ts b/tfjs-converter/src/operations/executors/creation_executor_test.ts index b9ceefd4708..b3609980062 100644 --- a/tfjs-converter/src/operations/executors/creation_executor_test.ts +++ b/tfjs-converter/src/operations/executors/creation_executor_test.ts @@ -22,8 +22,8 @@ import * as creation from '../op_list/creation'; import {Node} from '../types'; import {executeOp} from './creation_executor'; +import {RecursiveSpy, spyOnAllFunctions} from './spy_ops'; import {createDtypeAttr, createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttrFromIndex, createTensorAttr, validateParam} from './test_helper'; -import {spyOnAllFunctions, RecursiveSpy} from './spy_ops'; describe('creation', () => { let node: Node; @@ -99,15 +99,17 @@ describe('creation', () => { node.inputParams['depth'] = createNumberAttrFromIndex(1); node.inputParams['onValue'] = createNumberAttrFromIndex(2); node.inputParams['offValue'] = createNumberAttrFromIndex(3); + node.attrParams['dtype'] = createDtypeAttr('float32'); node.inputNames = ['input', 'input2', 'input3', 'input4']; const input = [tfOps.tensor1d([0])]; const input3 = [tfOps.scalar(2)]; const input4 = [tfOps.scalar(3)]; spyOps.oneHot.and.returnValue({}); - executeOp(node, {input, input2, input3, input4}, context, - spyOpsAsTfOps); + executeOp( + node, {input, input2, input3, input4}, context, spyOpsAsTfOps); - expect(spyOps.oneHot).toHaveBeenCalledWith(input[0], 1, 2, 3); + expect(spyOps.oneHot) + .toHaveBeenCalledWith(input[0], 1, 2, 3, 'float32'); }); it('should match json def', () => { node.op = 'OneHot'; @@ -115,6 +117,7 @@ describe('creation', () => { node.inputParams['depth'] = createNumberAttrFromIndex(1); node.inputParams['onValue'] = createNumberAttrFromIndex(2); node.inputParams['offValue'] = createNumberAttrFromIndex(3); + node.attrParams['dtype'] = createDtypeAttr('float32'); expect(validateParam(node, creation.json)).toBeTruthy(); }); diff --git a/tfjs-core/src/ops/one_hot.ts b/tfjs-core/src/ops/one_hot.ts index b4dffbc03ae..e77fd15ce20 100644 --- a/tfjs-core/src/ops/one_hot.ts +++ b/tfjs-core/src/ops/one_hot.ts @@ -45,6 +45,7 @@ import {op} from './operation'; * the location. * @param offValue A number used to fill in the output when the index does * not match the location. + * @param dtype The dtype of the output tensor, default to 'int32'. * * @doc {heading: 'Tensors', subheading: 'Creation'} */