Skip to content

Commit

Permalink
fixed model insights exception when features are excluded from sanity… (
Browse files Browse the repository at this point in the history
  • Loading branch information
leahmcguire authored Oct 4, 2018
1 parent c3cce7e commit 5fb59c7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
6 changes: 3 additions & 3 deletions core/src/main/scala/com/salesforce/op/ModelInsights.scala
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,8 @@ case object ModelInsights {
.map {
case (fname, seq) =>
val ftype = allFeatures.find(_.name == fname)
.getOrElse(throw new RuntimeException(s"No raw feature with name $fname found in raw features"))
.typeName
.map(_.typeName)
.getOrElse("")
val distributions = rawFeatureDistributions.filter(_.name == fname)
FeatureInsights(featureName = fname, featureType = ftype, derivedFeatures = seq.map(_._2),
distributions = distributions)
Expand All @@ -626,7 +626,7 @@ case object ModelInsights {
getIfExists(corr.featuresIn.indexOf(name), corr.values).orElse {
val j = corr.nanCorrs.indexOf(name)
if (j >= 0) Option(Double.NaN)
else throw new RuntimeException(s"Column name $name does not exist in summary correlations")
else None
}
}

Expand Down
23 changes: 21 additions & 2 deletions core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,16 @@ package com.salesforce.op

import com.salesforce.op.evaluators.{EvalMetric, EvaluationMetrics}
import com.salesforce.op.features.Feature
import com.salesforce.op._
import com.salesforce.op.features.types.{PickList, Real, RealNN}
import com.salesforce.op.filters.FeatureDistribution
import com.salesforce.op.stages.impl.classification.{BinaryClassificationModelSelector, BinaryClassificationModelsToTry, OpLogisticRegression}
import com.salesforce.op.stages.impl.classification.{BinaryClassificationModelSelector, BinaryClassificationModelsToTry, MultiClassificationModelSelector, OpLogisticRegression}
import com.salesforce.op.stages.impl.preparators._
import com.salesforce.op.stages.impl.regression.{OpLinearRegression, RegressionModelSelector}
import com.salesforce.op.stages.impl.selector.ModelSelectorNames.EstimatorType
import com.salesforce.op.stages.impl.selector.ValidationType._
import com.salesforce.op.stages.impl.selector.{ModelEvaluation, ProblemType, SelectedModel, ValidationType}
import com.salesforce.op.stages.impl.tuning.{DataSplitter, SplitterSummary}
import com.salesforce.op.stages.impl.tuning.{DataCutter, DataSplitter, SplitterSummary}
import com.salesforce.op.test.PassengerSparkFixtureTest
import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata}
import org.apache.spark.ml.param.ParamMap
Expand Down Expand Up @@ -152,6 +153,24 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest {
insights.stageInfo.keys.size shouldEqual 8
}

it should "return model insights even when correlation is turned off for some features" in {
val featuresFinal = Seq(
description.vectorize(numHashes = 10, autoDetectLanguage = false, minTokenLength = 1, toLowercase = true),
stringMap.vectorize(cleanText = true, numHashes = 10)
).combine()
val featuresChecked = label.sanityCheck(featuresFinal, correlationExclusion = CorrelationExclusion.HashedText)
val prediction = MultiClassificationModelSelector
.withCrossValidation(seed = 42, splitter = Option(DataCutter(seed = 42, reserveTestFraction = 0.1)),
modelsAndParameters = models)
.setInput(label, featuresChecked)
.getOutput()
val workflow = new OpWorkflow().setResultFeatures(prediction).setParameters(params).setReader(dataReader)
val workflowModel = workflow.train()
val insights = workflowModel.modelInsights(prediction)
insights.features.size shouldBe 2
insights.features.flatMap(_.derivedFeatures).size shouldBe 23
}

it should "return feature insights with selector info and label info even when no models are found" in {
val insights = workflowModel.modelInsights(checked)
val ageInsights = insights.features.filter(_.featureName == age.name).head
Expand Down

0 comments on commit 5fb59c7

Please sign in to comment.