Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Breaking] Remove dmlc::Parameter #293

Merged
merged 1 commit into from
Jul 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmake/ExternalLibs.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ add_library(RapidJSON::rapidjson ALIAS rapidjson)

# Google C++ tests
if(BUILD_CPP_TEST)
find_package(GTest CONFIG)
find_package(GTest 1.11.0 CONFIG)
if(NOT GTEST_FOUND)
message(STATUS "Did not find Google Test in the system root. Fetching Google Test now...")
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG release-1.10.0
GIT_TAG release-1.11.0
)
FetchContent_MakeAvailable(googletest)
add_library(GTest::GTest ALIAS gtest)
Expand Down
29 changes: 9 additions & 20 deletions include/treelite/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,46 +72,35 @@ TREELITE_DLL int TreeliteAnnotationFree(AnnotationHandle handle);
* \{
*/
/*!
* \brief create a compiler with a given name
* \brief Create a compiler with a given name
* \param name name of compiler
* \param params_json_str JSON string representing the parameters for the compiler
* \param out created compiler
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteCompilerCreate(const char* name,
CompilerHandle* out);
TREELITE_DLL int TreeliteCompilerCreateV2(const char* name, const char* params_json_str,
CompilerHandle* out);
/*!
* \brief set a parameter for a compiler
* \param handle compiler
* \param name name of parameter
* \param value value of parameter
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteCompilerSetParam(CompilerHandle handle,
const char* name,
const char* value);
/*!
* \brief generate prediction code from a tree ensemble model. The code will
* \brief Generate prediction code from a tree ensemble model. The code will
* be C99 compliant. One header file (.h) will be generated, along with
* one or more source files (.c).
*
* Usage example:
* \code
* TreeliteCompilerGenerateCode(compiler, model, 1, "./my/model");
* TreeliteCompilerGenerateCodeV2(compiler, model, "./my/model");
* // files to generate: ./my/model/header.h, ./my/model/main.c
* // if parallel compilation is enabled:
* // ./my/model/header.h, ./my/model/main.c, ./my/model/tu0.c,
* // ./my/model/tu1.c, and so forth
* \endcode
* \param compiler handle for compiler
* \param model handle for tree ensemble model
* \param verbose whether to produce extra messages
* \param dirpath directory to store header and source files
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteCompilerGenerateCode(CompilerHandle compiler,
ModelHandle model,
int verbose,
const char* dirpath);
TREELITE_DLL int TreeliteCompilerGenerateCodeV2(CompilerHandle compiler,
ModelHandle model,
const char* dirpath);
/*!
* \brief delete compiler from memory
* \param handle compiler to remove
Expand Down
2 changes: 1 addition & 1 deletion include/treelite/c_api_error.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
#ifndef TREELITE_C_API_ERROR_H_
#define TREELITE_C_API_ERROR_H_

#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <treelite/c_api_common.h>
#include <stdexcept>

/*! \brief macro to guard beginning and end section of all functions */
#define API_BEGIN() try {
Expand Down
8 changes: 7 additions & 1 deletion include/treelite/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,19 @@ class Compiler {
* \return compiled model
*/
virtual compiler::CompiledModel Compile(const Model& model) = 0;
/*!
* \brief Query the parameters used to intiailize the compiler
* \return parameters used
*/
virtual compiler::CompilerParam QueryParam() const = 0;
/*!
* \brief create a compiler from given name
* \param name name of compiler
* \param param_json_string JSON string representing
* \return The created compiler
*/
static Compiler* Create(const std::string& name,
const compiler::CompilerParam& param);
const char* param_json_str);
};

