Skip to content

Commit

Permalink
[TMVA] Merge RBDT with FastForest
Browse files Browse the repository at this point in the history
Consolidate RBDT as specified in the ROOT plan of work 2024.

The backends of RBDT are replaced with a single new backend:
the logic from the FastForest library:
https://github.com/guitargeek/XGBoost-FastForest

The logic in that library was originally taken from the GBRForest in
CMSSW:
https://github.com/cms-sw/cmssw/blob/master/CommonTools/MVAUtils/interface/GBRForestTools.h

The interface remains the same, only that the template parameter
specifying the backend is gone.

This change adds support for unbalanced trees.
  • Loading branch information
guitargeek committed Apr 10, 2024
1 parent a517e8d commit bf532a4
Show file tree
Hide file tree
Showing 18 changed files with 527 additions and 962 deletions.
3 changes: 2 additions & 1 deletion bindings/pyroot/pythonizations/python/ROOT/_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,11 @@ def TMVA(self):
hasRDF = "dataframe" in gROOT.GetConfigFeatures()
if hasRDF:
try:
from ._pythonization._tmva import inject_rbatchgenerator, _AsRTensor
from ._pythonization._tmva import inject_rbatchgenerator, _AsRTensor, SaveXGBoost

inject_rbatchgenerator(ns)
ns.Experimental.AsRTensor = _AsRTensor
ns.Experimental.SaveXGBoost = SaveXGBoost
except:
raise Exception("Failed to pythonize the namespace TMVA")
del type(self).TMVA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def inject_rbatchgenerator(ns):

#this should be available only when xgboost is there ?
# We probably don't need a protection here since the code is run only when there is xgboost
from ._tree_inference import SaveXGBoost, pythonize_tree_inference
from ._tree_inference import SaveXGBoost


# list of python classes that are used to pythonize TMVA classes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,11 @@ def Compute(self, x):
# As fall-through we go to the original compute function and use the error-handling from cppyy
return self._OriginalCompute(x)

def RBDTInit(self, *args, **kwargs):
import warnings
warnings.warn(
("Usage of xgboost models through RBDT is known to be limited and may "
"lead to unexpected behaviour. Proceed with caution if the input model "
"was obtained via `SaveXGBoost`. See https://github.com/root-project/root/issues/15197 "
"for more details."), UserWarning, stacklevel=2)

return self._original_init(*args, **kwargs)


@pythonization("RBDT", ns="TMVA::Experimental", is_prefix=True)
def pythonize_rbdt(klass):
# Parameters:
# klass: class to be pythonized

klass._original_init = klass.__init__
klass.__init__ = RBDTInit

klass._OriginalCompute = klass.Compute
klass.Compute = Compute
Original file line number Diff line number Diff line change
@@ -1,23 +1,45 @@
# Author: Stefan Wunsch CERN 09/2019
# Author : Stefan Wunsch CERN 09 / 2019

################################################################################
# Copyright (C) 1995-2019, Rene Brun and Fons Rademakers. #
# All rights reserved. #
# #
# For the licensing terms see $ROOTSYS/LICENSE. #
# For the list of contributors see $ROOTSYS/README/CREDITS. #
# Copyright(C) 1995 - 2019, Rene Brun and Fons Rademakers.#
# All rights reserved.#
# #
# For the licensing terms see $ROOTSYS / LICENSE.#
# For the list of contributors see $ROOTSYS / README / CREDITS.#
################################################################################

from .. import pythonization
import cppyy

import json

def SaveXGBoost(self, xgb_model, key_name, output_path, num_inputs, tmp_path="/tmp", threshold_dtype="float"):
import warnings
warnings.warn(
("Usage of xgboost models through RBDT is known to be limited and may "
"lead to unexpected behaviour. See https://github.com/root-project/root/issues/15197 "
"for more details."), UserWarning, stacklevel=2)
def get_basescore(model):
"""Get base score from an XGBoost sklearn estimator.
Copy-pasted from XGBoost unit test code.
See also:
* https://github.com/dmlc/xgboost/blob/a99bb38bd2762e35e6a1673a0c11e09eddd8e723/python-package/xgboost/testing/updater.py#L13
* https://github.com/dmlc/xgboost/issues/9347
* https://discuss.xgboost.ai/t/how-to-get-base-score-from-trained-booster/3192
"""
base_score = float(json.loads(model.get_booster().save_config())["learner"]["learner_model_param"]["base_score"])
return base_score


