Skip to content

Commit

Permalink
Add distributions calculated in RawFeatureFilter to ModelInsights (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jauntbox authored and tovbinm committed Aug 30, 2018
1 parent be1748a commit a8eaf4b
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 18 deletions.
22 changes: 16 additions & 6 deletions core/src/main/scala/com/salesforce/op/ModelInsights.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ package com.salesforce.op
import com.salesforce.op.evaluators._
import com.salesforce.op.features._
import com.salesforce.op.features.types._
import com.salesforce.op.filters.FeatureDistribution
import com.salesforce.op.stages._
import com.salesforce.op.stages.impl.feature.TransmogrifierDefaults
import com.salesforce.op.stages.impl.preparators._
Expand Down Expand Up @@ -311,8 +312,14 @@ case class Discrete(domain: Seq[String], prob: Seq[Double]) extends LabelInfo
* @param featureName name of raw feature insights are about
* @param featureType type of raw feature insights are about
* @param derivedFeatures sequence containing insights for each feature derived from the raw feature
* @param distributions distribution information for the raw feature (if calculated in RawFeatureFilter)
*/
case class FeatureInsights(featureName: String, featureType: String, derivedFeatures: Seq[Insights])
case class FeatureInsights(
featureName: String,
featureType: String,
derivedFeatures: Seq[Insights],
distributions: Seq[FeatureDistribution] = Seq.empty
)

/**
* Summary of insights for a derived feature
Expand Down Expand Up @@ -405,7 +412,8 @@ case object ModelInsights {
rawFeatures: Array[features.OPFeature],
trainingParams: OpParams,
blacklistedFeatures: Array[features.OPFeature],
blacklistedMapKeys: Map[String, Set[String]]
blacklistedMapKeys: Map[String, Set[String]],
rawFeatureDistributions: Array[FeatureDistribution]
): ModelInsights = {
val sanityCheckers = stages.collect { case s: SanityCheckerModel => s }
val sanityChecker = sanityCheckers.lastOption
Expand All @@ -430,7 +438,6 @@ case object ModelInsights {
val label = model.map(_.getInputFeature[RealNN](0)).orElse(sanityChecker.map(_.getInputFeature[RealNN](0))).flatten
log.info(s"Found ${label.map(_.name + " as label").getOrElse("no label")} to fill in model insights")


// Recover the vector metadata
val vectorInput: Option[OpVectorMetadata] = {
def makeMeta(s: => OpPipelineStageParams) = Try(OpVectorMetadata(s.getInputSchema().last)).toOption
Expand All @@ -453,7 +460,7 @@ case object ModelInsights {
ModelInsights(
label = getLabelSummary(label, checkerSummary),
features = getFeatureInsights(vectorInput, checkerSummary, model, rawFeatures,
blacklistedFeatures, blacklistedMapKeys),
blacklistedFeatures, blacklistedMapKeys, rawFeatureDistributions),
selectedModelInfo = getModelInfo(model),
trainingParams = trainingParams,
stageInfo = getStageInfo(stages)
Expand Down Expand Up @@ -501,7 +508,8 @@ case object ModelInsights {
model: Option[Model[_]],
rawFeatures: Array[features.OPFeature],
blacklistedFeatures: Array[features.OPFeature],
blacklistedMapKeys: Map[String, Set[String]]
blacklistedMapKeys: Map[String, Set[String]],
rawFeatureDistributions: Array[FeatureDistribution]
): Seq[FeatureInsights] = {
val contributions = getModelContributions(model)

Expand Down Expand Up @@ -587,7 +595,9 @@ case object ModelInsights {
val ftype = allFeatures.find(_.name == fname)
.getOrElse(throw new RuntimeException(s"No raw feature with name $fname found in raw features"))
.typeName
FeatureInsights(featureName = fname, featureType = ftype, derivedFeatures = seq.map(_._2))
val distributions = rawFeatureDistributions.filter(_.name == fname)
FeatureInsights(featureName = fname, featureType = ftype, derivedFeatures = seq.map(_._2),
distributions = distributions)
}.toSeq
}

Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/com/salesforce/op/OpWorkflow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore {
* @param distributions feature distributions calculated in raw feature filter
*/
private[op] def setBlacklist(features: Array[OPFeature], distributions: Seq[FeatureDistribution]): Unit = {
// TODO: Figure out a way to keep track of raw features that aren't explicitly blacklisted, but can't be used
// TODO: because they're inputs into an explicitly blacklisted feature. Eg. "height" in ModelInsightsTest
blacklistedFeatures = features
if (blacklistedFeatures.nonEmpty) {
val allBlacklisted: MList[OPFeature] = MList(getBlacklist(): _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams
val parentStageIds = feature.traverse[Set[String]](Set.empty[String])((s, f) => s + f.originStage.uid)
val modelStages = stages.filter(s => parentStageIds.contains(s.uid))
ModelInsights.extractFromStages(modelStages, rawFeatures, trainingParams,
blacklistedFeatures, blacklistedMapKeys)
blacklistedFeatures, blacklistedMapKeys, rawFeatureDistributions)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ import org.slf4j.LoggerFactory
/**
* Specialized stage that will load up data and compute distributions and empty counts on raw features.
* This information is then used to compute which raw features should be excluded from the workflow DAG
* Note: Currently, raw features that aren't explicitly blacklisted, but are not used because they are inputs to
* explicitly blacklisted features are not present as raw features in the model, nor in ModelInsights. However, they
* are accessible from an OpWorkflowModel via getRawFeatureDistributions().
*
* @param trainingReader reader to get the training data
* @param scoreReader reader to get the scoring data for comparison (optional - if not present will exclude based on
* training data features only)
Expand Down
46 changes: 35 additions & 11 deletions core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ package com.salesforce.op
import com.salesforce.op.evaluators.{EvalMetric, EvaluationMetrics}
import com.salesforce.op.features.Feature
import com.salesforce.op.features.types.{PickList, Real, RealNN}
import com.salesforce.op.filters.FeatureDistribution
import com.salesforce.op.stages.impl.classification.{BinaryClassificationModelSelector, BinaryClassificationModelsToTry, OpLogisticRegression}
import com.salesforce.op.stages.impl.preparators._
import com.salesforce.op.stages.impl.regression.{OpLinearRegression, RegressionModelSelector}
Expand Down Expand Up @@ -110,6 +111,14 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest {

lazy val workflowModel = workflow.train()

lazy val modelWithRFF = new OpWorkflow()
.setResultFeatures(predWithMaps)
.setParameters(params)
.withRawFeatureFilter(Option(dataReader), Option(simpleReader), bins = 10, minFillRate = 0.0,
maxFillDifference = 1.0, maxFillRatioDiff = Double.PositiveInfinity,
maxJSDivergence = 1.0, maxCorrelation = 0.4)
.train()

val rawNames = Set(age.name, weight.name, height.name, genderPL.name, description.name)

Spec[ModelInsights] should "throw an error when you try to get insights on a raw feature" in {
Expand Down Expand Up @@ -318,21 +327,14 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest {
}

it should "have feature insights for features that are removed by the raw feature filter" in {
val insights = modelWithRFF.modelInsights(predWithMaps)

val model = new OpWorkflow()
.setResultFeatures(predWithMaps)
.setParameters(params)
.withRawFeatureFilter(Option(dataReader), Option(simpleReader), bins = 10, minFillRate = 0.0,
maxFillDifference = 1.0, maxFillRatioDiff = Double.PositiveInfinity,
maxJSDivergence = 1.0, maxCorrelation = 0.4)
.train()
val insights = model.modelInsights(predWithMaps)
model.blacklistedFeatures should contain theSameElementsAs Array(age, description, genderPL, weight)
modelWithRFF.blacklistedFeatures should contain theSameElementsAs Array(age, description, genderPL, weight)
val heightIn = insights.features.find(_.featureName == age.name).get
heightIn.derivedFeatures.size shouldBe 1
heightIn.derivedFeatures.head.excluded shouldBe Some(true)

model.blacklistedMapKeys should contain theSameElementsAs Map(numericMap.name -> Set("Female"))
modelWithRFF.blacklistedMapKeys should contain theSameElementsAs Map(numericMap.name -> Set("Female"))
val mapDerivedIn = insights.features.find(_.featureName == numericMap.name).get.derivedFeatures
val droppedMapDerivedIn = mapDerivedIn.filter(_.derivedFeatureName == "Female")
mapDerivedIn.size shouldBe 3
Expand Down Expand Up @@ -409,7 +411,8 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest {

it should "correctly extract the FeatureInsights from the sanity checker summary and vector metadata" in {
val featureInsights = ModelInsights.getFeatureInsights(
Option(meta), Option(summary), None, Array(f1, f0), Array.empty, Map.empty[String, Set[String]]
Option(meta), Option(summary), None, Array(f1, f0), Array.empty, Map.empty[String, Set[String]],
Array.empty[FeatureDistribution]
)
featureInsights.size shouldBe 2

Expand Down Expand Up @@ -474,4 +477,25 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest {
f0InDer3.variance shouldBe Some(3.3)
}

it should "include raw feature distribution information when RawFeatureFilter is used" in {
val wfRawFeatureDistributions = modelWithRFF.getRawFeatureDistributions()
val wfDistributionsGrouped = wfRawFeatureDistributions.groupBy(_.name)

/*
Currently, raw features that aren't explicitly blacklisted, but are not used because they are inputs to
explicitly blacklisted features are not present as raw features in the model, nor in ModelInsights. For example,
weight is explicitly blacklisted here, which means that height will not be added as a raw feature even though
it's not explicitly blacklisted itself.
*/
val insights = modelWithRFF.modelInsights(predWithMaps)
insights.features.foreach(f =>
f.distributions should contain theSameElementsAs wfDistributionsGrouped.getOrElse(f.featureName, Array.empty)
)
}

it should "not include raw feature distribution information when RawFeatureFilter is not used" in {
val insights = workflowModel.modelInsights(pred)
insights.features.foreach(f => f.distributions shouldBe empty)
}

}

0 comments on commit a8eaf4b

Please sign in to comment.