Skip to content

Commit

Permalink
[converter] added dtype support for oneHot for converter (#6782)
Browse files Browse the repository at this point in the history
* added dtype support for oneHot for converter

* update the doc for onehot
  • Loading branch information
pyu10055 authored Aug 23, 2022
1 parent 25cd4f4 commit 9a580c5
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 10 deletions.
5 changes: 2 additions & 3 deletions tfjs-converter/python/tensorflowjs/op_list/creation.json
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@
{
"tfName": "T",
"name": "dtype",
"type": "dtype",
"notSupported": true
"type": "dtype"
}
]
},
Expand Down Expand Up @@ -366,4 +365,4 @@
}
]
}
]
]
8 changes: 5 additions & 3 deletions tfjs-converter/src/operations/executors/creation_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ import {InternalOpExecutor, Node} from '../types';
import {getParamValue} from './utils';

export const executeOp: InternalOpExecutor =
(node: Node, tensorMap: NamedTensorsMap,
context: ExecutionContext, ops = tfOps): Tensor[] => {
(node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext,
ops = tfOps): Tensor[] => {
switch (node.op) {
case 'Fill': {
const shape =
Expand Down Expand Up @@ -64,7 +64,9 @@ export const executeOp: InternalOpExecutor =
getParamValue('onValue', node, tensorMap, context) as number;
const offValue =
getParamValue('offValue', node, tensorMap, context) as number;
return [ops.oneHot(indices, depth, onValue, offValue)];
const dtype =
getParamValue('dtype', node, tensorMap, context) as DataType;
return [ops.oneHot(indices, depth, onValue, offValue, dtype)];
}
case 'Ones': {
return [ops.ones(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import * as creation from '../op_list/creation';
import {Node} from '../types';

import {executeOp} from './creation_executor';
import {RecursiveSpy, spyOnAllFunctions} from './spy_ops';
import {createDtypeAttr, createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttrFromIndex, createTensorAttr, validateParam} from './test_helper';
import {spyOnAllFunctions, RecursiveSpy} from './spy_ops';

describe('creation', () => {
let node: Node;
Expand Down Expand Up @@ -99,22 +99,25 @@ describe('creation', () => {
node.inputParams['depth'] = createNumberAttrFromIndex(1);
node.inputParams['onValue'] = createNumberAttrFromIndex(2);
node.inputParams['offValue'] = createNumberAttrFromIndex(3);
node.attrParams['dtype'] = createDtypeAttr('float32');
node.inputNames = ['input', 'input2', 'input3', 'input4'];
const input = [tfOps.tensor1d([0])];
const input3 = [tfOps.scalar(2)];
const input4 = [tfOps.scalar(3)];
spyOps.oneHot.and.returnValue({});
executeOp(node, {input, input2, input3, input4}, context,
spyOpsAsTfOps);
executeOp(
node, {input, input2, input3, input4}, context, spyOpsAsTfOps);

expect(spyOps.oneHot).toHaveBeenCalledWith(input[0], 1, 2, 3);
expect(spyOps.oneHot)
.toHaveBeenCalledWith(input[0], 1, 2, 3, 'float32');
});
it('should match json def', () => {
node.op = 'OneHot';
node.inputParams['indices'] = createTensorAttr(0);
node.inputParams['depth'] = createNumberAttrFromIndex(1);
node.inputParams['onValue'] = createNumberAttrFromIndex(2);
node.inputParams['offValue'] = createNumberAttrFromIndex(3);
node.attrParams['dtype'] = createDtypeAttr('float32');

expect(validateParam(node, creation.json)).toBeTruthy();
});
Expand Down
1 change: 1 addition & 0 deletions tfjs-core/src/ops/one_hot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import {op} from './operation';
* the location.
* @param offValue A number used to fill in the output when the index does
* not match the location.
* @param dtype The dtype of the output tensor, default to 'int32'.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
Expand Down

0 comments on commit 9a580c5

Please sign in to comment.