Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialization for FeatureGeneratorStage #300

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ class OpWorkflowModelReader(val workflowOpt: Option[OpWorkflow]) extends MLReade
val originalStage = workflow.stages.find(_.uid == stageUid)
originalStage match {
case Some(os) => new OpPipelineStageReader(os).loadFromJson(j, path = path).asInstanceOf[OPStage]
case None if stageUid.startsWith("FeatureGeneratorStage_") =>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

classOf[FeatureGeneratorStage].getSimpleName instead of string

new OpPipelineStageReader(Seq()).loadFromJson(j, path).asInstanceOf[OPStage]
case None => throw new RuntimeException(s"Workflow does not contain a stage with uid: $stageUid")
}
}
Expand Down
22 changes: 20 additions & 2 deletions core/src/main/scala/com/salesforce/op/OpWorkflowModelWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ package com.salesforce.op

import com.salesforce.op.features.FeatureJsonHelper
import com.salesforce.op.filters.RawFeatureFilterResults
import com.salesforce.op.stages.{OpPipelineStageBase, OpPipelineStageWriter}
import com.salesforce.op.stages.{FeatureGeneratorStage, OPStage, OpPipelineStageBase, OpPipelineStageWriter}
import enumeratum._
import org.apache.hadoop.fs.Path
import org.apache.spark.ml.util.MLWriter
Expand Down Expand Up @@ -98,13 +98,22 @@ class OpWorkflowModelWriter(val model: OpWorkflowModel) extends MLWriter {
* @return array of serialized stages
*/
private def stagesJArray(path: String): JArray = {
val stages: Seq[OpPipelineStageBase] = model.stages
val stages: Seq[OpPipelineStageBase] = getFeatureGenStages(model.stages) ++ model.stages
val stagesJson: Seq[JObject] = stages
.map(_.write.asInstanceOf[OpPipelineStageWriter].writeToJson(path))
.filter(_.children.nonEmpty)
JArray(stagesJson.toList)
}

private def getFeatureGenStages(stages: Seq[OPStage]): Seq[OpPipelineStageBase] = {
for {
stage <- stages
inputFeatures <- stage.getInputFeatures()
orgStage = inputFeatures.originStage
if orgStage.isInstanceOf[FeatureGeneratorStage[_, _]]
} yield orgStage
}

/**
* Gets all features to be serialized.
*
Expand Down Expand Up @@ -134,14 +143,23 @@ private[op] object OpWorkflowModelReadWriteShared {
*/
object FieldNames extends Enum[FieldNames] {
val values = findValues

case object Uid extends FieldNames("uid")

case object ResultFeaturesUids extends FieldNames("resultFeaturesUids")

case object BlacklistedFeaturesUids extends FieldNames("blacklistedFeaturesUids")

case object Stages extends FieldNames("stages")

case object AllFeatures extends FieldNames("allFeatures")

case object Parameters extends FieldNames("parameters")

case object TrainParameters extends FieldNames("trainParameters")

case object RawFeatureFilterResultsFieldName extends FieldNames("rawFeatureFilterResults")

}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ import org.scalatest.{BeforeAndAfterEach, FlatSpec}
import org.slf4j.LoggerFactory

import scala.collection.JavaConverters._

import OpWorkflowModelReaderWriterTest._

@RunWith(classOf[JUnitRunner])
class OpWorkflowModelReaderWriterTest
Expand Down Expand Up @@ -145,7 +145,7 @@ class OpWorkflowModelReaderWriterTest
}

trait SwSingleStageFlow {
val vec = FeatureBuilder.OPVector[Passenger].extract(_ => OPVector.empty).asPredictor
val vec = FeatureBuilder.OPVector[Passenger].extract(emptyVectFnc).asPredictor
val scaler = new StandardScaler().setWithStd(false).setWithMean(false)
val schema = FeatureSparkTypes.toStructType(vec)
val data = spark.createDataFrame(List(Row(Vectors.dense(1.0))).asJava, schema)
Expand All @@ -172,7 +172,7 @@ class OpWorkflowModelReaderWriterTest

it should "have a single stage" in new SingleStageFlow {
val stagesM = (jsonModel \ Stages.entryName).extract[JArray]
stagesM.values.size shouldBe 1
stagesM.values.size shouldBe 3
}

it should "have 3 features" in new SingleStageFlow {
Expand All @@ -193,7 +193,7 @@ class OpWorkflowModelReaderWriterTest

"MultiStage OpWorkflowWriter" should "recover all relevant stages" in new MultiStageFlow {
val stagesM = (jsonModel \ Stages.entryName).extract[JArray]
stagesM.values.size shouldBe 2
stagesM.values.size shouldBe 5
}

it should "recover all relevant features" in new MultiStageFlow {
Expand Down Expand Up @@ -379,4 +379,6 @@ trait UIDReset {

object OpWorkflowModelReaderWriterTest {
def mapFnc0: OPVector => Real = v => Real(v.value.toArray.headOption)

def emptyVectFnc: (Passenger => OPVector) = _ => OPVector.empty
}
1 change: 1 addition & 0 deletions core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import com.salesforce.op.stages.impl.tuning._
import com.salesforce.op.test.{Passenger, PassengerSparkFixtureTest, TestFeatureBuilder}
import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata}
import org.apache.log4j.Level
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{BooleanParam, ParamMap}
import org.apache.spark.ml.tuning.ParamGridBuilder
Expand Down
Loading