Skip to content

Commit

Permalink
Merge pull request #1166 from khotilov/r_api_fix
Browse files Browse the repository at this point in the history
[R-package] C-API fix; attribute accessors
  • Loading branch information
tqchen committed May 7, 2016
2 parents b92e225 + 5a78118 commit 6e79ba8
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 19 deletions.
2 changes: 2 additions & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ S3method(predict,xgb.Booster)
S3method(predict,xgb.Booster.handle)
S3method(setinfo,xgb.DMatrix)
S3method(slice,xgb.DMatrix)
export("xgb.attr<-")
export(getinfo)
export(print.xgb.DMatrix)
export(setinfo)
export(slice)
export(xgb.DMatrix)
export(xgb.DMatrix.save)
export(xgb.attr)
export(xgb.create.features)
export(xgb.cv)
export(xgb.dump)
Expand Down
74 changes: 74 additions & 0 deletions R-package/R/xgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,77 @@ predict.xgb.Booster.handle <- function(object, ...) {
ret <- predict(bst, ...)
return(ret)
}


#' Accessors for serializable attributes of a model.
#'
#' These methods allow to manipulate key-value attribute strings of an xgboost model.
#'
#' @param object Object of class \code{xgb.Booster} or \code{xgb.Booster.handle}.
#' @param which a non-empty character string specifying which attribute is to be accessed.
#' @param value a value of an attribute. Non-character values are converted to character.
#' When length of a \code{value} vector is more than one, only the first element is used.
#'
#' @details
#' Note that the xgboost model attributes are a separate concept from the attributes in R.
#' Specifically, they refer to key-value strings that can be attached to an xgboost model
#' and stored within the model's binary representation.
#' In contrast, any R-attribute assigned to an R-object of \code{xgb.Booster} class
#' would not be saved by \code{xgb.save}, since xgboost model is an external memory object
#' and its serialization is handled extrnally.
#'
#' Also note that the attribute setter would usually work more efficiently for \code{xgb.Booster.handle}
#' than for \code{xgb.Booster}, since only just a handle would need to be copied.
#'
#' @return
#' \code{xgb.attr} returns either a string value of an attribute
#' or \code{NULL} if an attribute wasn't stored in a model.
#'
#' @examples
#' data(agaricus.train, package='xgboost')
#' train <- agaricus.train
#'
#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
#' eta = 1, nthread = 2, nround = 2, objective = "binary:logistic")
#'
#' xgb.attr(bst, "my_attribute") <- "my attribute value"
#' print(xgb.attr(bst, "my_attribute"))
#'
#' xgb.save(bst, 'xgb.model')
#' bst1 <- xgb.load('xgb.model')
#' print(xgb.attr(bst1, "my_attribute"))
#'
#' @rdname xgb.attr
#' @export
xgb.attr <- function(object, which) {
if (is.null(which) | nchar(as.character(which)[1]) == 0) stop("invalid attribute name")
handle = xgb.get.handle(object, "xgb.attr")
.Call("XGBoosterGetAttr_R", handle, as.character(which)[1], PACKAGE="xgboost")
}

#' @rdname xgb.attr
#' @export
`xgb.attr<-` <- function(object, which, value) {
if (is.null(which) | nchar(as.character(which)[1]) == 0) stop("invalid attribute name")
handle = xgb.get.handle(object, "xgb.attr")
# TODO: setting NULL value to remove an attribute
.Call("XGBoosterSetAttr_R", handle, as.character(which)[1], as.character(value)[1], PACKAGE="xgboost")
if (is(object, 'xgb.Booster') && !is.null(object$raw)) {
object$raw <- xgb.save.raw(object$handle)
}
object
}

