Skip to content

Commit

Permalink
[tmva][pymva] Fixes for latest version of scikit-learn
Browse files Browse the repository at this point in the history
  • Loading branch information
lmoneta committed Sep 6, 2023
1 parent dff89cf commit db4013f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
14 changes: 7 additions & 7 deletions tmva/pymva/src/MethodPyGTB.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ MethodPyGTB::MethodPyGTB(const TString &jobName,
DataSetInfo &dsi,
const TString &theOption) :
PyMethodBase(jobName, Types::kPyGTB, methodTitle, dsi, theOption),
fLoss("deviance"),
fLoss("log_loss"),
fLearningRate(0.1),
fNestimators(100),
fSubsample(1.0),
Expand All @@ -85,7 +85,7 @@ MethodPyGTB::MethodPyGTB(const TString &jobName,
//_______________________________________________________________________
MethodPyGTB::MethodPyGTB(DataSetInfo &theData, const TString &theWeightFile)
: PyMethodBase(Types::kPyGTB, theData, theWeightFile),
fLoss("deviance"),
fLoss("log_loss"),
fLearningRate(0.1),
fNestimators(100),
fSubsample(1.0),
Expand Down Expand Up @@ -122,9 +122,9 @@ void MethodPyGTB::DeclareOptions()
{
MethodBase::DeclareCompatibilityOptions();

DeclareOptionRef(fLoss, "Loss", "{'deviance', 'exponential'}, optional (default='deviance')\
loss function to be optimized. 'deviance' refers to\
deviance (= logistic regression) for classification\
DeclareOptionRef(fLoss, "Loss", "{'log_loss', 'exponential'}, optional (default='log_loss')\
loss function to be optimized. 'log_loss' refers to\
logistic loss for classification\
with probabilistic outputs. For loss 'exponential' gradient\
boosting recovers the AdaBoost algorithm.");

Expand Down Expand Up @@ -197,9 +197,9 @@ void MethodPyGTB::DeclareOptions()
// Check options and load them to local python namespace
void MethodPyGTB::ProcessOptions()
{
if (fLoss != "deviance" && fLoss != "exponential") {
if (fLoss != "log_loss" && fLoss != "exponential") {
Log() << kFATAL << Form("Loss = %s ... that does not work!", fLoss.Data())
<< " The options are 'deviance' or 'exponential'." << Endl;
<< " The options are 'log_loss' or 'exponential'." << Endl;
}
pLoss = Eval(Form("'%s'", fLoss.Data()));
PyDict_SetItemString(fLocalNS, "loss", pLoss);
Expand Down
11 changes: 6 additions & 5 deletions tmva/pymva/src/MethodPyRandomForest.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ MethodPyRandomForest::MethodPyRandomForest(const TString &jobName,
fMinSamplesSplit(2),
fMinSamplesLeaf(1),
fMinWeightFractionLeaf(0),
fMaxFeatures("'auto'"),
fMaxFeatures("'sqrt'"),
fMaxLeafNodes("None"),
fBootstrap(kTRUE),
fOobScore(kFALSE),
Expand All @@ -90,7 +90,7 @@ MethodPyRandomForest::MethodPyRandomForest(DataSetInfo &theData, const TString &
fMinSamplesSplit(2),
fMinSamplesLeaf(1),
fMinWeightFractionLeaf(0),
fMaxFeatures("'auto'"),
fMaxFeatures("'sqrt'"),
fMaxLeafNodes("None"),
fBootstrap(kTRUE),
fOobScore(kFALSE),
Expand Down Expand Up @@ -234,7 +234,8 @@ void MethodPyRandomForest::ProcessOptions()
pMinWeightFractionLeaf = Eval(Form("%f", fMinWeightFractionLeaf));
PyDict_SetItemString(fLocalNS, "minWeightFractionLeaf", pMinWeightFractionLeaf);

if (fMaxFeatures == "auto" || fMaxFeatures == "sqrt" || fMaxFeatures == "log2"){
if (fMaxFeatures == "auto") fMaxFeatures = "sqrt"; // change in API from v 1.11
if (fMaxFeatures == "sqrt" || fMaxFeatures == "log2"){
fMaxFeatures = Form("'%s'", fMaxFeatures.Data());
}
pMaxFeatures = Eval(fMaxFeatures);
Expand Down Expand Up @@ -428,9 +429,9 @@ std::vector<Double_t> MethodPyRandomForest::GetMvaValues(Long64_t firstEvt, Long

Py_DECREF(pEvent);
Py_DECREF(result);

if (logProgress) {
Log() << kINFO
Log() << kINFO
<< "Elapsed time for evaluation of " << nEvents << " events: "
<< timer.GetElapsedTime() << " " << Endl;
}
Expand Down

0 comments on commit db4013f

Please sign in to comment.