Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for full set of XGBoost feature importance metrics #52

Merged
merged 3 commits into from
Sep 30, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.tribuo.classification.evaluation.LabelEvaluation;
import org.tribuo.classification.evaluation.LabelEvaluator;
import org.tribuo.classification.example.LabelledDataGenerator;
import org.tribuo.common.xgboost.XGBoostFeatureImportance;
import org.tribuo.common.xgboost.XGBoostModel;
import org.tribuo.data.text.TextDataSource;
import org.tribuo.data.text.TextFeatureExtractor;
Expand Down Expand Up @@ -178,6 +179,27 @@ public void testXGBoost(Pair<Dataset<Label>,Dataset<Label>> p) {
Assertions.assertFalse(features.isEmpty());
}

@Test
public void testFeatureImportanceSmokeTest() {
// we're just testing that not actually throws an exception
XGBoostModel<Label> m = (XGBoostModel<Label>)t.train(LabelledDataGenerator.denseTrainTest().getA());

XGBoostFeatureImportance i = m.getFeatureImportance().get(0);
i.getImportances();
i.getCover();
i.getGain();
i.getWeight();
i.getTotalCover();
i.getTotalGain();

i.getImportances(5);
i.getCover(5);
i.getGain(5);
i.getWeight(5);
i.getTotalCover(5);
i.getTotalGain(5);
}

@Test
public void testDenseData() {
Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.denseTrainTest();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ public class XGBoostFeatureImportance {
* An instance of feature importance values for a single feature. See {@link XGBoostFeatureImportance} for details
* on interpreting the metrics.
*/
public static class XGBoostFeatureImportanceRecord {
public static class XGBoostFeatureImportanceInstance {

private String featureName;
private final String featureName;
private final double gain;
private final double cover;
private final double weight;
private final double totalGain;
private final double totalCover;

XGBoostFeatureImportanceRecord(String featureName, double gain, double cover, double weight, double totalGain, double totalCover) {
XGBoostFeatureImportanceInstance(String featureName, double gain, double cover, double weight, double totalGain, double totalCover) {
this.featureName = featureName;
this.gain = gain;
this.cover = cover;
Expand Down Expand Up @@ -89,8 +89,8 @@ public String toString() {
}
}

private Booster booster;
private ImmutableFeatureMap featureMap;
private final Booster booster;
private final ImmutableFeatureMap featureMap;

XGBoostFeatureImportance(Booster booster, ImmutableFeatureMap featureMap) {
this.booster = booster;
Expand Down Expand Up @@ -210,14 +210,15 @@ public LinkedHashMap<String, Double> getTotalCover(int numFeatures) {
}

/**
* @return records of all importance metrics for each feature, sorted by gain.
*/
public List<XGBoostFeatureImportanceRecord> getImportances() {
public List<XGBoostFeatureImportanceInstance> getImportances() {
Map<String, LinkedHashMap<String, Double>> importanceByType = Stream.of(GAIN, COVER, WEIGHT, TOTAL_GAIN, TOTAL_COVER)
.map(importanceType -> new AbstractMap.SimpleEntry<>(importanceType, coalesceImportanceStream(getImportanceStream(importanceType))))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
// this is already sorted by gain
List<String> features = new ArrayList<>(importanceByType.get(GAIN).keySet());
return features.stream().map(featureName -> new XGBoostFeatureImportanceRecord(featureName,
return features.stream().map(featureName -> new XGBoostFeatureImportanceInstance(featureName,
importanceByType.get(GAIN).get(featureName),
importanceByType.get(COVER).get(featureName),
importanceByType.get(WEIGHT).get(featureName),
Expand All @@ -230,13 +231,13 @@ public List<XGBoostFeatureImportanceRecord> getImportances() {
* @param numFeatures number of features to return
* @return records of all importance metrics for each feature, sorted by gain.
*/
public List<XGBoostFeatureImportanceRecord> getImportances(int numFeatures) {
public List<XGBoostFeatureImportanceInstance> getImportances(int numFeatures) {
Map<String, LinkedHashMap<String, Double>> importanceByType = Stream.of(GAIN, COVER, WEIGHT, TOTAL_GAIN, TOTAL_COVER)
.map(importanceType -> new AbstractMap.SimpleEntry<>(importanceType, coalesceImportanceStream(getImportanceStream(importanceType))))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
// this is already sorted by gain
List<String> features = new ArrayList<>(importanceByType.get(GAIN).keySet()).subList(0, Math.min(importanceByType.get(GAIN).keySet().size(), numFeatures));
return features.stream().map(featureName -> new XGBoostFeatureImportanceRecord(featureName,
return features.stream().map(featureName -> new XGBoostFeatureImportanceInstance(featureName,
importanceByType.get(GAIN).get(featureName),
importanceByType.get(COVER).get(featureName),
importanceByType.get(WEIGHT).get(featureName),
Expand All @@ -246,8 +247,6 @@ public List<XGBoostFeatureImportanceRecord> getImportances(int numFeatures) {
}

public String toString() {
return "XGBoostFeatureImportance(" + getImportances(5).stream()
.map(XGBoostFeatureImportanceRecord::toString)
.collect(Collectors.joining(",\n\t")) + ")";
return String.format("XGBoostFeatureImportance(booster=%s, featureIdMap=%s)", booster.toString(), featureMap.toString());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be better to return a toString that refers to the model provenance? The booster and featureIdMap toStrings aren't very helpful.

}
}