/*!
Expand Down
22 changes: 2 additions & 20 deletions include/treelite/compiler_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
#ifndef TREELITE_COMPILER_PARAM_H_
#define TREELITE_COMPILER_PARAM_H_

#include <dmlc/parameter.h>
#include <string>
#include <limits>

namespace treelite {
namespace compiler {

/*! \brief parameters for tree compiler */
struct CompilerParam : public dmlc::Parameter<CompilerParam> {
struct CompilerParam {
/*!
* \defgroup compiler_param
* parameters for tree compiler
Expand Down Expand Up @@ -49,24 +48,7 @@ struct CompilerParam : public dmlc::Parameter<CompilerParam> {
int dump_array_as_elf;
/*! \} */

// declare parameters
DMLC_DECLARE_PARAMETER(CompilerParam) {
DMLC_DECLARE_FIELD(annotate_in).set_default("NULL")
.describe("Name of model annotation file");
DMLC_DECLARE_FIELD(quantize).set_lower_bound(0).set_default(0)
.describe("whether to quantize threshold points (0: no, >0: yes)");
DMLC_DECLARE_FIELD(parallel_comp).set_lower_bound(0).set_default(0)
.describe("option to enable parallel compilation;"
"if set to nonzero, the trees will be evely distributed"
"into [parallel_comp] files.");
DMLC_DECLARE_FIELD(verbose).set_default(0)
.describe("if >0, produce extra messages");
DMLC_DECLARE_FIELD(native_lib_name).set_default("predictor");
DMLC_DECLARE_FIELD(code_folding_req)
.set_default(std::numeric_limits<double>::infinity())
.set_lower_bound(0);
DMLC_DECLARE_FIELD(dump_array_as_elf).set_lower_bound(0).set_default(0);
}
static CompilerParam ParseFromJSON(const char* param_json_str);
};

} // namespace compiler
Expand Down
36 changes: 8 additions & 28 deletions python/treelite/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import collections
import shutil
import os
import json
from tempfile import TemporaryDirectory

import numpy as np
Expand Down Expand Up @@ -304,40 +305,19 @@ def compile(self, dirpath, params=None, compiler='ast_native', verbose=False):
``./model/header.h``, ``./my/model/main.c``
"""
compiler_handle = ctypes.c_void_p()
_check_call(_LIB.TreeliteCompilerCreate(c_str(compiler),
ctypes.byref(compiler_handle)))
_params = dict(params) if isinstance(params, list) else params
self._set_compiler_param(compiler_handle, _params or {})
_check_call(_LIB.TreeliteCompilerGenerateCode(
if verbose and _params:
_params['verbose'] = 1
params_json_str = json.dumps(_params) if _params else '{}'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think JSON handles this automatically.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the None object gets serialized to null.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, the param is optional.

_check_call(_LIB.TreeliteCompilerCreateV2(c_str(compiler),
c_str(params_json_str),
ctypes.byref(compiler_handle)))
_check_call(_LIB.TreeliteCompilerGenerateCodeV2(
compiler_handle,
self.handle,
ctypes.c_int(1 if verbose else 0),
c_str(dirpath)))
_check_call(_LIB.TreeliteCompilerFree(compiler_handle))

@staticmethod
def _set_compiler_param(compiler_handle, params, value=None):
"""
Set parameter(s) for compiler

Parameters
----------
params: :py:class:`dict <python:dict>` / :py:class:`list <python:list>` / \
:py:class:`str <python:str>`
list of key-alue pairs, dict or simply string key
compiler_handle: object of type `ctypes.c_void_p`
handle to compiler
value: optional
value of the specified parameter, when params is a single string
"""
if isinstance(params, collections.abc.Mapping):
params = params.items()
elif isinstance(params, (str,)) and value is not None:
params = [(params, value)]
for key, val in params:
_check_call(_LIB.TreeliteCompilerSetParam(compiler_handle, c_str(key),
c_str(str(val))))

@classmethod
def from_xgboost(cls, booster):
"""
Expand Down
70 changes: 13 additions & 57 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,6 @@

using namespace treelite;

