Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

made case class to deal with model selector metadata #39

Merged
merged 14 commits into from
Aug 9, 2018
Merged
37 changes: 19 additions & 18 deletions core/src/main/scala/com/salesforce/op/ModelInsights.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ case class ModelInsights
* @return json string
*/
def toJson(pretty: Boolean = true): String = {
implicit val formats = ModelInsights.SerFormats
implicit val formats = ModelInsights.SerializationFormats
if (pretty) writePretty(this) else write(this)
}

Expand Down Expand Up @@ -362,23 +362,24 @@ case class Insights
case object ModelInsights {
@transient protected lazy val log = LoggerFactory.getLogger(this.getClass)

val SerFormats: Formats = Serialization.formats(FullTypeHints(List(
classOf[Continuous], classOf[Discrete],
classOf[DataBalancerSummary], classOf[DataCutterSummary], classOf[DataSplitterSummary],
classOf[SingleMetric], classOf[MultiMetrics], classOf[BinaryClassificationMetrics], classOf[ThresholdMetrics],
classOf[MultiClassificationMetrics], classOf[RegressionMetrics]
))) +
EnumEntrySerializer.json4s[ValidationType](ValidationType) +
EnumEntrySerializer.json4s[ProblemType](ProblemType) +
new SpecialDoubleSerializer +
new CustomSerializer[EvalMetric](_ =>
( {
case JString(s) => EvalMetric.withNameInsensitive(s)
}, {
case x: EvalMetric => JString(x.entryName)
}
)
val SerializationFormats: Formats = {
val typeHints = FullTypeHints(List(
classOf[Continuous], classOf[Discrete],
classOf[DataBalancerSummary], classOf[DataCutterSummary], classOf[DataSplitterSummary],
classOf[SingleMetric], classOf[MultiMetrics], classOf[BinaryClassificationMetrics], classOf[ThresholdMetrics],
classOf[MultiClassificationMetrics], classOf[RegressionMetrics]
))
val evalMetricsSerializer = new CustomSerializer[EvalMetric](_ =>
( { case JString(s) => EvalMetric.withNameInsensitive(s) },
{ case x: EvalMetric => JString(x.entryName) }
)
)
Serialization.formats(typeHints) +
EnumEntrySerializer.json4s[ValidationType](ValidationType) +
EnumEntrySerializer.json4s[ProblemType](ProblemType) +
new SpecialDoubleSerializer +
evalMetricsSerializer
}

/**
* Read ModelInsights from a json
Expand All @@ -387,7 +388,7 @@ case object ModelInsights {
* @return Try[ModelInsights]
*/
def fromJson(json: String): Try[ModelInsights] = {
implicit val formats: Formats = SerFormats
implicit val formats: Formats = SerializationFormats
Try { read[ModelInsights](json) }
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,48 +132,6 @@ trait EvaluationMetrics extends JsonLike {
def toMetadata: Metadata = this.toMap.toMetadata
}

private[op] object EvaluationMetrics {

/**
* Decode metric values from JSON string
*
* @param json encoded metrics
*/
def fromJson(className: String, json: String): EvaluationMetrics = {
def error(c: Class[_]) = throw new IllegalArgumentException(
s"Could not extract metrics of type $c from ${json.mkString(",")}"
)
className match {
case n if n == classOf[MultiMetrics].getSimpleName =>
JsonUtils.fromString[Map[String, Map[String, Any]]](json).map{ d =>
val asMetrics = d.flatMap{ case (_, values) => values.map{
case (nm: String, mp: Map[String, Any]@unchecked) =>
val valsJson = JsonUtils.toJsonString(mp) // gross but it works TODO try to find a better way
nm match {
case OpEvaluatorNames.Binary.humanFriendlyName =>
nm -> JsonUtils.fromString[BinaryClassificationMetrics](valsJson).get
case OpEvaluatorNames.Multi.humanFriendlyName =>
nm -> JsonUtils.fromString[MultiClassificationMetrics](valsJson).get
case OpEvaluatorNames.Regression.humanFriendlyName =>
nm -> JsonUtils.fromString[RegressionMetrics](valsJson).get
case _ => nm -> JsonUtils.fromString[SingleMetric](valsJson).get
}}
}
MultiMetrics(asMetrics)
}.getOrElse(error(classOf[MultiMetrics]))
case n if n == classOf[BinaryClassificationMetrics].getSimpleName =>
JsonUtils.fromString[BinaryClassificationMetrics](json).getOrElse(error(classOf[BinaryClassificationMetrics]))
case n if n == classOf[MultiClassificationMetrics].getSimpleName =>
JsonUtils.fromString[MultiClassificationMetrics](json).getOrElse(error(classOf[MultiClassificationMetrics]))
case n if n == classOf[RegressionMetrics].getSimpleName =>
JsonUtils.fromString[RegressionMetrics](json).getOrElse(error(classOf[RegressionMetrics]))
case n if n == classOf[SingleMetric].getSimpleName =>
JsonUtils.fromString[SingleMetric](json).getOrElse(error(classOf[SingleMetric]))
case n => throw new IllegalArgumentException(s"Could not extract metrics of type $n from ${json.mkString(",")}")
}
}
}

/**
* Base Interface for OpEvaluator to be used in Evaluator creation. Can be used for both OP and spark
* eval (so with workflows and cross validation).
Expand Down Expand Up @@ -326,6 +284,9 @@ trait EvalMetric extends EnumEntry with Serializable {

}

/**
* Eval metric companion object
*/
object EvalMetric {

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docs please

def withNameInsensitive(name: String): EvalMetric = {
Expand Down Expand Up @@ -385,6 +346,9 @@ sealed abstract class RegressionEvalMetric
val humanFriendlyName: String
) extends EvalMetric

/**
* Regression Metrics
*/
object RegressionEvalMetrics extends Enum[RegressionEvalMetric] {
val values: Seq[RegressionEvalMetric] = findValues
case object RootMeanSquaredError extends RegressionEvalMetric("rmse", "root mean square error")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import com.salesforce.op.utils.spark.RichMetadata._
import enumeratum._
import org.apache.spark.sql.types.{Metadata, MetadataBuilder}
import com.salesforce.op.stages.impl.selector.ModelSelectorSummary._
import com.salesforce.op.utils.json.JsonUtils

/**
* This is used to store all information about fitting and model selection generated by the model selector class
Expand Down Expand Up @@ -164,7 +165,7 @@ case object ModelSelectorSummary {
val modelName: String = wrapped.get[String](ModelName)
val modelType: String = wrapped.get[String](ModelTypeName)
val Array(metName, metJson) = wrapped.get[Array[String]](MetricValues)
val metricValues: EvaluationMetrics = EvaluationMetrics.fromJson(metName, metJson)
val metricValues: EvaluationMetrics = evalMetFromJson(metName, metJson)
val modelParameters: Map[String, Any] = wrapped.get[Metadata](ModelParameters).wrapped.underlyingMap

ModelEvaluation(
Expand All @@ -186,7 +187,7 @@ case object ModelSelectorSummary {
.wrapped.underlyingMap
val dataPrepResults: Option[SplitterSummary] =
if (wrapped.contains(DataPrepResults)) {
Option(SplitterSummary.fromMap(wrapped.get[Metadata](DataPrepResults).wrapped.underlyingMap))
SplitterSummary.fromMetadata(wrapped.get[Metadata](DataPrepResults)).toOption
} else None
val evaluationMetric: EvalMetric = EvalMetric.withNameInsensitive(wrapped.get[String](EvaluationMetric))
val problemType: ProblemType = ProblemType.withName(wrapped.get[String](ProblemTypeName))
Expand All @@ -196,11 +197,11 @@ case object ModelSelectorSummary {
val validationResults: Seq[ModelEvaluation] = wrapped.get[Array[Metadata]](ValidationResults)
.map(modelEvalFromMetadata)
val Array(metName, metJson) = wrapped.get[Array[String]](TrainEvaluation)
val trainEvaluation: EvaluationMetrics = EvaluationMetrics.fromJson(metName, metJson)
val trainEvaluation: EvaluationMetrics = evalMetFromJson(metName, metJson)
val holdoutEvaluation: Option[EvaluationMetrics] =
if (wrapped.contains(HoldoutEvaluation)) {
val Array(metNameHold, metJsonHold) = wrapped.get[Array[String]](HoldoutEvaluation)
Option(EvaluationMetrics.fromJson(metNameHold, metJsonHold))
Option(evalMetFromJson(metNameHold, metJsonHold))
} else None

ModelSelectorSummary(
Expand All @@ -218,6 +219,47 @@ case object ModelSelectorSummary {
holdoutEvaluation = holdoutEvaluation)

}

/**
* Decode metric values from JSON string
*
* @param json encoded metrics
*/
private def evalMetFromJson(className: String, json: String): EvaluationMetrics = {
def error(c: Class[_]) = throw new IllegalArgumentException(
s"Could not extract metrics of type $c from ${json.mkString(",")}"
)
className match {
case n if n == classOf[MultiMetrics].getSimpleName =>
JsonUtils.fromString[Map[String, Map[String, Any]]](json).map{ d =>
val asMetrics = d.flatMap{ case (_, values) => values.map{
case (nm: String, mp: Map[String, Any]@unchecked) =>
val valsJson = JsonUtils.toJsonString(mp) // gross but it works TODO try to find a better way
nm match {
case OpEvaluatorNames.Binary.humanFriendlyName =>
nm -> JsonUtils.fromString[BinaryClassificationMetrics](valsJson).get
case OpEvaluatorNames.Multi.humanFriendlyName =>
nm -> JsonUtils.fromString[MultiClassificationMetrics](valsJson).get
case OpEvaluatorNames.Regression.humanFriendlyName =>
nm -> JsonUtils.fromString[RegressionMetrics](valsJson).get
case _ => nm -> JsonUtils.fromString[SingleMetric](valsJson).get
}}
}
MultiMetrics(asMetrics)
}.getOrElse(error(classOf[MultiMetrics]))
case n if n == classOf[BinaryClassificationMetrics].getSimpleName =>
JsonUtils.fromString[BinaryClassificationMetrics](json).getOrElse(error(classOf[BinaryClassificationMetrics]))
case n if n == classOf[MultiClassificationMetrics].getSimpleName =>
JsonUtils.fromString[MultiClassificationMetrics](json).getOrElse(error(classOf[MultiClassificationMetrics]))
case n if n == classOf[RegressionMetrics].getSimpleName =>
JsonUtils.fromString[RegressionMetrics](json).getOrElse(error(classOf[RegressionMetrics]))
case n if n == classOf[SingleMetric].getSimpleName =>
JsonUtils.fromString[SingleMetric](json).getOrElse(error(classOf[SingleMetric]))
case n => throw new IllegalArgumentException(s"Could not extract metrics of type $n from ${json.mkString(",")}")
}
}


}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ import org.apache.spark.ml.param._
import org.apache.spark.sql.{Dataset, Row}
import com.salesforce.op.stages.impl.MetadataLike
import com.salesforce.op.stages.impl.selector.ModelSelectorBaseNames
import com.salesforce.op.utils.reflection.ReflectionUtils
import org.apache.spark.sql.types.{Metadata, MetadataBuilder}
import com.salesforce.op.utils.spark.RichMetadata._
import org.apache.spark.sql.types.Metadata

import scala.util.Try



Expand Down Expand Up @@ -118,7 +120,8 @@ trait SplitterSummary extends MetadataLike

private[op] object SplitterSummary {
val ClassName: String = "className"
def fromMap(map: Map[String, Any]): SplitterSummary = {
def fromMetadata(metadata: Metadata): Try[SplitterSummary] = Try {
val map = metadata.wrapped.underlyingMap
map(ClassName) match {
case s if s == classOf[DataSplitterSummary].getCanonicalName => DataSplitterSummary()
case s if s == classOf[DataBalancerSummary].getCanonicalName => DataBalancerSummary(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class ModelSelectorSummaryTest extends FlatSpec with TestSparkContext {
decoded.bestModelName shouldEqual summary.bestModelName
decoded.bestModelType shouldEqual summary.bestModelType
decoded.validationResults shouldEqual summary.validationResults
decoded.trainEvaluation.toJson() shouldEqual summary.trainEvaluation.toJson()
decoded.trainEvaluation shouldEqual summary.trainEvaluation
decoded.holdoutEvaluation shouldEqual summary.holdoutEvaluation
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ import org.apache.spark.sql.types.MetadataBuilder
import com.salesforce.op.utils.spark.RichDataset._

@RunWith(classOf[JUnitRunner])
class OpValidatorTest extends FlatSpec with TestSparkContext {
class
OpValidatorTest extends FlatSpec with TestSparkContext {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove redundant end line

// Random Data
val count = 1000
val sizeOfVector = 2
Expand Down Expand Up @@ -107,7 +108,6 @@ class OpValidatorTest extends FlatSpec with TestSparkContext {
assertFractions(Array(1 - p, p), train)
assertFractions(Array(1 - p, p), validate)
}
println(balancer.get.summary)
balancer.get.summary.get.toMetadata() should not be new MetadataBuilder().build()
}

Expand Down