Skip to content

Commit

Permalink
Category mapper: codegen (llvm#1130)
Browse files Browse the repository at this point in the history
Codegen for CategoryMapper
Signed-off-by: Ettore Tiotto <[email protected]>
  • Loading branch information
Ettore Tiotto authored Feb 2, 2022
1 parent 9f4b861 commit 80beae8
Show file tree
Hide file tree
Showing 7 changed files with 376 additions and 234 deletions.
413 changes: 251 additions & 162 deletions src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ struct ONNXCategoryMapperOpLowering : public ConversionPattern {

MemRefType type = MemRefType::get(
{static_cast<int64_t>(V.size())}, builder.getIntegerType(32));
res.G = create.krnl.constant(type, "G", builder.getI32VectorAttr(G));
res.V = create.krnl.constant(type, "V", builder.getI32VectorAttr(V));
res.G = create.krnl.constant(type, "G", builder.getI32TensorAttr(G));
res.V = create.krnl.constant(type, "V", builder.getI32TensorAttr(V));
res.len = create.math.constant(builder.getIntegerType(32), G.size());
return res;
};
Expand Down
3 changes: 0 additions & 3 deletions src/Dialect/ONNX/ONNXOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4716,9 +4716,6 @@ static LogicalResult verify(ONNXCategoryMapperOp op) {
return op.emitError("'default_string' attribute is missing.");
if (elementType.isa<onnxmlir::StringType>() && !op.default_int64Attr())
return op.emitError("'default_int64' attribute is missing.");
if (op.default_stringAttr() && op.default_int64Attr())
return op.emitError("Only one of 'default_int64' or 'default_string' "
"attributes must be specified");

return success();
}
Expand Down
8 changes: 4 additions & 4 deletions src/Runtime/OMIndexLookup.inc
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ static inline uint32_t hash_int64(uint32_t hval, int64_t val) {
#ifdef __cplusplus
extern "C"
#endif
uint32_t
uint64_t
find_index_str(
const char *str, int32_t G[], int32_t V[], int32_t dictSize) {
assert(str && G && V && dictSize > 0);
int32_t d = G[hash_string(0, str) % dictSize];
int32_t index = (d < 0) ? V[-d - 1] : V[hash_string(d, str) % dictSize];
int64_t index = (d < 0) ? V[-d - 1] : V[hash_string(d, str) % dictSize];
assert(index >= 0 && index < dictSize);
return index;
}
Expand All @@ -62,11 +62,11 @@ extern "C"
#ifdef __cplusplus
extern "C"
#endif
uint32_t
uint64_t
find_index_i64(int64_t val, int32_t G[], int32_t V[], int32_t dictSize) {
assert(G && V && dictSize > 0);
int32_t d = G[hash_int64(0, val) % dictSize];
int32_t index = (d < 0) ? V[-d - 1] : V[hash_int64(d, val) % dictSize];
int64_t index = (d < 0) ? V[-d - 1] : V[hash_int64(d, val) % dictSize];
assert(index >= 0 && index < dictSize);
return index;
}
63 changes: 43 additions & 20 deletions src/Runtime/PyExecutionSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ std::vector<py::array> PyExecutionSession::pyRun(
dtype = ONNX_TYPE_INT32;
else if (py::isinstance<py::array_t<std::int64_t>>(inputPyArray))
dtype = ONNX_TYPE_INT64;
// string type missing
else if (py::isinstance<py::array_t<bool>>(inputPyArray))
dtype = ONNX_TYPE_BOOL;
// Missing fp16 support.
Expand All @@ -71,6 +72,11 @@ std::vector<py::array> PyExecutionSession::pyRun(
dtype = ONNX_TYPE_UINT32;
else if (py::isinstance<py::array_t<std::uint64_t>>(inputPyArray))
dtype = ONNX_TYPE_UINT64;
else if (py::isinstance<py::array_t<std::complex<float>>>(inputPyArray))
dtype = ONNX_TYPE_COMPLEX64;
else if (py::isinstance<py::array_t<std::complex<double>>>(inputPyArray))
dtype = ONNX_TYPE_COMPLEX128;
// Missing bfloat16 support
else {
std::cerr << "Numpy type not supported: " << inputPyArray.dtype()
<< ".\n";
Expand All @@ -97,38 +103,55 @@ std::vector<py::array> PyExecutionSession::pyRun(

// https://numpy.org/devdocs/user/basics.types.html
py::dtype dtype;
if (omTensorGetDataType(omt) == (OM_DATA_TYPE)onnx::TensorProto::FLOAT)
switch (omTensorGetDataType(omt)) {
case (OM_DATA_TYPE)onnx::TensorProto::FLOAT:
dtype = py::dtype("float32");
else if (omTensorGetDataType(omt) == (OM_DATA_TYPE)onnx::TensorProto::UINT8)
break;
case (OM_DATA_TYPE)onnx::TensorProto::UINT8:
dtype = py::dtype("uint8");
else if (omTensorGetDataType(omt) == (OM_DATA_TYPE)onnx::TensorProto::INT8)
break;
case (OM_DATA_TYPE)onnx::TensorProto::INT8:
dtype = py::dtype("int8");
else if (omTensorGetDataType(omt) ==
(OM_DATA_TYPE)onnx::TensorProto::UINT16)
break;
case (OM_DATA_TYPE)onnx::TensorProto::UINT16:
dtype = py::dtype("uint16");
else if (omTensorGetDataType(omt) == (OM_DATA_TYPE)onnx::TensorProto::INT16)
break;
case (OM_DATA_TYPE)onnx::TensorProto::INT16:
dtype = py::dtype("int16");
else if (omTensorGetDataType(omt) == (OM_DATA_TYPE)onnx::TensorProto::INT32)
break;
case (OM_DATA_TYPE)onnx::TensorProto::INT32:
dtype = py::dtype("int32");
else if (omTensorGetDataType(omt) == (OM_DATA_TYPE)onnx::TensorProto::INT64)
break;
case (OM_DATA_TYPE)onnx::TensorProto::INT64:
dtype = py::dtype("int64");
// TODO(tjingrant) wait for Tong's input for how to represent string.
else if (omTensorGetDataType(omt) == (OM_DATA_TYPE)onnx::TensorProto::BOOL)
break;
case (OM_DATA_TYPE)onnx::TensorProto::STRING:
dtype = py::dtype("str");
break;
case (OM_DATA_TYPE)onnx::TensorProto::BOOL:
dtype = py::dtype("bool_");
else if (omTensorGetDataType(omt) ==
(OM_DATA_TYPE)onnx::TensorProto::FLOAT16)
break;
case (OM_DATA_TYPE)onnx::TensorProto::FLOAT16:
dtype = py::dtype("float32");
else if (omTensorGetDataType(omt) ==
(OM_DATA_TYPE)onnx::TensorProto::DOUBLE)
break;
case (OM_DATA_TYPE)onnx::TensorProto::DOUBLE:
dtype = py::dtype("float64");
else if (omTensorGetDataType(omt) ==
(OM_DATA_TYPE)onnx::TensorProto::UINT32)
break;
case (OM_DATA_TYPE)onnx::TensorProto::UINT32:
dtype = py::dtype("uint32");
else if (omTensorGetDataType(omt) ==
(OM_DATA_TYPE)onnx::TensorProto::UINT64)
break;
case (OM_DATA_TYPE)onnx::TensorProto::UINT64:
dtype = py::dtype("uint64");
else {
fprintf(stderr, "Unsupported ONNX type in OMTensor.");
break;
case (OM_DATA_TYPE)onnx::TensorProto::COMPLEX64:
dtype = py::dtype("csingle");
break;
case (OM_DATA_TYPE)onnx::TensorProto::COMPLEX128:
dtype = py::dtype("cdouble");
break;
default:
std::cerr << "Unsupported ONNX type in OMTensor: "
<< omTensorGetDataType(omt) << ".\n";
exit(1);
}

Expand Down
Loading

0 comments on commit 80beae8

Please sign in to comment.