Skip to content

Commit

Permalink
allow result features to be removed by raw feature filter (#458)
Browse files Browse the repository at this point in the history
  • Loading branch information
leahmcguire authored Jan 24, 2020
1 parent b7e07e3 commit a51212a
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 16 deletions.
52 changes: 40 additions & 12 deletions core/src/main/scala/com/salesforce/op/OpWorkflow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import com.salesforce.op.utils.reflection.ReflectionUtils
import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.utils.stages.FitStagesUtil
import com.salesforce.op.utils.stages.FitStagesUtil.{CutDAG, FittedDAG, Layer, StagesDAG}
import enumeratum.{Enum, EnumEntry}
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{Estimator, Transformer}
import org.apache.spark.sql.{DataFrame, SparkSession}
Expand All @@ -61,6 +62,9 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore {
// raw feature filter stage which can be used in place of a reader
private[op] var rawFeatureFilter: Option[RawFeatureFilter[_]] = None

// result feature retention policy
private[op] var resultFeaturePolicy: ResultFeatureRetention = ResultFeatureRetention.Strict

/**
* Set stage and reader parameters from OpWorkflowParams object for run
*
Expand All @@ -78,7 +82,7 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore {
*
* By setting the final features the stages used to
* generate them can be traced back through the parent features and origin stages.
* The input is an tuple of features to support leaf feature generation (multiple endpoints in feature generation).
* The input is a tuple of features to support leaf feature generation (multiple endpoints in feature generation).
*
* @param features Final features generated by the workflow
*/
Expand All @@ -104,6 +108,7 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore {
this
}


/**
* Will set the blacklisted features variable and if list is non-empty it will
* @param features list of features to blacklist
Expand All @@ -112,17 +117,25 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore {
private[op] def setBlacklist(features: Array[OPFeature], distributions: Seq[FeatureDistribution]): Unit = {
// TODO: Figure out a way to keep track of raw features that aren't explicitly blacklisted, but can't be used
// TODO: because they're inputs into an explicitly blacklisted feature. Eg. "height" in ModelInsightsTest

def finalResultFeaturesCheck(resultFeatures: Array[OPFeature], blacklisted: List[OPFeature]): Unit = {
if (resultFeaturePolicy == ResultFeatureRetention.Strict) {
resultFeatures.foreach{ f => if (blacklisted.contains(f)) {
throw new IllegalArgumentException(s"Blacklist of features (${blacklisted.map(_.name).mkString(", ")})" +
s" from RawFeatureFilter contained the result feature ${f.name}") } }
} else if (resultFeaturePolicy == ResultFeatureRetention.AtLeastOne) {
if (resultFeatures.forall(blacklisted.contains)) throw new IllegalArgumentException(s"Blacklist of features" +
s" (${blacklisted.map(_.name).mkString(", ")}) from RawFeatureFilter removed all result features")
} else throw new IllegalArgumentException(s"result feature retention policy $resultFeaturePolicy not supported")
}

blacklistedFeatures = features
if (blacklistedFeatures.nonEmpty) {
val allBlacklisted: MList[OPFeature] = MList(getBlacklist(): _*)
val allUpdated: MList[OPFeature] = MList.empty

val initialResultFeatures = getResultFeatures()
initialResultFeatures
.foreach{ f => if (allBlacklisted.contains(f)) throw new IllegalArgumentException(
s"Blacklist of features (${allBlacklisted.map(_.name).mkString(", ")})" +
s" from RawFeatureFilter contained the result feature ${f.name}" )
}
finalResultFeaturesCheck(initialResultFeatures, allBlacklisted.toList)

val initialStages = getStages() // ordered by DAG so dont need to recompute DAG
// for each stage remove anything blacklisted from the inputs and update any changed input features
Expand All @@ -138,16 +151,15 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore {
Try(stg.setInputFeatureArray(inputsChanged).setOutputFeatureName(oldOutput.name).getOutput()) match {
case Success(out) => allUpdated += out
case Failure(e) =>
if (initialResultFeatures.contains(oldOutput)) throw new RuntimeException(
s"Blacklist of features (${allBlacklisted.map(_.name).mkString(", ")}) \n" +
s" created by RawFeatureFilter contained features critical to the creation of required result" +
s" feature (${oldOutput.name}) though the path: \n ${oldOutput.prettyParentStages} \n", e)
else allBlacklisted += oldOutput
log.info(s"Issue updating inputs for stage $stg: $e")
allBlacklisted += oldOutput
finalResultFeaturesCheck(initialResultFeatures, allBlacklisted.toList)
}
}

// Update the whole DAG with the blacklisted features expunged
val updatedResultFeatures = initialResultFeatures
.filterNot(allBlacklisted.contains)
.map{ f => allUpdated.find(u => u.sameOrigin(f)).getOrElse(f) }
setResultFeatures(updatedResultFeatures: _*)
}
Expand Down Expand Up @@ -339,6 +351,9 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore {
val fittedStgs = fitStages(data = rawData, stagesToFit = stages, persistEveryKStages)
val newResultFtrs = resultFeatures.map(_.copyWithNewStages(fittedStgs))
fittedStgs -> newResultFtrs
} else if (rawFeatureFilter.nonEmpty) {
generateRawData()
stages -> resultFeatures
} else {
stages -> resultFeatures
}
Expand Down Expand Up @@ -534,8 +549,10 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore {
protectedJSFeatures: Array[OPFeature] = Array.empty,
textBinsFormula: (Summary, Int) => Int = RawFeatureFilter.textBinsFormula,
timePeriod: Option[TimePeriod] = None,
minScoringRows: Int = RawFeatureFilter.minScoringRowsDefault
minScoringRows: Int = RawFeatureFilter.minScoringRowsDefault,
resultFeatureRetentionPolicy: ResultFeatureRetention = ResultFeatureRetention.Strict
): this.type = {
resultFeaturePolicy = resultFeatureRetentionPolicy
val training = trainingReader.orElse(reader).map(_.asInstanceOf[Reader[T]])
require(training.nonEmpty, "Reader for training data must be provided either in withRawFeatureFilter or directly" +
"as the reader for the workflow")
Expand Down Expand Up @@ -563,3 +580,14 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore {
// scalastyle:on

}

/**
* Methods of vectorizing text (eg. to be chosen by statistics computed in SmartTextVectorizer)
*/
sealed trait ResultFeatureRetention extends EnumEntry with Serializable

object ResultFeatureRetention extends Enum[ResultFeatureRetention] {
val values = findValues
case object Strict extends ResultFeatureRetention
case object AtLeastOne extends ResultFeatureRetention
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ object SmartTextVectorizer {
* @param valueCounts counts of feature values
* @param lengthCounts counts of token lengths
*/
private[op] case class TextStats(
private[op] case class TextStats
(
valueCounts: Map[String, Long],
lengthCounts: Map[Int, Long]
) extends JsonLike {
Expand Down
12 changes: 9 additions & 3 deletions core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,6 @@ class OpWorkflowTest extends FlatSpec with PassengerSparkFixtureTest {
}

it should "throw an error when it is not possible to remove blacklisted features" in {
val fv = Seq(age, gender, height, weight, description, boarded, stringMap, numericMap, booleanMap).transmogrify()
val survivedNum = survived.occurs()
val pred = BinaryClassificationModelSelector().setInput(survivedNum, fv).getOutput()
val wf = new OpWorkflow()
.setResultFeatures(whyNotNormed)
.withRawFeatureFilter(Option(dataReader), None)
Expand All @@ -219,6 +216,15 @@ class OpWorkflowTest extends FlatSpec with PassengerSparkFixtureTest {
error.getMessage.contains("creation of required result feature (height-weight_4-stagesApplied_Real")
}

it should "allow the removal of some final features with the retention policy is set to allow it" in {
val wf = new OpWorkflow()
.setResultFeatures(whyNotNormed, weight)
.withRawFeatureFilter(Option(dataReader), None, resultFeatureRetentionPolicy = ResultFeatureRetention.AtLeastOne)

wf.setBlacklist(Array(age, gender, height, description, stringMap, numericMap), Seq.empty)
wf.getResultFeatures().map(_.name) shouldEqual Seq(weight).map(_.name)
}

it should "be able to compute a partial dataset in both workflow and workflow model" in {
val fields =
List(KeyFieldName, height.name, weight.name, heightNormed.name, density.name,
Expand Down

0 comments on commit a51212a

Please sign in to comment.