From 8bd16dd8524711e9c38769c83c4a39f27f787eb3 Mon Sep 17 00:00:00 2001 From: John Sullivan Date: Tue, 29 Sep 2020 16:55:13 -0400 Subject: [PATCH 1/3] Added support for full set of XGBoost feature importance metrics --- .../xgboost/XGBoostFeatureImportance.java | 253 ++++++++++++++++++ .../tribuo/common/xgboost/XGBoostModel.java | 11 + 2 files changed, 264 insertions(+) create mode 100644 Common/XGBoost/src/main/java/org/tribuo/common/xgboost/XGBoostFeatureImportance.java diff --git a/Common/XGBoost/src/main/java/org/tribuo/common/xgboost/XGBoostFeatureImportance.java b/Common/XGBoost/src/main/java/org/tribuo/common/xgboost/XGBoostFeatureImportance.java new file mode 100644 index 000000000..c6b7cd605 --- /dev/null +++ b/Common/XGBoost/src/main/java/org/tribuo/common/xgboost/XGBoostFeatureImportance.java @@ -0,0 +1,253 @@ +package org.tribuo.common.xgboost; + +import ml.dmlc.xgboost4j.java.Booster; +import ml.dmlc.xgboost4j.java.XGBoostError; +import org.tribuo.ImmutableFeatureMap; + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static ml.dmlc.xgboost4j.java.Booster.FeatureImportanceType.*; + +/** + * Generate and collate feature importance information from the XGBoost model. This wraps the underlying functionality + * of the XGBoost model, and should provide feature importance metrics compatible with those provided by XGBoost's R + * and Python APIs. For a more treatment of what the different importance metrics mean and how to interpret them, see + * here. In brief + * + * + */ +public class XGBoostFeatureImportance { + + /** + * An instance of feature importance values for a single feature. See {@link XGBoostFeatureImportance} for details + * on interpreting the metrics. + */ + public static class XGBoostFeatureImportanceRecord { + + private String featureName; + private final double gain; + private final double cover; + private final double weight; + private final double totalGain; + private final double totalCover; + + XGBoostFeatureImportanceRecord(String featureName, double gain, double cover, double weight, double totalGain, double totalCover) { + this.featureName = featureName; + this.gain = gain; + this.cover = cover; + this.weight = weight; + this.totalGain = totalGain; + this.totalCover = totalCover; + } + + public String getFeatureName() { + return featureName; + } + + public double getGain() { + return gain; + } + + public double getCover() { + return cover; + } + + public double getWeight() { + return weight; + } + + public double getTotalGain() { + return totalGain; + } + + public double getTotalCover() { + return totalCover; + } + + public String toString() { + return String.format("XGBoostFeatureImportanceRecord(feature=%s, gain=%.2f, cover=%.2f, weight=%.2f, totalGain=%.2f, totalCover=%.2f)", + featureName, gain, cover, weight, totalGain, totalCover); + } + } + + private Booster booster; + private ImmutableFeatureMap featureMap; + + XGBoostFeatureImportance(Booster booster, ImmutableFeatureMap featureMap) { + this.booster = booster; + this.featureMap = featureMap; + } + + private String translateFeatureId(String xgbFeatName) { + return featureMap.get(Integer.parseInt(xgbFeatName.substring(1))).getName(); + } + + private Stream> getImportanceStream(String importanceType) { + try { + return booster.getScore("", importanceType).entrySet().stream() + .sorted(Comparator.comparingDouble((Map.Entry e) -> e.getValue()).reversed()); + } catch (XGBoostError e) { + throw new IllegalStateException("Error generating feature importance for " + importanceType + " caused by", e); + } + } + + private LinkedHashMap coalesceImportanceStream(Stream> str) { + return str.collect(Collectors.toMap(e -> translateFeatureId(e.getKey()), + Map.Entry::getValue, (e1, e2) -> e1, LinkedHashMap::new)); + } + + /** + * Gain measures the improvement in accuracy that a feature brings to the branches on which it appears. + * This represents the sum of situated marginal contributions that a given feature makes to the each branching + * chain in which it appears. + * @return Ordered map where the keys are feature names and the value is the gain, sorted descending + */ + public LinkedHashMap getGain() { + return coalesceImportanceStream(getImportanceStream(GAIN)); + } + + /** + * + * Gain measures the improvement in accuracy that a feature brings to the branches on which it appears. + * This represents the sum of situated marginal contributions that a given feature makes to the each branching + * chain in which it appears. Returns only the top numFeatures features. + * @param numFeatures number of features to return + * @return Ordered map where the keys are feature names and the value is the gain, sorted descending + */ + public LinkedHashMap getGain(int numFeatures) { + return coalesceImportanceStream(getImportanceStream(GAIN).limit(numFeatures)); + } + + /** + * Cover measures the number of examples a given feature discriminates across, relative to the total + * number of examples all features discriminate across. + * @return Ordered map where the keys are feature names and the value is the cover, sorted descending + */ + public LinkedHashMap getCover() { + return coalesceImportanceStream(getImportanceStream(COVER)); + } + /** + * + * Cover measures the number of examples a given feature discriminates across, relative to the total. + * number of examples all features discriminate across. Returns only the top numFeatures features. + * @param numFeatures number of features to return + * @return Ordered map where the keys are feature names and the value is the cover, sorted descending + */ + public LinkedHashMap getCover(int numFeatures) { + return coalesceImportanceStream(getImportanceStream(COVER).limit(numFeatures)); + } + + /** + * Weight measures the number a times a feature occurs in the model. Due to the way the model builds trees, + * this value is skewed in favor of continuous features. + * @return Ordered map where the keys are feature names and the value is the weight, sorted descending + */ + public LinkedHashMap getWeight() { + return coalesceImportanceStream(getImportanceStream(WEIGHT)); + } + /** + * Weight measures the number a times a feature occurs in the model. Due to the way the model builds trees, + * this value is skewed in favor of continuous features. Returns only the top numFeatures features. + * @param numFeatures number of features to return + * @return Ordered map where the keys are feature names and the value is the weight, sorted descending + */ + public LinkedHashMap getWeight(int numFeatures) { + return coalesceImportanceStream(getImportanceStream(WEIGHT).limit(numFeatures)); + } + + /** + * Total Gain is similar to gain, but not locally averaged by weight, and thus not skewed in the way that + * weight can be skewed. + * @return Ordered map where the keys are feature names and the value is the total gain, sorted descending + */ + public LinkedHashMap getTotalGain() { + return coalesceImportanceStream(getImportanceStream(TOTAL_GAIN)); + } + /** + * Total Gain is similar to gain, but not locally averaged by weight, and thus not skewed in the way that + * weight can be skewed. Returns only top numFeatures features. + * @param numFeatures number of features to return + * @return Ordered map where the keys are feature names and the value is the total gain, sorted descending + */ + public LinkedHashMap getTotalGain(int numFeatures) { + return coalesceImportanceStream(getImportanceStream(TOTAL_GAIN).limit(numFeatures)); + } + + /** + * Total Cover is similar to cover, but not locally averaged by weight, and thus not skewed in the way that + * weight can be skewed. + * @return Ordered map where the keys are feature names and the value is the total gain, sorted descending + */ + public LinkedHashMap getTotalCover() { + return coalesceImportanceStream(getImportanceStream(TOTAL_COVER)); + } + /** + * Total Cover is similar to cover, but not locally averaged by weight, and thus not skewed in the way that + * weight can be skewed. Returns only top numFeatures features. + * @return Ordered map where the keys are feature names and the value is the total gain, sorted descending + */ + public LinkedHashMap getTotalCover(int numFeatures) { + return coalesceImportanceStream(getImportanceStream(TOTAL_COVER).limit(numFeatures)); + } + + /** + */ + public List getImportances() { + Map> importanceByType = Stream.of(GAIN, COVER, WEIGHT, TOTAL_GAIN, TOTAL_COVER) + .map(importanceType -> new AbstractMap.SimpleEntry<>(importanceType, coalesceImportanceStream(getImportanceStream(importanceType)))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + // this is already sorted by gain + List features = new ArrayList<>(importanceByType.get(GAIN).keySet()); + return features.stream().map(featureName -> new XGBoostFeatureImportanceRecord(featureName, + importanceByType.get(GAIN).get(featureName), + importanceByType.get(COVER).get(featureName), + importanceByType.get(WEIGHT).get(featureName), + importanceByType.get(TOTAL_GAIN).get(featureName), + importanceByType.get(TOTAL_COVER).get(featureName))) + .collect(Collectors.toList()); + } + + /** + * @param numFeatures number of features to return + * @return records of all importance metrics for each feature, sorted by gain. + */ + public List getImportances(int numFeatures) { + Map> importanceByType = Stream.of(GAIN, COVER, WEIGHT, TOTAL_GAIN, TOTAL_COVER) + .map(importanceType -> new AbstractMap.SimpleEntry<>(importanceType, coalesceImportanceStream(getImportanceStream(importanceType)))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + // this is already sorted by gain + List features = new ArrayList<>(importanceByType.get(GAIN).keySet()).subList(0, Math.min(importanceByType.get(GAIN).keySet().size(), numFeatures)); + return features.stream().map(featureName -> new XGBoostFeatureImportanceRecord(featureName, + importanceByType.get(GAIN).get(featureName), + importanceByType.get(COVER).get(featureName), + importanceByType.get(WEIGHT).get(featureName), + importanceByType.get(TOTAL_GAIN).get(featureName), + importanceByType.get(TOTAL_COVER).get(featureName))) + .collect(Collectors.toList()); + } + + public String toString() { + return "XGBoostFeatureImportance(" + getImportances(5).stream() + .map(XGBoostFeatureImportanceRecord::toString) + .collect(Collectors.joining(",\n\t")) + ")"; + } +} diff --git a/Common/XGBoost/src/main/java/org/tribuo/common/xgboost/XGBoostModel.java b/Common/XGBoost/src/main/java/org/tribuo/common/xgboost/XGBoostModel.java index 2345fc22f..4c19870d7 100644 --- a/Common/XGBoost/src/main/java/org/tribuo/common/xgboost/XGBoostModel.java +++ b/Common/XGBoost/src/main/java/org/tribuo/common/xgboost/XGBoostModel.java @@ -46,6 +46,7 @@ import java.util.PriorityQueue; import java.util.logging.Level; import java.util.logging.Logger; +import java.util.stream.Collectors; /** * A {@link Model} which wraps around a XGBoost.Booster. @@ -156,6 +157,16 @@ public Prediction predict(Example example) { } } + /** + * Creates objects to report feature importance metrics for XGBoost. See the documentation of {@link XGBoostFeatureImportance} + * for more information on what those metrics mean. Typically this list will contain a single instance for the entire + * model. For multidimensional regression the list will have one entry per dimension, in dimension order. + * @return The feature importance object(s). + */ + public List getFeatureImportance() { + return models.stream().map(b -> new XGBoostFeatureImportance(b, featureIDMap)).collect(Collectors.toList()); + } + @Override public Map>> getTopFeatures(int n) { try { From 3271313a7fea728702441ecb99845a67d5e1e6ed Mon Sep 17 00:00:00 2001 From: John Sullivan Date: Wed, 30 Sep 2020 16:40:46 -0400 Subject: [PATCH 2/3] Made changes requested in PR #52 --- .../classification/xgboost/TestXGBoost.java | 22 ++++++++++++++++++ .../xgboost/XGBoostFeatureImportance.java | 23 +++++++++---------- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/Classification/XGBoost/src/test/java/org/tribuo/classification/xgboost/TestXGBoost.java b/Classification/XGBoost/src/test/java/org/tribuo/classification/xgboost/TestXGBoost.java index d156e65fa..065cae5e4 100644 --- a/Classification/XGBoost/src/test/java/org/tribuo/classification/xgboost/TestXGBoost.java +++ b/Classification/XGBoost/src/test/java/org/tribuo/classification/xgboost/TestXGBoost.java @@ -28,6 +28,7 @@ import org.tribuo.classification.evaluation.LabelEvaluation; import org.tribuo.classification.evaluation.LabelEvaluator; import org.tribuo.classification.example.LabelledDataGenerator; +import org.tribuo.common.xgboost.XGBoostFeatureImportance; import org.tribuo.common.xgboost.XGBoostModel; import org.tribuo.data.text.TextDataSource; import org.tribuo.data.text.TextFeatureExtractor; @@ -178,6 +179,27 @@ public void testXGBoost(Pair,Dataset