From 0d6b97739581eddeb6f1dce55bf231398d5293c4 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 1 Aug 2014 11:21:17 -0700 Subject: [PATCH] support for multiclass output prob --- demo/multiclass_classification/train.py | 7 +++++++ regrank/xgboost_regrank.h | 4 ++-- regrank/xgboost_regrank_obj.h | 9 +++++++-- regrank/xgboost_regrank_obj.hpp | 23 ++++++++++++++++++++--- 4 files changed, 36 insertions(+), 7 deletions(-) diff --git a/demo/multiclass_classification/train.py b/demo/multiclass_classification/train.py index df5e112aa902..fabc43c45b7c 100755 --- a/demo/multiclass_classification/train.py +++ b/demo/multiclass_classification/train.py @@ -39,4 +39,11 @@ print ('predicting, classification error=%f' % (sum( int(pred[i]) != test_Y[i] for i in range(len(test_Y))) / float(len(test_Y)) )) +# do the same thing again, but output probabilities +param['objective'] = 'multi:softprob' +bst = xgb.train(param, xg_train, num_round, watchlist ); +# get prediction, this is in 1D array, need reshape to (nclass, ndata) +yprob = bst.predict( xg_test ).reshape( 6, test_Y.shape[0] ) +ylabel = np.argmax( yprob, axis=0) +print ('predicting, classification error=%f' % (sum( int(ylabel[i]) != test_Y[i] for i in range(len(test_Y))) / float(len(test_Y)) )) diff --git a/regrank/xgboost_regrank.h b/regrank/xgboost_regrank.h index fd6800ce49b0..256d2b0855ee 100644 --- a/regrank/xgboost_regrank.h +++ b/regrank/xgboost_regrank.h @@ -103,7 +103,7 @@ namespace xgboost{ */ inline void InitTrainer(void){ if( mparam.num_class != 0 ){ - if( name_obj_ != "multi:softmax" ){ + if( name_obj_ != "multi:softmax" && name_obj_ != "multi:softprob"){ name_obj_ = "multi:softmax"; printf("auto select objective=softmax to support multi-class classification\n" ); } @@ -206,7 +206,7 @@ namespace xgboost{ fprintf(fo, "[%d]", iter); for (size_t i = 0; i < evals.size(); ++i){ this->PredictRaw(preds_, *evals[i]); - obj_->PredTransform(preds_); + obj_->EvalTransform(preds_); evaluator_.Eval(fo, evname[i].c_str(), preds_, evals[i]->info); } fprintf(fo, "\n"); diff --git a/regrank/xgboost_regrank_obj.h b/regrank/xgboost_regrank_obj.h index 5851f6384a79..09b447a156a0 100644 --- a/regrank/xgboost_regrank_obj.h +++ b/regrank/xgboost_regrank_obj.h @@ -41,6 +41,11 @@ namespace xgboost{ * \param preds prediction values, saves to this vector as well */ virtual void PredTransform(std::vector &preds){} + /*! + * \brief transform prediction values, this is only called when Eval is called, usually it redirect to PredTransform + * \param preds prediction values, saves to this vector as well + */ + virtual void EvalTransform(std::vector &preds){ this->PredTransform(preds); } }; }; @@ -114,8 +119,8 @@ namespace xgboost{ if( !strcmp("reg:logistic", name ) ) return new RegressionObj( LossType::kLogisticNeglik ); if( !strcmp("binary:logistic", name ) ) return new RegressionObj( LossType::kLogisticClassify ); if( !strcmp("binary:logitraw", name ) ) return new RegressionObj( LossType::kLogisticRaw ); - if( !strcmp("multi:softmax", name ) ) return new SoftmaxMultiClassObj(); - if( !strcmp("rank:pairwise", name ) ) return new PairwiseRankObj(); + if( !strcmp("multi:softmax", name ) ) return new SoftmaxMultiClassObj(0); + if( !strcmp("multi:softprob", name ) ) return new SoftmaxMultiClassObj(1); if( !strcmp("rank:pairwise", name ) ) return new PairwiseRankObj(); if( !strcmp("rank:softmax", name ) ) return new SoftmaxRankObj(); utils::Error("unknown objective function type"); diff --git a/regrank/xgboost_regrank_obj.hpp b/regrank/xgboost_regrank_obj.hpp index 7246b37e5f3e..b73c03c0cc9d 100644 --- a/regrank/xgboost_regrank_obj.hpp +++ b/regrank/xgboost_regrank_obj.hpp @@ -112,7 +112,7 @@ namespace xgboost{ // simple softmax multi-class classification class SoftmaxMultiClassObj : public IObjFunction{ public: - SoftmaxMultiClassObj(void){ + SoftmaxMultiClassObj(int output_prob):output_prob(output_prob){ nclass = 0; } virtual ~SoftmaxMultiClassObj(){} @@ -156,6 +156,13 @@ namespace xgboost{ } } virtual void PredTransform(std::vector &preds){ + this->Transform(preds, output_prob); + } + virtual void EvalTransform(std::vector &preds){ + this->Transform(preds, 0); + } + private: + inline void Transform(std::vector &preds, int prob){ utils::Assert( nclass != 0, "must set num_class to use softmax" ); utils::Assert( preds.size() % nclass == 0, "SoftmaxMultiClassObj: label size and pred size does not match" ); const unsigned ndata = static_cast(preds.size()/nclass); @@ -168,16 +175,26 @@ namespace xgboost{ for( int k = 0; k < nclass; ++ k ){ rec[k] = preds[j + k * ndata]; } - preds[j] = FindMaxIndex( rec ); + if( prob == 0 ){ + preds[j] = FindMaxIndex( rec ); + }else{ + Softmax( rec ); + for( int k = 0; k < nclass; ++ k ){ + preds[j + k * ndata] = rec[k]; + } + } } } - preds.resize( ndata ); + if( prob == 0 ){ + preds.resize( ndata ); + } } virtual const char* DefaultEvalMetric(void) { return "merror"; } private: int nclass; + int output_prob; }; };