Skip to content

Commit

Permalink
[BYOC] DNNL C_SRC Fix (#14267)
Browse files Browse the repository at this point in the history
Symptom:
After setting DNNL="C_SRC" to generate C code, the compiler
reported an error that the function "BindToCallNodeArgs"
was missing and there was a problem with linking "dtype_dl2dnnl".
Another issue is the compile complain for ambiguous "GetRootCall".

After fixing the compilation issue, when executing relay.build,
an error occurred complaining "Check failed: const_var_ndarray_
.count(var) > 0 (0 vs. 0)"

Solution:
Update the cmake file and change the codegen.cc.
  • Loading branch information
huajsj authored Mar 10, 2023
1 parent 422ca28 commit 302cee9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 25 deletions.
1 change: 1 addition & 0 deletions cmake/modules/contrib/DNNL.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ elseif(USE_DNNL STREQUAL "C_SRC")
find_library(EXTERN_LIBRARY_DNNL dnnl)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL})
tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/dnnl.cc
src/runtime/contrib/dnnl/dnnl_utils.cc
src/runtime/contrib/cblas/dnnl_blas.cc)
list(APPEND RUNTIME_SRCS ${DNNL_CONTRIB_SRC})
message(STATUS "Build with DNNL C source module: " ${EXTERN_LIBRARY_DNNL})
Expand Down
52 changes: 27 additions & 25 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,29 @@ namespace contrib {

using namespace backend;

/*!
* \brief Replace var expr which bind with args of call node
*
* \param args vector of expression (contains vars or constant nodes)
* \param cn call node which describe mapping of internal body vars with args
* \return updated vector of expressions
*/
static tvm::Array<Expr> BindToCallNodeArgs(const std::vector<Expr>& args, const CallNode* cn) {
tvm::Array<Expr> res;
for (const auto& arg : args) {
if (arg->IsInstance<ConstantNode>()) {
res.push_back(arg);
} else {
auto body_params = cn->op.as<FunctionNode>()->params;
auto found = std::find(body_params.begin(), body_params.end(), arg);
ICHECK(found != body_params.end());
auto idx = std::distance(body_params.begin(), found);
res.push_back(cn->args[idx]);
}
}
return res;
}

#ifndef USE_JSON_RUNTIME // C source runtime
inline size_t GetShape1DSize(const Type& type) {
const auto shape = GetShape(type);
Expand Down Expand Up @@ -203,7 +226,8 @@ class CodegenDNNL : public MemoizedExprTranslator<std::vector<Output>>, public C

// Give the ndarray a unique name to ease the initialization of it at
// runtime.
std::string const_var_name = CreateConstVar(ext_func_id_, const_idx_);
std::string const_symbol = "dnnl_" + ext_func_id_;
std::string const_var_name = CreateConstVar(const_symbol, const_idx_);
const_vars_.push_back(const_var_name);
const_idx_++;

Expand Down Expand Up @@ -274,7 +298,8 @@ class CodegenDNNL : public MemoizedExprTranslator<std::vector<Output>>, public C
return GenerateBody(conv_call, "dnnl_fused_conv2d_bias_relu", GetArgumentNames(caller),
Conv2d(conv_call));
} else if (pattern_name == "dnnl.conv2d_relu") {
const auto* conv_call = GetRootCall(callee->body.as<CallNode>(), 1, {"nn.conv2d", "nn.relu"});
const auto* conv_call = GetRootCall(callee->body.as<CallNode>(), 1,
(const std::vector<std::string>){"nn.conv2d", "nn.relu"});
return GenerateBody(conv_call, "dnnl_fused_conv2d_relu", GetArgumentNames(caller),
Conv2d(conv_call));
}
Expand Down Expand Up @@ -434,29 +459,6 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {

#else // DNNL JSON runtime

/*!
* \brief Replace var expr which bind with args of call node
*
* \param args vector of expression (contains vars or constant nodes)
* \param cn call node which describe mapping of internal body vars with args
* \return updated vector of expressions
*/
static tvm::Array<Expr> BindToCallNodeArgs(const std::vector<Expr>& args, const CallNode* cn) {
tvm::Array<Expr> res;
for (const auto& arg : args) {
if (arg->IsInstance<ConstantNode>()) {
res.push_back(arg);
} else {
auto body_params = cn->op.as<FunctionNode>()->params;
auto found = std::find(body_params.begin(), body_params.end(), arg);
ICHECK(found != body_params.end());
auto idx = std::distance(body_params.begin(), found);
res.push_back(cn->args[idx]);
}
}
return res;
}

/*! \brief Serializer to DNNL JSON runtime module */
class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
Expand Down

0 comments on commit 302cee9

Please sign in to comment.