Skip to content

Commit

Permalink
Serialize blacklisted map keys with the model + updated access on wo…
Browse files Browse the repository at this point in the history
…rkflow/model members (#320)
  • Loading branch information
tovbinm authored May 15, 2019
1 parent 18a9243 commit 6f55dee
Show file tree
Hide file tree
Showing 13 changed files with 140 additions and 81 deletions.
3 changes: 2 additions & 1 deletion core/src/main/scala/com/salesforce/op/OpWorkflow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,8 @@ class OpWorkflow(val uid: String = UID[OpWorkflow]) extends OpWorkflowCore {
* needed to generate the features not included in the fitted model
*/
def withModelStages(model: OpWorkflowModel): this.type = {
val newResultFeatures = (resultFeatures ++ model.getResultFeatures()).map(_.copyWithNewStages(model.stages))
val newResultFeatures =
(resultFeatures ++ model.getResultFeatures()).map(_.copyWithNewStages(model.getStages()))
setResultFeatures(newResultFeatures: _*)
}

Expand Down
51 changes: 41 additions & 10 deletions core/src/main/scala/com/salesforce/op/OpWorkflowCore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,31 +57,31 @@ private[op] trait OpWorkflowCore {
def uid: String

// whether the CV/TV is performed on the workflow level
private[op] var isWorkflowCV = false
protected var isWorkflowCVEnabled = false

// the data reader for the workflow or model
private[op] var reader: Option[Reader[_]] = None
protected var reader: Option[Reader[_]] = None

// final features from workflow, used to find stages of the workflow
private[op] var resultFeatures: Array[OPFeature] = Array[OPFeature]()
protected var resultFeatures: Array[OPFeature] = Array[OPFeature]()

// raw features generated after data is read in and aggregated
private[op] var rawFeatures: Array[OPFeature] = Array[OPFeature]()
protected var rawFeatures: Array[OPFeature] = Array[OPFeature]()

// features that have been blacklisted from use in dag
private[op] var blacklistedFeatures: Array[OPFeature] = Array[OPFeature]()
protected var blacklistedFeatures: Array[OPFeature] = Array[OPFeature]()

// map keys that were blacklisted from use in dag
private[op] var blacklistedMapKeys: Map[String, Set[String]] = Map[String, Set[String]]()
protected var blacklistedMapKeys: Map[String, Set[String]] = Map[String, Set[String]]()

// raw feature filter results calculated in raw feature filter
private[op] var rawFeatureFilterResults: RawFeatureFilterResults = RawFeatureFilterResults()
protected var rawFeatureFilterResults: RawFeatureFilterResults = RawFeatureFilterResults()

// stages of the workflow
private[op] var stages: Array[OPStage] = Array[OPStage]()
protected var stages: Array[OPStage] = Array[OPStage]()

// command line parameters for the workflow stages and readers
private[op] var parameters = new OpParams()
protected var parameters = new OpParams()

private[op] def setStages(value: Array[OPStage]): this.type = {
stages = value
Expand All @@ -102,10 +102,16 @@ private[op] trait OpWorkflowCore {
*/
@Experimental
final def withWorkflowCV: this.type = {
isWorkflowCV = true
isWorkflowCVEnabled = true
this
}

/**
* Whether the cross-validation/train-validation-split will be done at workflow level
*g c
* @return true if the cross-validation will be done at workflow level, false otherwise
*/
final def isWorkflowCV: Boolean = isWorkflowCVEnabled

/**
* Set data reader that will be used to generate data frame for stages
Expand All @@ -119,6 +125,15 @@ private[op] trait OpWorkflowCore {
this
}

/**
* Get data reader that will be used to generate data frame for stages
*
* @return reader for workflow
*/
final def getReader(): Reader[_] = {
reader.getOrElse(throw new IllegalArgumentException("Reader is not set"))
}

/**
* Set input dataset which contains columns corresponding to the raw features used in the workflow
* The type of the dataset (Dataset[T]) must match the type of the FeatureBuilders[T] used to generate
Expand Down Expand Up @@ -161,13 +176,29 @@ private[op] trait OpWorkflowCore {
*/
final def getStages(): Array[OPStage] = stages

/**
* Get the raw features generated by the workflow
*
* @return raw features for workflow
*/
final def getRawFeatures(): Array[OPFeature] = rawFeatures

/**
* Get the final features generated by the workflow
*
* @return result features for workflow
*/
final def getResultFeatures(): Array[OPFeature] = resultFeatures

/**
* Get all the features that potentially are generated by the workflow: raw, intermediate and result features
*
* @return all the features that potentially are generated by the workflow: raw, intermediate and result features
*/
final def getAllFeatures(): Array[OPFeature] = {
(getRawFeatures() ++ getStages().flatMap(_.getInputFeatures()) ++ getResultFeatures()).distinct
}

/**
* Get the list of raw features which have been blacklisted
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams
val parentStageIds = feature.traverse[Set[String]](Set.empty[String])((s, f) => s + f.originStage.uid)
val modelStages = stages.filter(s => parentStageIds.contains(s.uid))
ModelInsights.extractFromStages(modelStages, rawFeatures, trainingParams,
blacklistedFeatures, blacklistedMapKeys, rawFeatureFilterResults)
getBlacklist(), getBlacklistMapKeys(), getRawFeatureFilterResults())
}
}

Expand Down
19 changes: 14 additions & 5 deletions core/src/main/scala/com/salesforce/op/OpWorkflowModelReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,27 +92,36 @@ class OpWorkflowModelReader(val workflowOpt: Option[OpWorkflow]) extends MLReade
model <- Try(new OpWorkflowModel(uid = (json \ Uid.entryName).extract[String], trainParams))
(stages, resultFeatures) <- Try(resolveFeaturesAndStages(workflow, json, path))
blacklist <- Try(resolveBlacklist(workflow, json))
blacklistMapKeys <- Try(resolveBlacklistMapKeys(json))
results <- resolveRawFeatureFilterResults(json)
} yield model
.setStages(stages.filterNot(_.isInstanceOf[FeatureGeneratorStage[_, _]]))
.setFeatures(resultFeatures)
.setParameters(params)
.setBlacklist(blacklist)
.setBlacklistMapKeys(blacklistMapKeys)
.setRawFeatureFilterResults(results)
}

private def resolveBlacklist(workflow: OpWorkflow, json: JValue): Array[OPFeature] = {
if ((json \ BlacklistedFeaturesUids.entryName) != JNothing) { // for backwards compatibility
val blacklistIds = (json \ BlacklistedFeaturesUids.entryName).extract[JArray].arr
val allFeatures = workflow.rawFeatures ++ workflow.blacklistedFeatures ++
workflow.stages.flatMap(s => s.getInputFeatures()) ++
workflow.resultFeatures
val allFeatures = workflow.getRawFeatures() ++ workflow.getBlacklist() ++
workflow.getStages().flatMap(_.getInputFeatures()) ++
workflow.getResultFeatures()
blacklistIds.flatMap(uid => allFeatures.find(_.uid == uid.extract[String])).toArray
} else {
Array.empty[OPFeature]
}
}

private def resolveBlacklistMapKeys(json: JValue): Map[String, Set[String]] = {
(json \ BlacklistedMapKeys.entryName).extractOpt[Map[String, List[String]]] match {
case Some(blackMapKeys) => blackMapKeys.map { case (k, vs) => k -> vs.toSet }
case None => Map.empty
}
}

private def resolveFeaturesAndStages
(
workflow: OpWorkflow,
Expand All @@ -135,14 +144,14 @@ class OpWorkflowModelReader(val workflowOpt: Option[OpWorkflow]) extends MLReade
val recoveredStages = stagesJs.flatMap { j =>
val stageUidOpt = (j \ Uid.entryName).extractOpt[String]
stageUidOpt.map { stageUid =>
val originalStage = workflow.stages.find(_.uid == stageUid)
val originalStage = workflow.getStages().find(_.uid == stageUid)
originalStage match {
case Some(os) => new OpPipelineStageReader(os).loadFromJson(j, path = path).asInstanceOf[OPStage]
case None => throw new RuntimeException(s"Workflow does not contain a stage with uid: $stageUid")
}
}
}
val generators = workflow.rawFeatures.map(_.originStage)
val generators = workflow.getRawFeatures().map(_.originStage)
generators ++ recoveredStages
}

Expand Down
19 changes: 11 additions & 8 deletions core/src/main/scala/com/salesforce/op/OpWorkflowModelWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,23 @@ class OpWorkflowModelWriter(val model: OpWorkflowModel) extends MLWriter {
(FN.Uid.entryName -> model.uid) ~
(FN.ResultFeaturesUids.entryName -> resultFeaturesJArray) ~
(FN.BlacklistedFeaturesUids.entryName -> blacklistFeaturesJArray()) ~
(FN.BlacklistedMapKeys.entryName -> blacklistMapKeys()) ~
(FN.Stages.entryName -> stagesJArray(path)) ~
(FN.AllFeatures.entryName -> allFeaturesJArray) ~
(FN.Parameters.entryName -> model.parameters.toJson(pretty = false)) ~
(FN.Parameters.entryName -> model.getParameters().toJson(pretty = false)) ~
(FN.TrainParameters.entryName -> model.trainingParams.toJson(pretty = false)) ~
(FN.RawFeatureFilterResultsFieldName.entryName ->
RawFeatureFilterResults.toJson(model.getRawFeatureFilterResults()))
}

private def resultFeaturesJArray(): JArray =
JArray(model.resultFeatures.map(_.uid).map(JString).toList)
JArray(model.getResultFeatures().map(_.uid).map(JString).toList)

private def blacklistFeaturesJArray(): JArray =
JArray(model.blacklistedFeatures.map(_.uid).map(JString).toList)
JArray(model.getBlacklist().map(_.uid).map(JString).toList)

private def blacklistMapKeys(): JObject =
JObject(model.getBlacklistMapKeys().map { case (k, vs) => k -> JArray(vs.map(JString).toList) }.toList)

/**
* Serialize all the workflow model stages
Expand All @@ -98,7 +102,7 @@ class OpWorkflowModelWriter(val model: OpWorkflowModel) extends MLWriter {
* @return array of serialized stages
*/
private def stagesJArray(path: String): JArray = {
val stages: Seq[OpPipelineStageBase] = model.stages
val stages: Seq[OpPipelineStageBase] = model.getStages()
val stagesJson: Seq[JObject] = stages
.map(_.write.asInstanceOf[OpPipelineStageWriter].writeToJson(path))
.filter(_.children.nonEmpty)
Expand All @@ -111,10 +115,8 @@ class OpWorkflowModelWriter(val model: OpWorkflowModel) extends MLWriter {
* @note Features should be topologically sorted
* @return all features to be serialized
*/
private def allFeaturesJArray: JArray = {
val features = model.rawFeatures ++ model.stages.flatMap(s => s.getInputFeatures()) ++ model.resultFeatures
JArray(features.distinct.map(FeatureJsonHelper.toJson).toList)
}
private def allFeaturesJArray: JArray =
JArray(model.getAllFeatures().map(FeatureJsonHelper.toJson).toList)

}

Expand All @@ -137,6 +139,7 @@ private[op] object OpWorkflowModelReadWriteShared {
case object Uid extends FieldNames("uid")
case object ResultFeaturesUids extends FieldNames("resultFeaturesUids")
case object BlacklistedFeaturesUids extends FieldNames("blacklistedFeaturesUids")
case object BlacklistedMapKeys extends FieldNames("blacklistedMapKeys")
case object Stages extends FieldNames("stages")
case object AllFeatures extends FieldNames("allFeatures")
case object Parameters extends FieldNames("parameters")
Expand Down
4 changes: 2 additions & 2 deletions core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -428,12 +428,12 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
it should "have feature insights for features that are removed by the raw feature filter" in {
val insights = modelWithRFF.modelInsights(predWithMaps)

modelWithRFF.blacklistedFeatures should contain theSameElementsAs Array(age, description, genderPL, weight)
modelWithRFF.getBlacklist() should contain theSameElementsAs Array(age, description, genderPL, weight)
val heightIn = insights.features.find(_.featureName == age.name).get
heightIn.derivedFeatures.size shouldBe 1
heightIn.derivedFeatures.head.excluded shouldBe Some(true)

modelWithRFF.blacklistedMapKeys should contain theSameElementsAs Map(numericMap.name -> Set("Female"))
modelWithRFF.getBlacklistMapKeys() should contain theSameElementsAs Map(numericMap.name -> Set("Female"))
val mapDerivedIn = insights.features.find(_.featureName == numericMap.name).get.derivedFeatures
val droppedMapDerivedIn = mapDerivedIn.filter(_.derivedFeatureName == "Female")
mapDerivedIn.size shouldBe 3
Expand Down
Loading

0 comments on commit 6f55dee

Please sign in to comment.