# Return a valid handle out of either xgb.Booster.handle or xgb.Booster
# internal utility function
xgb.get.handle <- function(object, caller="") {
handle = switch(class(object),
xgb.Booster = object$handle,
xgb.Booster.handle = object,
stop(caller, ": argument must be either xgb.Booster or xgb.Booster.handle")
)
if (is.null(handle) | .Call("XGCheckNullPtr_R", handle, PACKAGE="xgboost")) {
stop(caller, ": invalid xgb.Booster.handle")
}
handle
}
53 changes: 53 additions & 0 deletions R-package/man/xgb.attr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

56 changes: 46 additions & 10 deletions R-package/src/xgboost_R.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,16 @@ SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
return ret;
}

void XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) {
SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) {
R_API_BEGIN();
CHECK_CALL(XGDMatrixSaveBinary(R_ExternalPtrAddr(handle),
CHAR(asChar(fname)),
asInteger(silent)));
R_API_END();
return R_NilValue;
}

void XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
R_API_BEGIN();
int len = length(array);
const char *name = CHAR(asChar(field));
Expand All @@ -167,6 +168,7 @@ void XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
BeginPtr(vec), len));
}
R_API_END();
return R_NilValue;
}

SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) {
Expand Down Expand Up @@ -227,23 +229,25 @@ SEXP XGBoosterCreate_R(SEXP dmats) {
return ret;
}

void XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) {
SEXP XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) {
R_API_BEGIN();
CHECK_CALL(XGBoosterSetParam(R_ExternalPtrAddr(handle),
CHAR(asChar(name)),
CHAR(asChar(val))));
CHAR(asChar(name)),
CHAR(asChar(val))));
R_API_END();
return R_NilValue;
}

void XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) {
SEXP XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) {
R_API_BEGIN();
CHECK_CALL(XGBoosterUpdateOneIter(R_ExternalPtrAddr(handle),
asInteger(iter),
R_ExternalPtrAddr(dtrain)));
R_API_END();
return R_NilValue;
}

void XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) {
SEXP XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) {
R_API_BEGIN();
CHECK_EQ(length(grad), length(hess))
<< "gradient and hess must have same length";
Expand All @@ -259,6 +263,7 @@ void XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) {
BeginPtr(tgrad), BeginPtr(thess),
len));
R_API_END();
return R_NilValue;
}

SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) {
Expand Down Expand Up @@ -305,24 +310,27 @@ SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask, SEXP ntree_lim
return ret;
}

void XGBoosterLoadModel_R(SEXP handle, SEXP fname) {
SEXP XGBoosterLoadModel_R(SEXP handle, SEXP fname) {
R_API_BEGIN();
CHECK_CALL(XGBoosterLoadModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname))));
R_API_END();
return R_NilValue;
}

void XGBoosterSaveModel_R(SEXP handle, SEXP fname) {
SEXP XGBoosterSaveModel_R(SEXP handle, SEXP fname) {
R_API_BEGIN();
CHECK_CALL(XGBoosterSaveModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname))));
R_API_END();
return R_NilValue;
}

void XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
R_API_BEGIN();
CHECK_CALL(XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle),
RAW(raw),
length(raw)));
R_API_END();
return R_NilValue;
}

SEXP XGBoosterModelToRaw_R(SEXP handle) {
Expand Down Expand Up @@ -360,3 +368,31 @@ SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats) {
return out;
}

SEXP XGBoosterGetAttr_R(SEXP handle, SEXP name) {
SEXP out;
R_API_BEGIN();
int success;
const char *val;
CHECK_CALL(XGBoosterGetAttr(R_ExternalPtrAddr(handle),
CHAR(asChar(name)),
&val,
&success));
if (success) {
out = PROTECT(allocVector(STRSXP, 1));
SET_STRING_ELT(out, 0, mkChar(val));
} else {
out = PROTECT(R_NilValue);
}
UNPROTECT(1);
R_API_END();
return out;
}