def SaveXGBoost(xgb_model, key_name, output_path, num_inputs):
"""
Saves the XGBoost model to a ROOT file as a TMVA::Experimental::RBDT object.
Args:
xgb_model: The trained XGBoost model.
key_name (str): The name to use for storing the RBDT in the output file.
output_path (str): The path to save the output file.
num_inputs (int): The number of input features used in the model.
Raises:
Exception: If the XGBoost model has an unsupported objective.
"""
# Extract objective
objective_map = {
"multi:softprob": "softmax", # Naming the objective softmax is more common today
Expand All @@ -34,99 +56,25 @@ def SaveXGBoost(self, xgb_model, key_name, output_path, num_inputs, tmp_path="/t
)
objective = cppyy.gbl.std.string(objective_map[model_objective])

# Extract max depth of the trees
max_depth = xgb_model.max_depth

# Determine number of outputs
if "reg:" in model_objective:
num_outputs = 1
elif "binary:" in model_objective:
num_outputs = 1
else:
num_outputs = xgb_model.n_classes_
num_outputs = xgb_model.n_classes_ if "multi:" in model_objective else 1

# Dump XGB model to the tmp folder as json file
import os
import uuid
# Dump XGB model as json file
xgb_model.get_booster().dump_model(output_path, dump_format="json")

tmp_path = os.path.join(tmp_path, str(uuid.uuid4()) + ".json")
xgb_model.get_booster().dump_model(tmp_path, dump_format="json")
with open(output_path, "r") as json_file:
forest = json.load(json_file)

import json
# Dump XGB model as txt file
xgb_model.get_booster().dump_model(output_path)

with open(tmp_path, "r") as json_file:
forest = json.load(json_file)
features = cppyy.gbl.std.vector["std::string"]([f"f{i}" for i in range(num_inputs)])
bdt = cppyy.gbl.TMVA.Experimental.RBDT.LoadText(output_path, features, num_outputs)

bdt.logistic_ = objective == "logistic"

bs = get_basescore(xgb_model)
bdt.baseScore_ = cppyy.gbl.std.log(bs / (1.0 - bs)) if bdt.logistic_ else bs

# Determine whether the model has a bias paramter and write bias trees
if hasattr(xgb_model, "base_score") and "reg:" in model_objective:
bias = xgb_model.base_score
if not bias == 0.0:
forest += [{"leaf": bias}] * num_outputs
# print(str(forest).replace("u'", "'").replace("'", '"'))

# Extract parameters from json and write to arrays
num_trees = len(forest)
len_inputs = 2 ** max_depth - 1
inputs = cppyy.gbl.std.vector["int"](len_inputs * num_trees, -1)
len_thresholds = 2 ** (max_depth + 1) - 1
thresholds = cppyy.gbl.std.vector[threshold_dtype](len_thresholds * num_trees)

def fill_arrays(node, index, inputs_base, thresholds_base):
# Set leaf score as threshold value if this node is a leaf
if "leaf" in node:
thresholds[thresholds_base + index] = node["leaf"]
return

# Set input index
input_ = int(node["split"].replace("f", ""))
inputs[inputs_base + index] = input_

# Set threshold value
thresholds[thresholds_base + index] = node["split_condition"]

# Find next left (no) and right (yes) node
if node["children"][0]["nodeid"] == node["yes"]:
yes, no = 1, 0
else:
yes, no = 0, 1

# Fill values from the child nodes
fill_arrays(node["children"][no], 2 * index + 1, inputs_base, thresholds_base)
fill_arrays(node["children"][yes], 2 * index + 2, inputs_base, thresholds_base)

for i_tree, tree in enumerate(forest):
fill_arrays(tree, 0, len_inputs * i_tree, len_thresholds * i_tree)

# Determine to which output node a tree belongs
outputs = cppyy.gbl.std.vector["int"](num_trees)
if num_outputs != 1:
for i in range(num_trees):
outputs[i] = int(i % num_outputs)

# Store arrays in a ROOT file in a folder with the given key name
# TODO: Write single values as simple integers and not vectors.
f = cppyy.gbl.TFile(output_path, "RECREATE")
f.mkdir(key_name)
d = f.Get(key_name)
d.WriteObjectAny(inputs, "std::vector<int>", "inputs")
d.WriteObjectAny(outputs, "std::vector<int>", "outputs")
d.WriteObjectAny(thresholds, "std::vector<" + threshold_dtype + ">", "thresholds")
d.WriteObjectAny(objective, "std::string", "objective")
max_depth_ = cppyy.gbl.std.vector["int"](1, max_depth)
d.WriteObjectAny(max_depth_, "std::vector<int>", "max_depth")
num_trees_ = cppyy.gbl.std.vector["int"](1, num_trees)
d.WriteObjectAny(num_trees_, "std::vector<int>", "num_trees")
num_inputs_ = cppyy.gbl.std.vector["int"](1, num_inputs)
d.WriteObjectAny(num_inputs_, "std::vector<int>", "num_inputs")
num_outputs_ = cppyy.gbl.std.vector["int"](1, num_outputs)
d.WriteObjectAny(num_outputs_, "std::vector<int>", "num_outputs")
f.Write()
f.Close()


@pythonization("SaveXGBoost", ns="TMVA::Experimental")
def pythonize_tree_inference(klass):
# Parameters:
# klass: class to be pythonized

klass.__init__ = SaveXGBoost
with cppyy.gbl.TFile.Open(output_path, "RECREATE") as tFile:
tFile.WriteObject(bdt, key_name)
4 changes: 0 additions & 4 deletions tmva/tmva/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -462,10 +462,6 @@ ROOT_STANDARD_LIBRARY_PACKAGE(TMVAUtils
TMVA/RBatchGenerator.hxx
TMVA/RBatchLoader.hxx
TMVA/RChunkLoader.hxx
TMVA/TreeInference/PythonHelpers.hxx
TMVA/TreeInference/BranchlessTree.hxx
TMVA/TreeInference/Forest.hxx
TMVA/TreeInference/Objectives.hxx

SOURCES

Expand Down
5 changes: 2 additions & 3 deletions tmva/tmva/inc/LinkDefUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@

#ifdef R__HAS_DATAFRAME
// BDT inference
#pragma link C++ class TMVA::Experimental::RBDT<TMVA::Experimental::BranchlessForest<float>>+;
#pragma link C++ class TMVA::Experimental::RBDT<TMVA::Experimental::BranchlessJittedForest<float>>+;
#pragma link C++ class TMVA::Experimental::RBDT+;
#endif

// RTensor will have its own streamer function
#pragma link C++ class TMVA::Experimental::RTensor<float,std::vector<float>>-;

#endif
#endif
Loading

0 comments on commit bf532a4

Please sign in to comment.