namespace {

struct CompilerHandleImpl {
std::string name;
std::vector<std::pair<std::string, std::string>> cfg;
std::unique_ptr<Compiler> compiler;
explicit CompilerHandleImpl(const std::string& name)
: name(name), cfg(), compiler(nullptr) {}
~CompilerHandleImpl() = default;
};

} // anonymous namespace

int TreeliteAnnotateBranch(
ModelHandle model, DMatrixHandle dmat, int nthread, int verbose, AnnotationHandle* out) {
API_BEGIN();
Expand Down Expand Up @@ -63,66 +50,35 @@ int TreeliteAnnotationFree(AnnotationHandle handle) {
API_END();
}

int TreeliteCompilerCreate(const char* name,
CompilerHandle* out) {
int TreeliteCompilerCreateV2(const char* name, const char* params_json_str, CompilerHandle* out) {
API_BEGIN();
std::unique_ptr<CompilerHandleImpl> compiler{new CompilerHandleImpl(name)};
std::unique_ptr<Compiler> compiler{Compiler::Create(name, params_json_str)};
*out = static_cast<CompilerHandle>(compiler.release());
API_END();
}

int TreeliteCompilerSetParam(CompilerHandle handle,
const char* name,
const char* value) {
API_BEGIN();
CompilerHandleImpl* impl = static_cast<CompilerHandleImpl*>(handle);
auto& cfg_ = impl->cfg;
std::string name_(name);
std::string value_(value);
// check for duplicate parameters
auto it = std::find_if(cfg_.begin(), cfg_.end(),
[&name_](const std::pair<std::string, std::string>& x) {
return x.first == name_;
});
if (it == cfg_.end()) {
cfg_.emplace_back(name_, value_);
} else {
it->second = value;
}
API_END();
}

int TreeliteCompilerGenerateCode(CompilerHandle compiler,
ModelHandle model,
int verbose,
const char* dirpath) {
int TreeliteCompilerGenerateCodeV2(CompilerHandle compiler,
ModelHandle model,
const char* dirpath) {
API_BEGIN();
if (verbose > 0) { // verbose enabled
int ret = TreeliteCompilerSetParam(compiler, "verbose",
std::to_string(verbose).c_str());
if (ret < 0) { // SetParam failed
return ret;
}
}
const Model* model_ = static_cast<Model*>(model);
CompilerHandleImpl* impl = static_cast<CompilerHandleImpl*>(compiler);
Compiler* compiler_ = static_cast<Compiler*>(compiler);
CHECK(model_);
CHECK(compiler_);
compiler::CompilerParam param = compiler_->QueryParam();

// create directory named dirpath
const std::string& dirpath_(dirpath);
filesystem::CreateDirectoryIfNotExist(dirpath);

compiler::CompilerParam cparam;
cparam.Init(impl->cfg, dmlc::parameter::kAllMatch);

/* compile model */
impl->compiler.reset(Compiler::Create(impl->name, cparam));
auto compiled_model = impl->compiler->Compile(*model_);
if (verbose > 0) {
auto compiled_model = compiler_->Compile(*model_);
if (param.verbose > 0) {
LOG(INFO) << "Code generation finished. Writing code to files...";
}

for (const auto& it : compiled_model.files) {
if (verbose > 0) {
if (param.verbose > 0) {
LOG(INFO) << "Writing file " << it.first << "...";
}
const std::string filename_full = dirpath_ + "/" + it.first;
Expand All @@ -138,7 +94,7 @@ int TreeliteCompilerGenerateCode(CompilerHandle compiler,

int TreeliteCompilerFree(CompilerHandle handle) {
API_BEGIN();
delete static_cast<CompilerHandleImpl*>(handle);
delete static_cast<Compiler*>(handle);
API_END();
}

Expand Down
4 changes: 4 additions & 0 deletions src/compiler/ast_native.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ class ASTNativeCompiler : public Compiler {
});
}

CompilerParam QueryParam() const override {
return param;
}

private:
CompilerParam param;
int num_feature_;
Expand Down
Loading