SEXP XGBoosterSetAttr_R(SEXP handle, SEXP name, SEXP val) {
R_API_BEGIN();
CHECK_CALL(XGBoosterSetAttr(R_ExternalPtrAddr(handle),
CHAR(asChar(name)),
CHAR(asChar(val))));
R_API_END();
return R_NilValue;
}
44 changes: 35 additions & 9 deletions R-package/src/xgboost_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,18 @@ XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset);
* \param handle a instance of data matrix
* \param fname file name
* \param silent print statistics when saving
* \return R_NilValue
*/
XGB_DLL void XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent);
XGB_DLL SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent);

/*!
* \brief set information to dmatrix
* \param handle a instance of data matrix
* \param field field name, can be label, weight
* \param array pointer to float vector
* \return R_NilValue
*/
XGB_DLL void XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array);
XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array);

/*!
* \brief get info vector from matrix
Expand Down Expand Up @@ -104,16 +106,18 @@ XGB_DLL SEXP XGBoosterCreate_R(SEXP dmats);
* \param handle handle
* \param name parameter name
* \param val value of parameter
* \return R_NilValue
*/
XGB_DLL void XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val);
XGB_DLL SEXP XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val);

/*!
* \brief update the model in one round using dtrain
* \param handle handle
* \param iter current iteration rounds
* \param dtrain training data
* \return R_NilValue
*/
XGB_DLL void XGBoosterUpdateOneIter_R(SEXP ext, SEXP iter, SEXP dtrain);
XGB_DLL SEXP XGBoosterUpdateOneIter_R(SEXP ext, SEXP iter, SEXP dtrain);

/*!
* \brief update the model, by directly specify gradient and second order gradient,
Expand All @@ -122,16 +126,17 @@ XGB_DLL void XGBoosterUpdateOneIter_R(SEXP ext, SEXP iter, SEXP dtrain);
* \param dtrain training data
* \param grad gradient statistics
* \param hess second order gradient statistics
* \return R_NilValue
*/
XGB_DLL void XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess);
XGB_DLL SEXP XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess);

/*!
* \brief get evaluation statistics for xgboost
* \param handle handle
* \param iter current iteration rounds
* \param dmats list of handles to dmatrices
* \param evname name of evaluation
* \return the string containing evaluation stati
* \return the string containing evaluation stats
*/
XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames);

Expand All @@ -147,21 +152,24 @@ XGB_DLL SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask, SEXP n
* \brief load model from existing file
* \param handle handle
* \param fname file name
* \return R_NilValue
*/
XGB_DLL void XGBoosterLoadModel_R(SEXP handle, SEXP fname);
XGB_DLL SEXP XGBoosterLoadModel_R(SEXP handle, SEXP fname);

/*!
* \brief save model into existing file
* \param handle handle
* \param fname file name
* \return R_NilValue
*/
XGB_DLL void XGBoosterSaveModel_R(SEXP handle, SEXP fname);
XGB_DLL SEXP XGBoosterSaveModel_R(SEXP handle, SEXP fname);

/*!
* \brief load model from raw array
* \param handle handle
* \return R_NilValue
*/
XGB_DLL void XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw);
XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw);

/*!
* \brief save model into R's raw array
Expand All @@ -177,4 +185,22 @@ XGB_DLL SEXP XGBoosterModelToRaw_R(SEXP handle);
* \param with_stats whether dump statistics of splits
*/
XGB_DLL SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats);

/*!
* \brief get learner attribute value
* \param handle handle
* \param name attribute name
* \return character containing attribute value
*/
XGB_DLL SEXP XGBoosterGetAttr_R(SEXP handle, SEXP name);

/*!
* \brief set learner attribute value
* \param handle handle
* \param name attribute name
* \param val attribute value
* \return R_NilValue
*/
XGB_DLL SEXP XGBoosterSetAttr_R(SEXP handle, SEXP name, SEXP val);

#endif // XGBOOST_WRAPPER_R_H_ // NOLINT(*)
Loading

0 comments on commit 6e79ba8

Please sign in to comment.