Skip to content

Commit

Permalink
Added support for full set of XGBoost feature importance metrics (#52)
Browse files Browse the repository at this point in the history
* Added support for full set of XGBoost feature importance metrics

* Made changes requested in PR #52

* tying toString in `XGBoostFeatureImportance` to parent model and adding importance method te `XGBoostExternalModel`.
  • Loading branch information
JackSullivan authored Sep 30, 2020
1 parent 835d874 commit 3b647a3
Show file tree
Hide file tree
Showing 4 changed files with 299 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.tribuo.classification.evaluation.LabelEvaluation;
import org.tribuo.classification.evaluation.LabelEvaluator;
import org.tribuo.classification.example.LabelledDataGenerator;
import org.tribuo.common.xgboost.XGBoostFeatureImportance;
import org.tribuo.common.xgboost.XGBoostModel;
import org.tribuo.data.text.TextDataSource;
import org.tribuo.data.text.TextFeatureExtractor;
Expand Down Expand Up @@ -178,6 +179,27 @@ public void testXGBoost(Pair<Dataset<Label>,Dataset<Label>> p) {
Assertions.assertFalse(features.isEmpty());
}

@Test
public void testFeatureImportanceSmokeTest() {
// we're just testing that not actually throws an exception
XGBoostModel<Label> m = (XGBoostModel<Label>)t.train(LabelledDataGenerator.denseTrainTest().getA());

XGBoostFeatureImportance i = m.getFeatureImportance().get(0);
i.getImportances();
i.getCover();
i.getGain();
i.getWeight();
i.getTotalCover();
i.getTotalGain();

i.getImportances(5);
i.getCover(5);
i.getGain(5);
i.getWeight(5);
i.getTotalCover(5);
i.getTotalGain(5);
}

@Test
public void testDenseData() {
Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.denseTrainTest();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import java.util.PriorityQueue;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;

/**
* A {@link Model} which wraps around a XGBoost.Booster which was trained by a system other than Tribuo.
Expand Down Expand Up @@ -148,6 +149,16 @@ protected List<Prediction<T>> convertOutput(float[][] output, int[] numValidFeat
return converter.convertBatchOutput(outputIDInfo,Collections.singletonList(output),numValidFeatures,(Example<T>[])examples.toArray(new Example[0]));
}

/**
* Creates objects to report feature importance metrics for XGBoost. See the documentation of {@link XGBoostFeatureImportance}
* for more information on what those metrics mean. Typically this list will contain a single instance for the entire
* model. For multidimensional regression the list will have one entry per dimension, in dimension order.
* @return The feature importance object(s).
*/
public List<XGBoostFeatureImportance> getFeatureImportance() {
return Collections.singletonList(new XGBoostFeatureImportance(model, this));
}

@Override
public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
try {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
package org.tribuo.common.xgboost;

import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Model;

import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static ml.dmlc.xgboost4j.java.Booster.FeatureImportanceType.*;

/**
* Generate and collate feature importance information from the XGBoost model. This wraps the underlying functionality
* of the XGBoost model, and should provide feature importance metrics compatible with those provided by XGBoost's R
* and Python APIs. For a more treatment of what the different importance metrics mean and how to interpret them, see
* <a href="https://xgboost.readthedocs.io/en/latest/R-package/discoverYourData.html">here</a>. In brief
*
* <ul>
* <li><b>Gain</b> measures the improvement in accuracy that a feature brings to the branches on which it appears.
* This represents the sum of situated marginal contributions that a given feature makes to the each branching
* chain in which it appears.</li>
* <li><b>Cover</b> measures the number of examples a given feature discriminates across, relative to the total
* number of examples all features discriminate across.</li>
* <li><b>Weight</b> measures the number a times a feature occurs in the model. Due to the way the model builds trees,
* this value is skewed in favor of continuous features.</li>
* <li><b>Total Gain</b> is similar to gain, but not locally averaged by weight, and thus not skewed in the way that
* weight can be skewed.</li>
* <li><b>Total Cover</b> is similar to cover, but not locally averaged by weight, and thus not skewed in the way that
* weight can be skewed.</li>
* </ul>
*/
public class XGBoostFeatureImportance {

/**
* An instance of feature importance values for a single feature. See {@link XGBoostFeatureImportance} for details
* on interpreting the metrics.
*/
public static class XGBoostFeatureImportanceInstance {

private final String featureName;
private final double gain;
private final double cover;
private final double weight;
private final double totalGain;
private final double totalCover;

XGBoostFeatureImportanceInstance(String featureName, double gain, double cover, double weight, double totalGain, double totalCover) {
this.featureName = featureName;
this.gain = gain;
this.cover = cover;
this.weight = weight;
this.totalGain = totalGain;
this.totalCover = totalCover;
}

public String getFeatureName() {
return featureName;
}

public double getGain() {
return gain;
}

public double getCover() {
return cover;
}

public double getWeight() {
return weight;
}

public double getTotalGain() {
return totalGain;
}

public double getTotalCover() {
return totalCover;
}

public String toString() {
return String.format("XGBoostFeatureImportanceRecord(feature=%s, gain=%.2f, cover=%.2f, weight=%.2f, totalGain=%.2f, totalCover=%.2f)",
featureName, gain, cover, weight, totalGain, totalCover);
}
}

private final Booster booster;
private final ImmutableFeatureMap featureMap;
private final Model<?> model;

XGBoostFeatureImportance(Booster booster, Model<?> model) {
this.booster = booster;
this.model = model;
this.featureMap = model.getFeatureIDMap();
}

private String translateFeatureId(String xgbFeatName) {
return featureMap.get(Integer.parseInt(xgbFeatName.substring(1))).getName();
}

private Stream<Map.Entry<String, Double>> getImportanceStream(String importanceType) {
try {
return booster.getScore("", importanceType).entrySet().stream()
.sorted(Comparator.comparingDouble((Map.Entry<String, Double> e) -> e.getValue()).reversed());
} catch (XGBoostError e) {
throw new IllegalStateException("Error generating feature importance for " + importanceType + " caused by", e);
}
}

private LinkedHashMap<String, Double> coalesceImportanceStream(Stream<Map.Entry<String, Double>> str) {
return str.collect(Collectors.toMap(e -> translateFeatureId(e.getKey()),
Map.Entry::getValue, (e1, e2) -> e1, LinkedHashMap::new));
}

/**
* Gain measures the improvement in accuracy that a feature brings to the branches on which it appears.
* This represents the sum of situated marginal contributions that a given feature makes to the each branching
* chain in which it appears.
* @return Ordered map where the keys are feature names and the value is the gain, sorted descending
*/
public LinkedHashMap<String, Double> getGain() {
return coalesceImportanceStream(getImportanceStream(GAIN));
}

/**
*
* Gain measures the improvement in accuracy that a feature brings to the branches on which it appears.
* This represents the sum of situated marginal contributions that a given feature makes to the each branching
* chain in which it appears. Returns only the top numFeatures features.
* @param numFeatures number of features to return
* @return Ordered map where the keys are feature names and the value is the gain, sorted descending
*/
public LinkedHashMap<String, Double> getGain(int numFeatures) {
return coalesceImportanceStream(getImportanceStream(GAIN).limit(numFeatures));
}

/**
* Cover measures the number of examples a given feature discriminates across, relative to the total
* number of examples all features discriminate across.
* @return Ordered map where the keys are feature names and the value is the cover, sorted descending
*/
public LinkedHashMap<String, Double> getCover() {
return coalesceImportanceStream(getImportanceStream(COVER));
}
/**
*
* Cover measures the number of examples a given feature discriminates across, relative to the total.
* number of examples all features discriminate across. Returns only the top numFeatures features.
* @param numFeatures number of features to return
* @return Ordered map where the keys are feature names and the value is the cover, sorted descending
*/
public LinkedHashMap<String, Double> getCover(int numFeatures) {
return coalesceImportanceStream(getImportanceStream(COVER).limit(numFeatures));
}

/**
* Weight measures the number a times a feature occurs in the model. Due to the way the model builds trees,
* this value is skewed in favor of continuous features.
* @return Ordered map where the keys are feature names and the value is the weight, sorted descending
*/
public LinkedHashMap<String, Double> getWeight() {
return coalesceImportanceStream(getImportanceStream(WEIGHT));
}
/**
* Weight measures the number a times a feature occurs in the model. Due to the way the model builds trees,
* this value is skewed in favor of continuous features. Returns only the top numFeatures features.
* @param numFeatures number of features to return
* @return Ordered map where the keys are feature names and the value is the weight, sorted descending
*/
public LinkedHashMap<String, Double> getWeight(int numFeatures) {
return coalesceImportanceStream(getImportanceStream(WEIGHT).limit(numFeatures));
}

/**
* Total Gain is similar to gain, but not locally averaged by weight, and thus not skewed in the way that
* weight can be skewed.
* @return Ordered map where the keys are feature names and the value is the total gain, sorted descending
*/
public LinkedHashMap<String, Double> getTotalGain() {
return coalesceImportanceStream(getImportanceStream(TOTAL_GAIN));
}
/**
* Total Gain is similar to gain, but not locally averaged by weight, and thus not skewed in the way that
* weight can be skewed. Returns only top numFeatures features.
* @param numFeatures number of features to return
* @return Ordered map where the keys are feature names and the value is the total gain, sorted descending
*/
public LinkedHashMap<String, Double> getTotalGain(int numFeatures) {
return coalesceImportanceStream(getImportanceStream(TOTAL_GAIN).limit(numFeatures));
}

/**
* Total Cover is similar to cover, but not locally averaged by weight, and thus not skewed in the way that
* weight can be skewed.
* @return Ordered map where the keys are feature names and the value is the total gain, sorted descending
*/
public LinkedHashMap<String, Double> getTotalCover() {
return coalesceImportanceStream(getImportanceStream(TOTAL_COVER));
}
/**
* Total Cover is similar to cover, but not locally averaged by weight, and thus not skewed in the way that
* weight can be skewed. Returns only top numFeatures features.
* @return Ordered map where the keys are feature names and the value is the total gain, sorted descending
*/
public LinkedHashMap<String, Double> getTotalCover(int numFeatures) {
return coalesceImportanceStream(getImportanceStream(TOTAL_COVER).limit(numFeatures));
}

/**
* @return records of all importance metrics for each feature, sorted by gain.
*/
public List<XGBoostFeatureImportanceInstance> getImportances() {
Map<String, LinkedHashMap<String, Double>> importanceByType = Stream.of(GAIN, COVER, WEIGHT, TOTAL_GAIN, TOTAL_COVER)
.map(importanceType -> new AbstractMap.SimpleEntry<>(importanceType, coalesceImportanceStream(getImportanceStream(importanceType))))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
// this is already sorted by gain
List<String> features = new ArrayList<>(importanceByType.get(GAIN).keySet());
return features.stream().map(featureName -> new XGBoostFeatureImportanceInstance(featureName,
importanceByType.get(GAIN).get(featureName),
importanceByType.get(COVER).get(featureName),
importanceByType.get(WEIGHT).get(featureName),
importanceByType.get(TOTAL_GAIN).get(featureName),
importanceByType.get(TOTAL_COVER).get(featureName)))
.collect(Collectors.toList());
}

/**
* @param numFeatures number of features to return
* @return records of all importance metrics for each feature, sorted by gain.
*/
public List<XGBoostFeatureImportanceInstance> getImportances(int numFeatures) {
Map<String, LinkedHashMap<String, Double>> importanceByType = Stream.of(GAIN, COVER, WEIGHT, TOTAL_GAIN, TOTAL_COVER)
.map(importanceType -> new AbstractMap.SimpleEntry<>(importanceType, coalesceImportanceStream(getImportanceStream(importanceType))))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
// this is already sorted by gain
List<String> features = new ArrayList<>(importanceByType.get(GAIN).keySet()).subList(0, Math.min(importanceByType.get(GAIN).keySet().size(), numFeatures));
return features.stream().map(featureName -> new XGBoostFeatureImportanceInstance(featureName,
importanceByType.get(GAIN).get(featureName),
importanceByType.get(COVER).get(featureName),
importanceByType.get(WEIGHT).get(featureName),
importanceByType.get(TOTAL_GAIN).get(featureName),
importanceByType.get(TOTAL_COVER).get(featureName)))
.collect(Collectors.toList());
}

public String toString() {
return String.format("XGBoostFeatureImportance(model=%s)", model.toString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import java.util.PriorityQueue;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;

/**
* A {@link Model} which wraps around a XGBoost.Booster.
Expand Down Expand Up @@ -156,6 +157,16 @@ public Prediction<T> predict(Example<T> example) {
}
}

/**
* Creates objects to report feature importance metrics for XGBoost. See the documentation of {@link XGBoostFeatureImportance}
* for more information on what those metrics mean. Typically this list will contain a single instance for the entire
* model. For multidimensional regression the list will have one entry per dimension, in dimension order.
* @return The feature importance object(s).
*/
public List<XGBoostFeatureImportance> getFeatureImportance() {
return models.stream().map(b -> new XGBoostFeatureImportance(b, this)).collect(Collectors.toList());
}

@Override
public Map<String, List<Pair<String,Double>>> getTopFeatures(int n) {
try {
Expand Down

0 comments on commit 3b647a3

Please sign in to comment.