diff --git a/core/src/main/scala/com/salesforce/op/OpWorkflow.scala b/core/src/main/scala/com/salesforce/op/OpWorkflow.scala index d1ebb22ddf..27459eb4ca 100644 --- a/core/src/main/scala/com/salesforce/op/OpWorkflow.scala +++ b/core/src/main/scala/com/salesforce/op/OpWorkflow.scala @@ -41,6 +41,7 @@ import com.salesforce.op.utils.reflection.ReflectionUtils import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.utils.stages.FitStagesUtil import com.salesforce.op.utils.stages.FitStagesUtil.{CutDAG, FittedDAG, Layer, StagesDAG} +import enumeratum.{Enum, EnumEntry} import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Transformer} import org.apache.spark.sql.{DataFrame, SparkSession} @@ -61,6 +62,9 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore { // raw feature filter stage which can be used in place of a reader private[op] var rawFeatureFilter: Option[RawFeatureFilter[_]] = None + // result feature retention policy + private[op] var resultFeaturePolicy: ResultFeatureRetention = ResultFeatureRetention.Strict + /** * Set stage and reader parameters from OpWorkflowParams object for run * @@ -78,7 +82,7 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore { * * By setting the final features the stages used to * generate them can be traced back through the parent features and origin stages. - * The input is an tuple of features to support leaf feature generation (multiple endpoints in feature generation). + * The input is a tuple of features to support leaf feature generation (multiple endpoints in feature generation). * * @param features Final features generated by the workflow */ @@ -104,6 +108,7 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore { this } + /** * Will set the blacklisted features variable and if list is non-empty it will * @param features list of features to blacklist @@ -112,17 +117,25 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore { private[op] def setBlacklist(features: Array[OPFeature], distributions: Seq[FeatureDistribution]): Unit = { // TODO: Figure out a way to keep track of raw features that aren't explicitly blacklisted, but can't be used // TODO: because they're inputs into an explicitly blacklisted feature. Eg. "height" in ModelInsightsTest + + def finalResultFeaturesCheck(resultFeatures: Array[OPFeature], blacklisted: List[OPFeature]): Unit = { + if (resultFeaturePolicy == ResultFeatureRetention.Strict) { + resultFeatures.foreach{ f => if (blacklisted.contains(f)) { + throw new IllegalArgumentException(s"Blacklist of features (${blacklisted.map(_.name).mkString(", ")})" + + s" from RawFeatureFilter contained the result feature ${f.name}") } } + } else if (resultFeaturePolicy == ResultFeatureRetention.AtLeastOne) { + if (resultFeatures.forall(blacklisted.contains)) throw new IllegalArgumentException(s"Blacklist of features" + + s" (${blacklisted.map(_.name).mkString(", ")}) from RawFeatureFilter removed all result features") + } else throw new IllegalArgumentException(s"result feature retention policy $resultFeaturePolicy not supported") + } + blacklistedFeatures = features if (blacklistedFeatures.nonEmpty) { val allBlacklisted: MList[OPFeature] = MList(getBlacklist(): _*) val allUpdated: MList[OPFeature] = MList.empty val initialResultFeatures = getResultFeatures() - initialResultFeatures - .foreach{ f => if (allBlacklisted.contains(f)) throw new IllegalArgumentException( - s"Blacklist of features (${allBlacklisted.map(_.name).mkString(", ")})" + - s" from RawFeatureFilter contained the result feature ${f.name}" ) - } + finalResultFeaturesCheck(initialResultFeatures, allBlacklisted.toList) val initialStages = getStages() // ordered by DAG so dont need to recompute DAG // for each stage remove anything blacklisted from the inputs and update any changed input features @@ -138,16 +151,15 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore { Try(stg.setInputFeatureArray(inputsChanged).setOutputFeatureName(oldOutput.name).getOutput()) match { case Success(out) => allUpdated += out case Failure(e) => - if (initialResultFeatures.contains(oldOutput)) throw new RuntimeException( - s"Blacklist of features (${allBlacklisted.map(_.name).mkString(", ")}) \n" + - s" created by RawFeatureFilter contained features critical to the creation of required result" + - s" feature (${oldOutput.name}) though the path: \n ${oldOutput.prettyParentStages} \n", e) - else allBlacklisted += oldOutput + log.info(s"Issue updating inputs for stage $stg: $e") + allBlacklisted += oldOutput + finalResultFeaturesCheck(initialResultFeatures, allBlacklisted.toList) } } // Update the whole DAG with the blacklisted features expunged val updatedResultFeatures = initialResultFeatures + .filterNot(allBlacklisted.contains) .map{ f => allUpdated.find(u => u.sameOrigin(f)).getOrElse(f) } setResultFeatures(updatedResultFeatures: _*) } @@ -339,6 +351,9 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore { val fittedStgs = fitStages(data = rawData, stagesToFit = stages, persistEveryKStages) val newResultFtrs = resultFeatures.map(_.copyWithNewStages(fittedStgs)) fittedStgs -> newResultFtrs + } else if (rawFeatureFilter.nonEmpty) { + generateRawData() + stages -> resultFeatures } else { stages -> resultFeatures } @@ -534,8 +549,10 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore { protectedJSFeatures: Array[OPFeature] = Array.empty, textBinsFormula: (Summary, Int) => Int = RawFeatureFilter.textBinsFormula, timePeriod: Option[TimePeriod] = None, - minScoringRows: Int = RawFeatureFilter.minScoringRowsDefault + minScoringRows: Int = RawFeatureFilter.minScoringRowsDefault, + resultFeatureRetentionPolicy: ResultFeatureRetention = ResultFeatureRetention.Strict ): this.type = { + resultFeaturePolicy = resultFeatureRetentionPolicy val training = trainingReader.orElse(reader).map(_.asInstanceOf[Reader[T]]) require(training.nonEmpty, "Reader for training data must be provided either in withRawFeatureFilter or directly" + "as the reader for the workflow") @@ -563,3 +580,14 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore { // scalastyle:on } + +/** + * Methods of vectorizing text (eg. to be chosen by statistics computed in SmartTextVectorizer) + */ +sealed trait ResultFeatureRetention extends EnumEntry with Serializable + +object ResultFeatureRetention extends Enum[ResultFeatureRetention] { + val values = findValues + case object Strict extends ResultFeatureRetention + case object AtLeastOne extends ResultFeatureRetention +} diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizer.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizer.scala index d75e42bf36..ea5235fe44 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizer.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/SmartTextVectorizer.scala @@ -179,7 +179,8 @@ object SmartTextVectorizer { * @param valueCounts counts of feature values * @param lengthCounts counts of token lengths */ -private[op] case class TextStats( +private[op] case class TextStats +( valueCounts: Map[String, Long], lengthCounts: Map[Int, Long] ) extends JsonLike { diff --git a/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala b/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala index 3a51c1b1f7..2b8d88e63d 100644 --- a/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala +++ b/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala @@ -206,9 +206,6 @@ class OpWorkflowTest extends FlatSpec with PassengerSparkFixtureTest { } it should "throw an error when it is not possible to remove blacklisted features" in { - val fv = Seq(age, gender, height, weight, description, boarded, stringMap, numericMap, booleanMap).transmogrify() - val survivedNum = survived.occurs() - val pred = BinaryClassificationModelSelector().setInput(survivedNum, fv).getOutput() val wf = new OpWorkflow() .setResultFeatures(whyNotNormed) .withRawFeatureFilter(Option(dataReader), None) @@ -219,6 +216,15 @@ class OpWorkflowTest extends FlatSpec with PassengerSparkFixtureTest { error.getMessage.contains("creation of required result feature (height-weight_4-stagesApplied_Real") } + it should "allow the removal of some final features with the retention policy is set to allow it" in { + val wf = new OpWorkflow() + .setResultFeatures(whyNotNormed, weight) + .withRawFeatureFilter(Option(dataReader), None, resultFeatureRetentionPolicy = ResultFeatureRetention.AtLeastOne) + + wf.setBlacklist(Array(age, gender, height, description, stringMap, numericMap), Seq.empty) + wf.getResultFeatures().map(_.name) shouldEqual Seq(weight).map(_.name) + } + it should "be able to compute a partial dataset in both workflow and workflow model" in { val fields = List(KeyFieldName, height.name, weight.name, heightNormed.name, density.name,