diff --git a/repos/lila/modules/api/src/main/RoundApiBalancer.scala b/repos/lila/modules/api/src/main/RoundApiBalancer.scala index 7e51afd35eb..8a0ac500f44 100644 --- a/repos/lila/modules/api/src/main/RoundApiBalancer.scala +++ b/repos/lila/modules/api/src/main/RoundApiBalancer.scala @@ -1,7 +1,7 @@ package lila.api import akka.actor._ -import akka.pattern.{ ask, pipe } +import akka.pattern.{ask, pipe} import play.api.libs.json.JsObject import scala.concurrent.duration._ @@ -22,51 +22,97 @@ private[api] final class RoundApiBalancer( implicit val timeout = makeTimeout seconds 20 case class Player(pov: Pov, apiVersion: Int, ctx: Context) - case class Watcher(pov: Pov, apiVersion: Int, tv: Option[lila.round.OnTv], - analysis: Option[(Pgn, Analysis)] = None, - initialFenO: Option[Option[String]] = None, - withMoveTimes: Boolean = false, - withOpening: Boolean = false, - ctx: Context) - case class UserAnalysis(pov: Pov, pref: Pref, initialFen: Option[String], orientation: chess.Color, owner: Boolean) + case class Watcher( + pov: Pov, + apiVersion: Int, + tv: Option[lila.round.OnTv], + analysis: Option[(Pgn, Analysis)] = None, + initialFenO: Option[Option[String]] = None, + withMoveTimes: Boolean = false, + withOpening: Boolean = false, + ctx: Context) + case class UserAnalysis( + pov: Pov, + pref: Pref, + initialFen: Option[String], + orientation: chess.Color, + owner: Boolean) val router = system.actorOf( - akka.routing.RoundRobinPool(nbActors).props(Props(new lila.hub.SequentialProvider { - val futureTimeout = 20.seconds - val logger = RoundApiBalancer.this.logger - def process = { - case Player(pov, apiVersion, ctx) => { - api.player(pov, apiVersion)(ctx) addFailureEffect { e => - logger.error(pov.toString, e) - } - }.chronometer.logIfSlow(500, logger) { _ => s"inner player $pov" }.result - case Watcher(pov, apiVersion, tv, analysis, initialFenO, withMoveTimes, withOpening, ctx) => - api.watcher(pov, apiVersion, tv, analysis, initialFenO, withMoveTimes, withOpening)(ctx) - case UserAnalysis(pov, pref, initialFen, orientation, owner) => - api.userAnalysisJson(pov, pref, initialFen, orientation, owner) - } - })), "api.round.router") + akka.routing + .RoundRobinPool(nbActors) + .props(Props(new lila.hub.SequentialProvider { + val futureTimeout = 20.seconds + val logger = RoundApiBalancer.this.logger + def process = { + case Player(pov, apiVersion, ctx) => { + api.player(pov, apiVersion)(ctx) addFailureEffect { e => + logger.error(pov.toString, e) + } + }.chronometer + .logIfSlow(500, logger) { _ => s"inner player $pov" } + .result + case Watcher( + pov, + apiVersion, + tv, + analysis, + initialFenO, + withMoveTimes, + withOpening, + ctx) => + api.watcher( + pov, + apiVersion, + tv, + analysis, + initialFenO, + withMoveTimes, + withOpening)(ctx) + case UserAnalysis(pov, pref, initialFen, orientation, owner) => + api.userAnalysisJson(pov, pref, initialFen, orientation, owner) + } + })), + "api.round.router" + ) } import implementation._ def player(pov: Pov, apiVersion: Int)(implicit ctx: Context): Fu[JsObject] = { - router ? Player(pov, apiVersion, ctx) mapTo manifest[JsObject] addFailureEffect { e => - logger.error(pov.toString, e) + router ? Player(pov, apiVersion, ctx) mapTo manifest[JsObject] addFailureEffect { + e => logger.error(pov.toString, e) } }.chronometer .mon(_.round.api.player) .logIfSlow(500, logger) { _ => s"outer player $pov" } .result - def watcher(pov: Pov, apiVersion: Int, tv: Option[lila.round.OnTv], - analysis: Option[(Pgn, Analysis)] = None, - initialFenO: Option[Option[String]] = None, - withMoveTimes: Boolean = false, - withOpening: Boolean = false)(implicit ctx: Context): Fu[JsObject] = { - router ? Watcher(pov, apiVersion, tv, analysis, initialFenO, withMoveTimes, withOpening, ctx) mapTo manifest[JsObject] + def watcher( + pov: Pov, + apiVersion: Int, + tv: Option[lila.round.OnTv], + analysis: Option[(Pgn, Analysis)] = None, + initialFenO: Option[Option[String]] = None, + withMoveTimes: Boolean = false, + withOpening: Boolean = false)(implicit ctx: Context): Fu[JsObject] = { + router ? Watcher( + pov, + apiVersion, + tv, + analysis, + initialFenO, + withMoveTimes, + withOpening, + ctx) mapTo manifest[JsObject] }.mon(_.round.api.watcher) - def userAnalysisJson(pov: Pov, pref: Pref, initialFen: Option[String], orientation: chess.Color, owner: Boolean): Fu[JsObject] = - router ? UserAnalysis(pov, pref, initialFen, orientation, owner) mapTo manifest[JsObject] + def userAnalysisJson( + pov: Pov, + pref: Pref, + initialFen: Option[String], + orientation: chess.Color, + owner: Boolean): Fu[JsObject] = + router ? UserAnalysis(pov, pref, initialFen, orientation, owner) mapTo manifest[ + JsObject] } diff --git a/repos/marathon/src/main/scala/mesosphere/marathon/api/v2/json/Formats.scala b/repos/marathon/src/main/scala/mesosphere/marathon/api/v2/json/Formats.scala index 3d0bc23b4ae..be820d82d8e 100644 --- a/repos/marathon/src/main/scala/mesosphere/marathon/api/v2/json/Formats.scala +++ b/repos/marathon/src/main/scala/mesosphere/marathon/api/v2/json/Formats.scala @@ -5,17 +5,17 @@ import mesosphere.marathon.Protos.Constraint.Operator import mesosphere.marathon.Protos.HealthCheckDefinition.Protocol import mesosphere.marathon.Protos.ResidencyDefinition.TaskLostBehavior import mesosphere.marathon.core.appinfo._ -import mesosphere.marathon.core.plugin.{ PluginDefinition, PluginDefinitions } +import mesosphere.marathon.core.plugin.{PluginDefinition, PluginDefinitions} import mesosphere.marathon.core.task.Task import mesosphere.marathon.event._ import mesosphere.marathon.event.http.EventSubscribers -import mesosphere.marathon.health.{ Health, HealthCheck } +import mesosphere.marathon.health.{Health, HealthCheck} import mesosphere.marathon.state.Container.Docker import mesosphere.marathon.state.Container.Docker.PortMapping import mesosphere.marathon.state._ import mesosphere.marathon.upgrade._ import org.apache.mesos.Protos.ContainerInfo.DockerInfo -import org.apache.mesos.{ Protos => mesos } +import org.apache.mesos.{Protos => mesos} import play.api.data.validation.ValidationError import play.api.libs.functional.syntax._ import play.api.libs.json._ @@ -26,11 +26,13 @@ import scala.concurrent.duration._ // scalastyle:off file.size.limit object Formats extends Formats { - implicit class ReadsWithDefault[A](val reads: Reads[Option[A]]) extends AnyVal { + implicit class ReadsWithDefault[A](val reads: Reads[Option[A]]) + extends AnyVal { def withDefault(a: A): Reads[A] = reads.map(_.getOrElse(a)) } - implicit class FormatWithDefault[A](val m: OFormat[Option[A]]) extends AnyVal { + implicit class FormatWithDefault[A](val m: OFormat[Option[A]]) + extends AnyVal { def withDefault(a: A): OFormat[A] = m.inmap(_.getOrElse(a), Some(_)) } @@ -67,48 +69,63 @@ trait Formats "taskId" -> failure.taskId.getValue, "timestamp" -> failure.timestamp, "version" -> failure.version, - "slaveId" -> (if (failure.slaveId.isDefined) failure.slaveId.get.getValue else JsNull) + "slaveId" -> (if (failure.slaveId.isDefined) failure.slaveId.get.getValue + else JsNull) ) } - implicit lazy val networkInfoProtocolWrites = Writes[mesos.NetworkInfo.Protocol] { protocol => - JsString(protocol.name) - } + implicit lazy val networkInfoProtocolWrites = + Writes[mesos.NetworkInfo.Protocol] { protocol => JsString(protocol.name) } private[this] val allowedProtocolString = - mesos.NetworkInfo.Protocol.values().toSeq.map(_.getDescriptorForType.getName).mkString(", ") - - implicit lazy val networkInfoProtocolReads = Reads[mesos.NetworkInfo.Protocol] { json => - json.validate[String].flatMap { protocolString: String => + mesos.NetworkInfo.Protocol + .values() + .toSeq + .map(_.getDescriptorForType.getName) + .mkString(", ") + + implicit lazy val networkInfoProtocolReads = + Reads[mesos.NetworkInfo.Protocol] { json => + json.validate[String].flatMap { protocolString: String => + Option(mesos.NetworkInfo.Protocol.valueOf(protocolString)) match { + case Some(protocol) => JsSuccess(protocol) + case None => + JsError( + s"'$protocolString' is not a valid protocol. Allowed values: $allowedProtocolString") + } - Option(mesos.NetworkInfo.Protocol.valueOf(protocolString)) match { - case Some(protocol) => JsSuccess(protocol) - case None => - JsError(s"'$protocolString' is not a valid protocol. Allowed values: $allowedProtocolString") } - } - } implicit lazy val ipAddressFormat: Format[mesos.NetworkInfo.IPAddress] = { - def toIpAddress(ipAddress: String, protocol: mesos.NetworkInfo.Protocol): mesos.NetworkInfo.IPAddress = - mesos.NetworkInfo.IPAddress.newBuilder().setIpAddress(ipAddress).setProtocol(protocol).build() - - def toTuple(ipAddress: mesos.NetworkInfo.IPAddress): (String, mesos.NetworkInfo.Protocol) = + def toIpAddress( + ipAddress: String, + protocol: mesos.NetworkInfo.Protocol): mesos.NetworkInfo.IPAddress = + mesos.NetworkInfo.IPAddress + .newBuilder() + .setIpAddress(ipAddress) + .setProtocol(protocol) + .build() + + def toTuple(ipAddress: mesos.NetworkInfo.IPAddress) + : (String, mesos.NetworkInfo.Protocol) = (ipAddress.getIpAddress, ipAddress.getProtocol) ( (__ \ "ipAddress").format[String] ~ - (__ \ "protocol").format[mesos.NetworkInfo.Protocol] + (__ \ "protocol").format[mesos.NetworkInfo.Protocol] )(toIpAddress, toTuple) } - implicit lazy val TaskIdWrite: Writes[Task.Id] = Writes { id => JsString(id.idString) } - implicit lazy val LocalVolumeIdWrite: Writes[Task.LocalVolumeId] = Writes { id => - Json.obj( - "containerPath" -> id.containerPath, - "persistenceId" -> id.idString - ) + implicit lazy val TaskIdWrite: Writes[Task.Id] = Writes { id => + JsString(id.idString) + } + implicit lazy val LocalVolumeIdWrite: Writes[Task.LocalVolumeId] = Writes { + id => + Json.obj( + "containerPath" -> id.containerPath, + "persistenceId" -> id.idString + ) } implicit lazy val TaskWrites: Writes[Task] = Writes { task => @@ -118,21 +135,25 @@ trait Formats "host" -> task.agentInfo.host ) - val launched = task.launched.map { launched => - base ++ Json.obj ( - "startedAt" -> launched.status.startedAt, - "stagedAt" -> launched.status.stagedAt, - "ports" -> launched.ports, - "ipAddresses" -> launched.ipAddresses, - "version" -> launched.appVersion - ) - }.getOrElse(base) + val launched = task.launched + .map { launched => + base ++ Json.obj( + "startedAt" -> launched.status.startedAt, + "stagedAt" -> launched.status.stagedAt, + "ports" -> launched.ports, + "ipAddresses" -> launched.ipAddresses, + "version" -> launched.appVersion + ) + } + .getOrElse(base) - val reservation = task.reservationWithVolumes.map { reservation => - launched ++ Json.obj( - "localVolumes" -> reservation.volumeIds - ) - }.getOrElse(launched) + val reservation = task.reservationWithVolumes + .map { reservation => + launched ++ Json.obj( + "localVolumes" -> reservation.volumeIds + ) + } + .getOrElse(launched) reservation } @@ -144,13 +165,15 @@ trait Formats "appId" -> task.appId ) - val withServicePorts = if (task.servicePorts.nonEmpty) - enrichedJson ++ Json.obj("servicePorts" -> task.servicePorts) - else - enrichedJson + val withServicePorts = + if (task.servicePorts.nonEmpty) + enrichedJson ++ Json.obj("servicePorts" -> task.servicePorts) + else + enrichedJson if (task.healthCheckResults.nonEmpty) - withServicePorts ++ Json.obj("healthCheckResults" -> task.healthCheckResults) + withServicePorts ++ Json.obj( + "healthCheckResults" -> task.healthCheckResults) else withServicePorts } @@ -174,28 +197,32 @@ trait Formats implicit lazy val ParameterFormat: Format[Parameter] = ( (__ \ "key").format[String] ~ - (__ \ "value").format[String] + (__ \ "value").format[String] )(Parameter(_, _), unlift(Parameter.unapply)) /* - * Helpers - */ + * Helpers + */ - def uniquePorts: Reads[Seq[Int]] = Format.of[Seq[Int]].filter(ValidationError("Ports must be unique.")) { ports => - val withoutRandom = ports.filterNot(_ == AppDefinition.RandomPortValue) - withoutRandom.distinct.size == withoutRandom.size - } + def uniquePorts: Reads[Seq[Int]] = + Format.of[Seq[Int]].filter(ValidationError("Ports must be unique.")) { + ports => + val withoutRandom = ports.filterNot(_ == AppDefinition.RandomPortValue) + withoutRandom.distinct.size == withoutRandom.size + } def nonEmpty[C <: Iterable[_]](implicit reads: Reads[C]): Reads[C] = - Reads.filterNot[C](ValidationError(s"set must not be empty"))(_.isEmpty)(reads) + Reads.filterNot[C](ValidationError(s"set must not be empty"))(_.isEmpty)( + reads) - def enumFormat[A <: java.lang.Enum[A]](read: String => A, errorMsg: String => String): Format[A] = { + def enumFormat[A <: java.lang.Enum[A]]( + read: String => A, + errorMsg: String => String): Format[A] = { val reads = Reads[A] { case JsString(str) => try { JsSuccess(read(str)) - } - catch { + } catch { case _: IllegalArgumentException => JsError(errorMsg(str)) } @@ -212,45 +239,62 @@ trait ContainerFormats { import Formats._ implicit lazy val DockerNetworkFormat: Format[DockerInfo.Network] = - enumFormat(DockerInfo.Network.valueOf, str => s"$str is not a valid network type") + enumFormat( + DockerInfo.Network.valueOf, + str => s"$str is not a valid network type") implicit lazy val PortMappingFormat: Format[Docker.PortMapping] = ( - (__ \ "containerPort").formatNullable[Int].withDefault(AppDefinition.RandomPortValue) ~ - (__ \ "hostPort").formatNullable[Int].withDefault(AppDefinition.RandomPortValue) ~ - (__ \ "servicePort").formatNullable[Int].withDefault(AppDefinition.RandomPortValue) ~ - (__ \ "protocol").formatNullable[String].withDefault("tcp") ~ - (__ \ "name").formatNullable[String] ~ - (__ \ "labels").formatNullable[Map[String, String]].withDefault(Map.empty[String, String]) + (__ \ "containerPort") + .formatNullable[Int] + .withDefault(AppDefinition.RandomPortValue) ~ + (__ \ "hostPort") + .formatNullable[Int] + .withDefault(AppDefinition.RandomPortValue) ~ + (__ \ "servicePort") + .formatNullable[Int] + .withDefault(AppDefinition.RandomPortValue) ~ + (__ \ "protocol").formatNullable[String].withDefault("tcp") ~ + (__ \ "name").formatNullable[String] ~ + (__ \ "labels") + .formatNullable[Map[String, String]] + .withDefault(Map.empty[String, String]) )(PortMapping(_, _, _, _, _, _), unlift(PortMapping.unapply)) implicit lazy val DockerFormat: Format[Docker] = ( (__ \ "image").format[String] ~ - (__ \ "network").formatNullable[DockerInfo.Network] ~ - (__ \ "portMappings").formatNullable[Seq[Docker.PortMapping]] ~ - (__ \ "privileged").formatNullable[Boolean].withDefault(false) ~ - (__ \ "parameters").formatNullable[Seq[Parameter]].withDefault(Seq.empty) ~ - (__ \ "forcePullImage").formatNullable[Boolean].withDefault(false) + (__ \ "network").formatNullable[DockerInfo.Network] ~ + (__ \ "portMappings").formatNullable[Seq[Docker.PortMapping]] ~ + (__ \ "privileged").formatNullable[Boolean].withDefault(false) ~ + (__ \ "parameters") + .formatNullable[Seq[Parameter]] + .withDefault(Seq.empty) ~ + (__ \ "forcePullImage").formatNullable[Boolean].withDefault(false) )(Docker(_, _, _, _, _, _), unlift(Docker.unapply)) implicit lazy val ModeFormat: Format[mesos.Volume.Mode] = enumFormat(mesos.Volume.Mode.valueOf, str => s"$str is not a valid mde") - implicit lazy val PersistentVolumeInfoFormat: Format[PersistentVolumeInfo] = Json.format[PersistentVolumeInfo] + implicit lazy val PersistentVolumeInfoFormat: Format[PersistentVolumeInfo] = + Json.format[PersistentVolumeInfo] implicit lazy val VolumeFormat: Format[Volume] = ( (__ \ "containerPath").format[String] ~ - (__ \ "hostPath").formatNullable[String] ~ - (__ \ "mode").format[mesos.Volume.Mode] ~ - (__ \ "persistent").formatNullable[PersistentVolumeInfo] + (__ \ "hostPath").formatNullable[String] ~ + (__ \ "mode").format[mesos.Volume.Mode] ~ + (__ \ "persistent").formatNullable[PersistentVolumeInfo] )(Volume(_, _, _, _), unlift(Volume.unapply)) implicit lazy val ContainerTypeFormat: Format[mesos.ContainerInfo.Type] = - enumFormat(mesos.ContainerInfo.Type.valueOf, str => s"$str is not a valid container type") + enumFormat( + mesos.ContainerInfo.Type.valueOf, + str => s"$str is not a valid container type") implicit lazy val ContainerFormat: Format[Container] = ( - (__ \ "type").formatNullable[mesos.ContainerInfo.Type].withDefault(mesos.ContainerInfo.Type.DOCKER) ~ - (__ \ "volumes").formatNullable[Seq[Volume]].withDefault(Nil) ~ - (__ \ "docker").formatNullable[Docker] + (__ \ "type") + .formatNullable[mesos.ContainerInfo.Type] + .withDefault(mesos.ContainerInfo.Type.DOCKER) ~ + (__ \ "volumes").formatNullable[Seq[Volume]].withDefault(Nil) ~ + (__ \ "docker").formatNullable[Docker] )(Container(_, _, _), unlift(Container.unapply)) } @@ -259,7 +303,8 @@ trait IpAddressFormats { private[this] lazy val ValidPortProtocol: Reads[String] = { implicitly[Reads[String]] - .filter(ValidationError("Invalid protocol. Only 'udp' or 'tcp' are allowed."))( + .filter( + ValidationError("Invalid protocol. Only 'udp' or 'tcp' are allowed."))( DiscoveryInfo.Port.AllowedProtocols ) } @@ -275,19 +320,22 @@ trait IpAddressFormats { implicitly[Reads[Seq[DiscoveryInfo.Port]]] .filter(ValidationError("Port names are not unique."))(hasUniquePortNames) - .filter(ValidationError("There may be only one port with a particular port number/protocol combination."))( + .filter(ValidationError( + "There may be only one port with a particular port number/protocol combination."))( hasUniquePortNumberProtocol ) } implicit lazy val PortFormat: Format[DiscoveryInfo.Port] = ( (__ \ "number").format[Int] ~ - (__ \ "name").format[String] ~ - (__ \ "protocol").format[String](ValidPortProtocol) + (__ \ "name").format[String] ~ + (__ \ "protocol").format[String](ValidPortProtocol) )(DiscoveryInfo.Port(_, _, _), unlift(DiscoveryInfo.Port.unapply)) implicit lazy val DiscoveryInfoFormat: Format[DiscoveryInfo] = Format( - (__ \ "ports").read[Seq[DiscoveryInfo.Port]](ValidPorts).map(DiscoveryInfo(_)), + (__ \ "ports") + .read[Seq[DiscoveryInfo.Port]](ValidPorts) + .map(DiscoveryInfo(_)), Writes[DiscoveryInfo] { discoveryInfo => Json.obj("ports" -> discoveryInfo.ports.map(PortFormat.writes)) } @@ -295,8 +343,12 @@ trait IpAddressFormats { implicit lazy val IpAddressFormat: Format[IpAddress] = ( (__ \ "groups").formatNullable[Seq[String]].withDefault(Nil) ~ - (__ \ "labels").formatNullable[Map[String, String]].withDefault(Map.empty[String, String]) ~ - (__ \ "discovery").formatNullable[DiscoveryInfo].withDefault(DiscoveryInfo.empty) + (__ \ "labels") + .formatNullable[Map[String, String]] + .withDefault(Map.empty[String, String]) ~ + (__ \ "discovery") + .formatNullable[DiscoveryInfo] + .withDefault(DiscoveryInfo.empty) )(IpAddress(_, _, _), unlift(IpAddress.unapply)) } @@ -306,44 +358,45 @@ trait DeploymentFormats { implicit lazy val ByteArrayFormat: Format[Array[Byte]] = Format( Reads.of[Seq[Int]].map(_.map(_.toByte).toArray), - Writes { xs => - JsArray(xs.to[Seq].map(b => JsNumber(b.toInt))) - } + Writes { xs => JsArray(xs.to[Seq].map(b => JsNumber(b.toInt))) } ) implicit lazy val GroupUpdateFormat: Format[GroupUpdate] = ( (__ \ "id").formatNullable[PathId] ~ - (__ \ "apps").formatNullable[Set[AppDefinition]] ~ - (__ \ "groups").lazyFormatNullable(implicitly[Format[Set[GroupUpdate]]]) ~ - (__ \ "dependencies").formatNullable[Set[PathId]] ~ - (__ \ "scaleBy").formatNullable[Double] ~ - (__ \ "version").formatNullable[Timestamp] - ) (GroupUpdate(_, _, _, _, _, _), unlift(GroupUpdate.unapply)) - - implicit lazy val URLToStringMapFormat: Format[Map[java.net.URL, String]] = Format( - Reads.of[Map[String, String]] - .map( - _.map { case (k, v) => new java.net.URL(k) -> v } - ), - Writes[Map[java.net.URL, String]] { m => - Json.toJson(m) - } - ) - - implicit lazy val DeploymentActionWrites: Writes[DeploymentAction] = Writes { action => - Json.obj( - "type" -> action.getClass.getSimpleName, - "app" -> action.app.id + (__ \ "apps").formatNullable[Set[AppDefinition]] ~ + (__ \ "groups").lazyFormatNullable(implicitly[Format[Set[GroupUpdate]]]) ~ + (__ \ "dependencies").formatNullable[Set[PathId]] ~ + (__ \ "scaleBy").formatNullable[Double] ~ + (__ \ "version").formatNullable[Timestamp] + )(GroupUpdate(_, _, _, _, _, _), unlift(GroupUpdate.unapply)) + + implicit lazy val URLToStringMapFormat: Format[Map[java.net.URL, String]] = + Format( + Reads + .of[Map[String, String]] + .map( + _.map { case (k, v) => new java.net.URL(k) -> v } + ), + Writes[Map[java.net.URL, String]] { m => Json.toJson(m) } ) + + implicit lazy val DeploymentActionWrites: Writes[DeploymentAction] = Writes { + action => + Json.obj( + "type" -> action.getClass.getSimpleName, + "app" -> action.app.id + ) } - implicit lazy val DeploymentStepWrites: Writes[DeploymentStep] = Json.writes[DeploymentStep] + implicit lazy val DeploymentStepWrites: Writes[DeploymentStep] = + Json.writes[DeploymentStep] } trait EventFormats { import Formats._ - implicit lazy val AppTerminatedEventWrites: Writes[AppTerminatedEvent] = Json.writes[AppTerminatedEvent] + implicit lazy val AppTerminatedEventWrites: Writes[AppTerminatedEvent] = + Json.writes[AppTerminatedEvent] implicit lazy val ApiPostEventWrites: Writes[ApiPostEvent] = Writes { event => Json.obj( @@ -355,39 +408,59 @@ trait EventFormats { ) } - implicit lazy val DeploymentPlanWrites: Writes[DeploymentPlan] = Writes { plan => - Json.obj( - "id" -> plan.id, - "original" -> plan.original, - "target" -> plan.target, - "steps" -> plan.steps, - "version" -> plan.version - ) + implicit lazy val DeploymentPlanWrites: Writes[DeploymentPlan] = Writes { + plan => + Json.obj( + "id" -> plan.id, + "original" -> plan.original, + "target" -> plan.target, + "steps" -> plan.steps, + "version" -> plan.version + ) } implicit lazy val SubscribeWrites: Writes[Subscribe] = Json.writes[Subscribe] - implicit lazy val UnsubscribeWrites: Writes[Unsubscribe] = Json.writes[Unsubscribe] - implicit lazy val EventStreamAttachedWrites: Writes[EventStreamAttached] = Json.writes[EventStreamAttached] - implicit lazy val EventStreamDetachedWrites: Writes[EventStreamDetached] = Json.writes[EventStreamDetached] - implicit lazy val AddHealthCheckWrites: Writes[AddHealthCheck] = Json.writes[AddHealthCheck] - implicit lazy val RemoveHealthCheckWrites: Writes[RemoveHealthCheck] = Json.writes[RemoveHealthCheck] - implicit lazy val FailedHealthCheckWrites: Writes[FailedHealthCheck] = Json.writes[FailedHealthCheck] - implicit lazy val HealthStatusChangedWrites: Writes[HealthStatusChanged] = Json.writes[HealthStatusChanged] - implicit lazy val GroupChangeSuccessWrites: Writes[GroupChangeSuccess] = Json.writes[GroupChangeSuccess] - implicit lazy val GroupChangeFailedWrites: Writes[GroupChangeFailed] = Json.writes[GroupChangeFailed] - implicit lazy val DeploymentSuccessWrites: Writes[DeploymentSuccess] = Json.writes[DeploymentSuccess] - implicit lazy val DeploymentFailedWrites: Writes[DeploymentFailed] = Json.writes[DeploymentFailed] - implicit lazy val DeploymentStatusWrites: Writes[DeploymentStatus] = Json.writes[DeploymentStatus] - implicit lazy val DeploymentStepSuccessWrites: Writes[DeploymentStepSuccess] = Json.writes[DeploymentStepSuccess] - implicit lazy val DeploymentStepFailureWrites: Writes[DeploymentStepFailure] = Json.writes[DeploymentStepFailure] - implicit lazy val MesosStatusUpdateEventWrites: Writes[MesosStatusUpdateEvent] = Json.writes[MesosStatusUpdateEvent] - implicit lazy val MesosFrameworkMessageEventWrites: Writes[MesosFrameworkMessageEvent] = + implicit lazy val UnsubscribeWrites: Writes[Unsubscribe] = + Json.writes[Unsubscribe] + implicit lazy val EventStreamAttachedWrites: Writes[EventStreamAttached] = + Json.writes[EventStreamAttached] + implicit lazy val EventStreamDetachedWrites: Writes[EventStreamDetached] = + Json.writes[EventStreamDetached] + implicit lazy val AddHealthCheckWrites: Writes[AddHealthCheck] = + Json.writes[AddHealthCheck] + implicit lazy val RemoveHealthCheckWrites: Writes[RemoveHealthCheck] = + Json.writes[RemoveHealthCheck] + implicit lazy val FailedHealthCheckWrites: Writes[FailedHealthCheck] = + Json.writes[FailedHealthCheck] + implicit lazy val HealthStatusChangedWrites: Writes[HealthStatusChanged] = + Json.writes[HealthStatusChanged] + implicit lazy val GroupChangeSuccessWrites: Writes[GroupChangeSuccess] = + Json.writes[GroupChangeSuccess] + implicit lazy val GroupChangeFailedWrites: Writes[GroupChangeFailed] = + Json.writes[GroupChangeFailed] + implicit lazy val DeploymentSuccessWrites: Writes[DeploymentSuccess] = + Json.writes[DeploymentSuccess] + implicit lazy val DeploymentFailedWrites: Writes[DeploymentFailed] = + Json.writes[DeploymentFailed] + implicit lazy val DeploymentStatusWrites: Writes[DeploymentStatus] = + Json.writes[DeploymentStatus] + implicit lazy val DeploymentStepSuccessWrites: Writes[DeploymentStepSuccess] = + Json.writes[DeploymentStepSuccess] + implicit lazy val DeploymentStepFailureWrites: Writes[DeploymentStepFailure] = + Json.writes[DeploymentStepFailure] + implicit lazy val MesosStatusUpdateEventWrites + : Writes[MesosStatusUpdateEvent] = Json.writes[MesosStatusUpdateEvent] + implicit lazy val MesosFrameworkMessageEventWrites + : Writes[MesosFrameworkMessageEvent] = Json.writes[MesosFrameworkMessageEvent] - implicit lazy val SchedulerDisconnectedEventWrites: Writes[SchedulerDisconnectedEvent] = + implicit lazy val SchedulerDisconnectedEventWrites + : Writes[SchedulerDisconnectedEvent] = Json.writes[SchedulerDisconnectedEvent] - implicit lazy val SchedulerRegisteredEventWritesWrites: Writes[SchedulerRegisteredEvent] = + implicit lazy val SchedulerRegisteredEventWritesWrites + : Writes[SchedulerRegisteredEvent] = Json.writes[SchedulerRegisteredEvent] - implicit lazy val SchedulerReregisteredEventWritesWrites: Writes[SchedulerReregisteredEvent] = + implicit lazy val SchedulerReregisteredEventWritesWrites + : Writes[SchedulerReregisteredEvent] = Json.writes[SchedulerReregisteredEvent] //scalastyle:off cyclomatic.complexity @@ -420,10 +493,11 @@ trait EventFormats { trait EventSubscribersFormats { - implicit lazy val EventSubscribersWrites: Writes[EventSubscribers] = Writes { eventSubscribers => - Json.obj( - "callbackUrls" -> eventSubscribers.urls - ) + implicit lazy val EventSubscribersWrites: Writes[EventSubscribers] = Writes { + eventSubscribers => + Json.obj( + "callbackUrls" -> eventSubscribers.urls + ) } } @@ -453,15 +527,30 @@ trait HealthCheckFormats { ( (__ \ "path").formatNullable[String] ~ - (__ \ "protocol").formatNullable[Protocol].withDefault(DefaultProtocol) ~ - (__ \ "portIndex").formatNullable[Int] ~ - (__ \ "command").formatNullable[Command] ~ - (__ \ "gracePeriodSeconds").formatNullable[Long].withDefault(DefaultGracePeriod.toSeconds).asSeconds ~ - (__ \ "intervalSeconds").formatNullable[Long].withDefault(DefaultInterval.toSeconds).asSeconds ~ - (__ \ "timeoutSeconds").formatNullable[Long].withDefault(DefaultTimeout.toSeconds).asSeconds ~ - (__ \ "maxConsecutiveFailures").formatNullable[Int].withDefault(DefaultMaxConsecutiveFailures) ~ - (__ \ "ignoreHttp1xx").formatNullable[Boolean].withDefault(DefaultIgnoreHttp1xx) ~ - (__ \ "port").formatNullable[Int] + (__ \ "protocol") + .formatNullable[Protocol] + .withDefault(DefaultProtocol) ~ + (__ \ "portIndex").formatNullable[Int] ~ + (__ \ "command").formatNullable[Command] ~ + (__ \ "gracePeriodSeconds") + .formatNullable[Long] + .withDefault(DefaultGracePeriod.toSeconds) + .asSeconds ~ + (__ \ "intervalSeconds") + .formatNullable[Long] + .withDefault(DefaultInterval.toSeconds) + .asSeconds ~ + (__ \ "timeoutSeconds") + .formatNullable[Long] + .withDefault(DefaultTimeout.toSeconds) + .asSeconds ~ + (__ \ "maxConsecutiveFailures") + .formatNullable[Int] + .withDefault(DefaultMaxConsecutiveFailures) ~ + (__ \ "ignoreHttp1xx") + .formatNullable[Boolean] + .withDefault(DefaultIgnoreHttp1xx) ~ + (__ \ "port").formatNullable[Int] )(HealthCheck.apply, unlift(HealthCheck.unapply)) } } @@ -472,9 +561,9 @@ trait FetchUriFormats { implicit lazy val FetchUriFormat: Format[FetchUri] = { ( (__ \ "uri").format[String] ~ - (__ \ "extract").formatNullable[Boolean].withDefault(true) ~ - (__ \ "executable").formatNullable[Boolean].withDefault(false) ~ - (__ \ "cache").formatNullable[Boolean].withDefault(false) + (__ \ "extract").formatNullable[Boolean].withDefault(true) ~ + (__ \ "executable").formatNullable[Boolean].withDefault(false) ~ + (__ \ "cache").formatNullable[Boolean].withDefault(false) )(FetchUri(_, _, _, _), unlift(FetchUri.unapply)) } } @@ -489,9 +578,13 @@ trait AppAndGroupFormats { implicit lazy val UpgradeStrategyReads: Reads[UpgradeStrategy] = { import mesosphere.marathon.state.AppDefinition._ ( - (__ \ "minimumHealthCapacity").readNullable[Double].withDefault(DefaultUpgradeStrategy.minimumHealthCapacity) ~ - (__ \ "maximumOverCapacity").readNullable[Double].withDefault(DefaultUpgradeStrategy.maximumOverCapacity) - ) (UpgradeStrategy(_, _)) + (__ \ "minimumHealthCapacity") + .readNullable[Double] + .withDefault(DefaultUpgradeStrategy.minimumHealthCapacity) ~ + (__ \ "maximumOverCapacity") + .readNullable[Double] + .withDefault(DefaultUpgradeStrategy.maximumOverCapacity) + )(UpgradeStrategy(_, _)) } implicit lazy val ConstraintFormat: Format[Constraint] = Format( @@ -502,14 +595,20 @@ trait AppAndGroupFormats { json.asOpt[Seq[String]] match { case Some(seq) if seq.size >= 2 && seq.size <= 3 => if (validOperators.contains(seq(1))) { - val builder = Constraint.newBuilder().setField(seq(0)).setOperator(Operator.valueOf(seq(1))) + val builder = Constraint + .newBuilder() + .setField(seq(0)) + .setOperator(Operator.valueOf(seq(1))) if (seq.size == 3) builder.setValue(seq(2)) JsSuccess(builder.build()) + } else { + JsError( + s"Constraint operator must be one of the following: [${validOperators + .mkString(", ")}]") } - else { - JsError(s"Constraint operator must be one of the following: [${validOperators.mkString(", ")}]") - } - case _ => JsError("Constraint definition must be an array of strings in format: , [, value]") + case _ => + JsError( + "Constraint definition must be an array of strings in format: , [, value]") } } }, @@ -526,112 +625,192 @@ trait AppAndGroupFormats { val executorPattern = "^(//cmd)|(/?[^/]+(/[^/]+)*)|$".r ( (__ \ "id").read[PathId].filterNot(_.isRoot) ~ - (__ \ "cmd").readNullable[String](Reads.minLength(1)) ~ - (__ \ "args").readNullable[Seq[String]] ~ - (__ \ "user").readNullable[String] ~ - (__ \ "env").readNullable[Map[String, String]].withDefault(AppDefinition.DefaultEnv) ~ - (__ \ "instances").readNullable[Int].withDefault(AppDefinition.DefaultInstances) ~ - (__ \ "cpus").readNullable[Double].withDefault(AppDefinition.DefaultCpus) ~ - (__ \ "mem").readNullable[Double].withDefault(AppDefinition.DefaultMem) ~ - (__ \ "disk").readNullable[Double].withDefault(AppDefinition.DefaultDisk) ~ - (__ \ "executor").readNullable[String](Reads.pattern(executorPattern)) - .withDefault(AppDefinition.DefaultExecutor) ~ - (__ \ "constraints").readNullable[Set[Constraint]].withDefault(AppDefinition.DefaultConstraints) ~ - (__ \ "storeUrls").readNullable[Seq[String]].withDefault(AppDefinition.DefaultStoreUrls) ~ - (__ \ "requirePorts").readNullable[Boolean].withDefault(AppDefinition.DefaultRequirePorts) ~ - (__ \ "backoffSeconds").readNullable[Long].withDefault(AppDefinition.DefaultBackoff.toSeconds).asSeconds ~ - (__ \ "backoffFactor").readNullable[Double].withDefault(AppDefinition.DefaultBackoffFactor) ~ - (__ \ "maxLaunchDelaySeconds").readNullable[Long] - .withDefault(AppDefinition.DefaultMaxLaunchDelay.toSeconds).asSeconds ~ - (__ \ "container").readNullable[Container] ~ - (__ \ "healthChecks").readNullable[Set[HealthCheck]].withDefault(AppDefinition.DefaultHealthChecks) - ) (( - id, cmd, args, maybeString, env, instances, cpus, mem, disk, executor, constraints, storeUrls, - requirePorts, backoff, backoffFactor, maxLaunchDelay, container, checks - ) => AppDefinition( - id = id, cmd = cmd, args = args, user = maybeString, env = env, instances = instances, cpus = cpus, - mem = mem, disk = disk, executor = executor, constraints = constraints, storeUrls = storeUrls, - requirePorts = requirePorts, backoff = backoff, - backoffFactor = backoffFactor, maxLaunchDelay = maxLaunchDelay, container = container, - healthChecks = checks)).flatMap { app => - // necessary because of case class limitations (good for another 21 fields) - case class ExtraFields( - uris: Seq[String], - fetch: Seq[FetchUri], - dependencies: Set[PathId], - maybePorts: Option[Seq[Int]], - upgradeStrategy: Option[UpgradeStrategy], - labels: Map[String, String], - acceptedResourceRoles: Option[Set[String]], - ipAddress: Option[IpAddress], - version: Timestamp, - residency: Option[Residency], - maybePortDefinitions: Option[Seq[PortDefinition]]) { - def upgradeStrategyOrDefault: UpgradeStrategy = { - import UpgradeStrategy.{ forResidentTasks, empty } - upgradeStrategy.getOrElse(if (residency.isDefined) forResidentTasks else empty) - } - def residencyOrDefault: Option[Residency] = { - residency.orElse(if (app.persistentVolumes.nonEmpty) Some(Residency.defaultResidency) else None) - } + (__ \ "cmd").readNullable[String](Reads.minLength(1)) ~ + (__ \ "args").readNullable[Seq[String]] ~ + (__ \ "user").readNullable[String] ~ + (__ \ "env") + .readNullable[Map[String, String]] + .withDefault(AppDefinition.DefaultEnv) ~ + (__ \ "instances") + .readNullable[Int] + .withDefault(AppDefinition.DefaultInstances) ~ + (__ \ "cpus") + .readNullable[Double] + .withDefault(AppDefinition.DefaultCpus) ~ + (__ \ "mem") + .readNullable[Double] + .withDefault(AppDefinition.DefaultMem) ~ + (__ \ "disk") + .readNullable[Double] + .withDefault(AppDefinition.DefaultDisk) ~ + (__ \ "executor") + .readNullable[String](Reads.pattern(executorPattern)) + .withDefault(AppDefinition.DefaultExecutor) ~ + (__ \ "constraints") + .readNullable[Set[Constraint]] + .withDefault(AppDefinition.DefaultConstraints) ~ + (__ \ "storeUrls") + .readNullable[Seq[String]] + .withDefault(AppDefinition.DefaultStoreUrls) ~ + (__ \ "requirePorts") + .readNullable[Boolean] + .withDefault(AppDefinition.DefaultRequirePorts) ~ + (__ \ "backoffSeconds") + .readNullable[Long] + .withDefault(AppDefinition.DefaultBackoff.toSeconds) + .asSeconds ~ + (__ \ "backoffFactor") + .readNullable[Double] + .withDefault(AppDefinition.DefaultBackoffFactor) ~ + (__ \ "maxLaunchDelaySeconds") + .readNullable[Long] + .withDefault(AppDefinition.DefaultMaxLaunchDelay.toSeconds) + .asSeconds ~ + (__ \ "container").readNullable[Container] ~ + (__ \ "healthChecks") + .readNullable[Set[HealthCheck]] + .withDefault(AppDefinition.DefaultHealthChecks) + )( + ( + id, + cmd, + args, + maybeString, + env, + instances, + cpus, + mem, + disk, + executor, + constraints, + storeUrls, + requirePorts, + backoff, + backoffFactor, + maxLaunchDelay, + container, + checks + ) => + AppDefinition( + id = id, + cmd = cmd, + args = args, + user = maybeString, + env = env, + instances = instances, + cpus = cpus, + mem = mem, + disk = disk, + executor = executor, + constraints = constraints, + storeUrls = storeUrls, + requirePorts = requirePorts, + backoff = backoff, + backoffFactor = backoffFactor, + maxLaunchDelay = maxLaunchDelay, + container = container, + healthChecks = checks + )).flatMap { app => + // necessary because of case class limitations (good for another 21 fields) + case class ExtraFields( + uris: Seq[String], + fetch: Seq[FetchUri], + dependencies: Set[PathId], + maybePorts: Option[Seq[Int]], + upgradeStrategy: Option[UpgradeStrategy], + labels: Map[String, String], + acceptedResourceRoles: Option[Set[String]], + ipAddress: Option[IpAddress], + version: Timestamp, + residency: Option[Residency], + maybePortDefinitions: Option[Seq[PortDefinition]]) { + def upgradeStrategyOrDefault: UpgradeStrategy = { + import UpgradeStrategy.{forResidentTasks, empty} + upgradeStrategy.getOrElse( + if (residency.isDefined) forResidentTasks else empty) + } + def residencyOrDefault: Option[Residency] = { + residency.orElse( + if (app.persistentVolumes.nonEmpty) Some(Residency.defaultResidency) + else None) } + } - val extraReads: Reads[ExtraFields] = - ( - (__ \ "uris").readNullable[Seq[String]].withDefault(AppDefinition.DefaultUris) ~ - (__ \ "fetch").readNullable[Seq[FetchUri]].withDefault(AppDefinition.DefaultFetch) ~ - (__ \ "dependencies").readNullable[Set[PathId]].withDefault(AppDefinition.DefaultDependencies) ~ + val extraReads: Reads[ExtraFields] = + ( + (__ \ "uris") + .readNullable[Seq[String]] + .withDefault(AppDefinition.DefaultUris) ~ + (__ \ "fetch") + .readNullable[Seq[FetchUri]] + .withDefault(AppDefinition.DefaultFetch) ~ + (__ \ "dependencies") + .readNullable[Set[PathId]] + .withDefault(AppDefinition.DefaultDependencies) ~ (__ \ "ports").readNullable[Seq[Int]](uniquePorts) ~ (__ \ "upgradeStrategy").readNullable[UpgradeStrategy] ~ - (__ \ "labels").readNullable[Map[String, String]].withDefault(AppDefinition.DefaultLabels) ~ + (__ \ "labels") + .readNullable[Map[String, String]] + .withDefault(AppDefinition.DefaultLabels) ~ (__ \ "acceptedResourceRoles").readNullable[Set[String]](nonEmpty) ~ (__ \ "ipAddress").readNullable[IpAddress] ~ - (__ \ "version").readNullable[Timestamp].withDefault(Timestamp.now()) ~ + (__ \ "version") + .readNullable[Timestamp] + .withDefault(Timestamp.now()) ~ (__ \ "residency").readNullable[Residency] ~ (__ \ "portDefinitions").readNullable[Seq[PortDefinition]] - )(ExtraFields) - .filter(ValidationError("You cannot specify both uris and fetch fields")) { extra => - !(extra.uris.nonEmpty && extra.fetch.nonEmpty) - } - .filter(ValidationError("You cannot specify both an IP address and ports")) { extra => - val appWithoutPorts = extra.maybePorts.forall(_.isEmpty) && extra.maybePortDefinitions.forall(_.isEmpty) - appWithoutPorts || extra.ipAddress.isEmpty - } - .filter(ValidationError("You cannot specify both ports and port definitions")) { extra => - val portDefinitionsIsEquivalentToPorts = extra.maybePortDefinitions.map(_.map(_.port)) == extra.maybePorts - portDefinitionsIsEquivalentToPorts || extra.maybePorts.isEmpty || extra.maybePortDefinitions.isEmpty - } - - extraReads.map { extra => - // Normally, our default is one port. If an ipAddress is defined that would lead to an error - // if left unchanged. - def fetch: Seq[FetchUri] = - if (extra.fetch.nonEmpty) extra.fetch - else extra.uris.map { uri => FetchUri(uri = uri, extract = FetchUri.isExtract(uri)) } - - def portDefinitions: Seq[PortDefinition] = extra.ipAddress match { - case Some(ipAddress) => Seq.empty[PortDefinition] - case None => - extra.maybePortDefinitions.getOrElse { - extra.maybePorts.map { ports => - PortDefinitions.apply(ports: _*) - }.getOrElse(AppDefinition.DefaultPortDefinitions) - } + )(ExtraFields) + .filter(ValidationError( + "You cannot specify both uris and fetch fields")) { extra => + !(extra.uris.nonEmpty && extra.fetch.nonEmpty) + } + .filter(ValidationError( + "You cannot specify both an IP address and ports")) { extra => + val appWithoutPorts = + extra.maybePorts.forall(_.isEmpty) && extra.maybePortDefinitions + .forall(_.isEmpty) + appWithoutPorts || extra.ipAddress.isEmpty + } + .filter(ValidationError( + "You cannot specify both ports and port definitions")) { extra => + val portDefinitionsIsEquivalentToPorts = + extra.maybePortDefinitions.map(_.map(_.port)) == extra.maybePorts + portDefinitionsIsEquivalentToPorts || extra.maybePorts.isEmpty || extra.maybePortDefinitions.isEmpty } - app.copy( - fetch = fetch, - dependencies = extra.dependencies, - portDefinitions = portDefinitions, - upgradeStrategy = extra.upgradeStrategyOrDefault, - labels = extra.labels, - acceptedResourceRoles = extra.acceptedResourceRoles, - ipAddress = extra.ipAddress, - versionInfo = AppDefinition.VersionInfo.OnlyVersion(extra.version), - residency = extra.residencyOrDefault - ) + extraReads.map { extra => + // Normally, our default is one port. If an ipAddress is defined that would lead to an error + // if left unchanged. + def fetch: Seq[FetchUri] = + if (extra.fetch.nonEmpty) extra.fetch + else + extra.uris.map { uri => + FetchUri(uri = uri, extract = FetchUri.isExtract(uri)) + } + + def portDefinitions: Seq[PortDefinition] = extra.ipAddress match { + case Some(ipAddress) => Seq.empty[PortDefinition] + case None => + extra.maybePortDefinitions.getOrElse { + extra.maybePorts + .map { ports => PortDefinitions.apply(ports: _*) } + .getOrElse(AppDefinition.DefaultPortDefinitions) + } } + + app.copy( + fetch = fetch, + dependencies = extra.dependencies, + portDefinitions = portDefinitions, + upgradeStrategy = extra.upgradeStrategyOrDefault, + labels = extra.labels, + acceptedResourceRoles = extra.acceptedResourceRoles, + ipAddress = extra.ipAddress, + versionInfo = AppDefinition.VersionInfo.OnlyVersion(extra.version), + residency = extra.residencyOrDefault + ) } + } }.map(addHealthCheckPortIndexIfNecessary) /** @@ -640,18 +819,22 @@ trait AppAndGroupFormats { * In the past, healthCheck.portIndex was required and had a default value 0. When we introduced healthCheck.port, we * made it optional (also with ip-per-container in mind) and we have to re-add it in cases where it makes sense. */ - private[this] def addHealthCheckPortIndexIfNecessary(app: AppDefinition): AppDefinition = { - val hasPortMappings = app.container.exists(_.docker.exists(_.portMappings.exists(_.nonEmpty))) + private[this] def addHealthCheckPortIndexIfNecessary( + app: AppDefinition): AppDefinition = { + val hasPortMappings = + app.container.exists(_.docker.exists(_.portMappings.exists(_.nonEmpty))) val portIndexesMakeSense = app.portDefinitions.nonEmpty || hasPortMappings app.copy(healthChecks = app.healthChecks.map { healthCheck => def needsDefaultPortIndex = healthCheck.port.isEmpty && healthCheck.portIndex.isEmpty && healthCheck.protocol != Protocol.COMMAND - if (portIndexesMakeSense && needsDefaultPortIndex) healthCheck.copy(portIndex = Some(0)) + if (portIndexesMakeSense && needsDefaultPortIndex) + healthCheck.copy(portIndex = Some(0)) else healthCheck }) } - private[this] def addHealthCheckPortIndexIfNecessary(appUpdate: AppUpdate): AppUpdate = { + private[this] def addHealthCheckPortIndexIfNecessary( + appUpdate: AppUpdate): AppUpdate = { appUpdate.copy(healthChecks = appUpdate.healthChecks.map { healthChecks => healthChecks.map { healthCheck => def needsDefaultPortIndex = @@ -662,20 +845,24 @@ trait AppAndGroupFormats { }) } - implicit lazy val taskLostBehaviorWrites = Writes[TaskLostBehavior] { taskLostBehavior => - JsString(taskLostBehavior.name()) + implicit lazy val taskLostBehaviorWrites = Writes[TaskLostBehavior] { + taskLostBehavior => JsString(taskLostBehavior.name()) } implicit lazy val taskLostBehaviorReads = Reads[TaskLostBehavior] { json => json.validate[String].flatMap { behaviorString: String => - Option(TaskLostBehavior.valueOf(behaviorString)) match { case Some(taskLostBehavior) => JsSuccess(taskLostBehavior) case None => { val allowedTaskLostBehaviorString = - TaskLostBehavior.values().toSeq.map(_.getDescriptorForType.getName).mkString(", ") - - JsError(s"'$behaviorString' is not a valid taskLostBehavior. Allowed values: $allowedTaskLostBehaviorString") + TaskLostBehavior + .values() + .toSeq + .map(_.getDescriptorForType.getName) + .mkString(", ") + + JsError( + s"'$behaviorString' is not a valid taskLostBehavior. Allowed values: $allowedTaskLostBehaviorString") } } @@ -683,11 +870,13 @@ trait AppAndGroupFormats { } implicit lazy val ResidencyFormat: Format[Residency] = ( - (__ \ "relaunchEscalationTimeoutSeconds").formatNullable[Long] - .withDefault(Residency.defaultRelaunchEscalationTimeoutSeconds) ~ - (__ \ "taskLostBehavior").formatNullable[TaskLostBehavior] - .withDefault(Residency.defaultTaskLostBehaviour) - ) (Residency(_, _), unlift(Residency.unapply)) + (__ \ "relaunchEscalationTimeoutSeconds") + .formatNullable[Long] + .withDefault(Residency.defaultRelaunchEscalationTimeoutSeconds) ~ + (__ \ "taskLostBehavior") + .formatNullable[TaskLostBehavior] + .withDefault(Residency.defaultTaskLostBehaviour) + )(Residency(_, _), unlift(Residency.unapply)) implicit lazy val AppDefinitionWrites: Writes[AppDefinition] = { implicit lazy val durationWrites = Writes[FiniteDuration] { d => @@ -714,7 +903,8 @@ trait AppAndGroupFormats { // it should contain the service ports "ports" -> app.servicePorts, "portDefinitions" -> app.portDefinitions.zip(app.servicePorts).map { - case (portDefinition, servicePort) => portDefinition.copy(port = servicePort) + case (portDefinition, servicePort) => + portDefinition.copy(port = servicePort) }, "requirePorts" -> app.requirePorts, "backoffSeconds" -> app.backoff, @@ -739,7 +929,8 @@ trait AppAndGroupFormats { implicit lazy val VersionInfoWrites: Writes[AppDefinition.VersionInfo] = Writes[AppDefinition.VersionInfo] { - case AppDefinition.VersionInfo.FullVersionInfo(_, lastScalingAt, lastConfigChangeAt) => + case AppDefinition.VersionInfo + .FullVersionInfo(_, lastScalingAt, lastConfigChangeAt) => Json.obj( "lastScalingAt" -> lastScalingAt, "lastConfigChangeAt" -> lastConfigChangeAt @@ -779,11 +970,11 @@ trait AppAndGroupFormats { implicit lazy val TaskStatsWrites: Writes[TaskStats] = Writes { stats => - val statsJson = Json.obj("counts" -> TaskCountsWritesWithoutPrefix.writes(stats.counts)) + val statsJson = + Json.obj("counts" -> TaskCountsWritesWithoutPrefix.writes(stats.counts)) Json.obj( "stats" -> stats.maybeLifeTime.fold(ifEmpty = statsJson)(lifeTime => - statsJson ++ Json.obj("lifeTime" -> lifeTime) - ) + statsJson ++ Json.obj("lifeTime" -> lifeTime)) ) } @@ -808,9 +999,11 @@ trait AppAndGroupFormats { val maybeJson = Seq[Option[JsObject]]( info.maybeCounts.map(TaskCountsWrites.writes(_).as[JsObject]), - info.maybeDeployments.map(deployments => Json.obj("deployments" -> deployments)), + info.maybeDeployments.map(deployments => + Json.obj("deployments" -> deployments)), info.maybeTasks.map(tasks => Json.obj("tasks" -> tasks)), - info.maybeLastTaskFailure.map(lastFailure => Json.obj("lastTaskFailure" -> lastFailure)), + info.maybeLastTaskFailure.map(lastFailure => + Json.obj("lastTaskFailure" -> lastFailure)), info.maybeTaskStats.map(taskStats => Json.obj("taskStats" -> taskStats)) ).flatten @@ -819,13 +1012,12 @@ trait AppAndGroupFormats { implicit lazy val GroupInfoWrites: Writes[GroupInfo] = Writes { info => - val maybeJson = Seq[Option[JsObject]]( info.maybeApps.map(apps => Json.obj("apps" -> apps)), info.maybeGroups.map(groups => Json.obj("groups" -> groups)) ).flatten - val groupJson = Json.obj ( + val groupJson = Json.obj( "id" -> info.group.id, "dependencies" -> info.group.dependencies, "version" -> info.group.version @@ -836,66 +1028,104 @@ trait AppAndGroupFormats { implicit lazy val AppUpdateReads: Reads[AppUpdate] = ( (__ \ "id").readNullable[PathId].filterNot(_.exists(_.isRoot)) ~ - (__ \ "cmd").readNullable[String](Reads.minLength(1)) ~ - (__ \ "args").readNullable[Seq[String]] ~ - (__ \ "user").readNullable[String] ~ - (__ \ "env").readNullable[Map[String, String]] ~ - (__ \ "instances").readNullable[Int] ~ - (__ \ "cpus").readNullable[Double] ~ - (__ \ "mem").readNullable[Double] ~ - (__ \ "disk").readNullable[Double] ~ - (__ \ "executor").readNullable[String](Reads.pattern("^(//cmd)|(/?[^/]+(/[^/]+)*)|$".r)) ~ - (__ \ "constraints").readNullable[Set[Constraint]] ~ - (__ \ "storeUrls").readNullable[Seq[String]] ~ - (__ \ "requirePorts").readNullable[Boolean] ~ - (__ \ "backoffSeconds").readNullable[Long].map(_.map(_.seconds)) ~ - (__ \ "backoffFactor").readNullable[Double] ~ - (__ \ "maxLaunchDelaySeconds").readNullable[Long].map(_.map(_.seconds)) ~ - (__ \ "container").readNullable[Container] ~ - (__ \ "healthChecks").readNullable[Set[HealthCheck]] ~ - (__ \ "dependencies").readNullable[Set[PathId]] - ) ((id, cmd, args, user, env, instances, cpus, mem, disk, executor, constraints, storeUrls, requirePorts, - backoffSeconds, backoffFactor, maxLaunchDelaySeconds, container, healthChecks, dependencies) => + (__ \ "cmd").readNullable[String](Reads.minLength(1)) ~ + (__ \ "args").readNullable[Seq[String]] ~ + (__ \ "user").readNullable[String] ~ + (__ \ "env").readNullable[Map[String, String]] ~ + (__ \ "instances").readNullable[Int] ~ + (__ \ "cpus").readNullable[Double] ~ + (__ \ "mem").readNullable[Double] ~ + (__ \ "disk").readNullable[Double] ~ + (__ \ "executor").readNullable[String]( + Reads.pattern("^(//cmd)|(/?[^/]+(/[^/]+)*)|$".r)) ~ + (__ \ "constraints").readNullable[Set[Constraint]] ~ + (__ \ "storeUrls").readNullable[Seq[String]] ~ + (__ \ "requirePorts").readNullable[Boolean] ~ + (__ \ "backoffSeconds").readNullable[Long].map(_.map(_.seconds)) ~ + (__ \ "backoffFactor").readNullable[Double] ~ + (__ \ "maxLaunchDelaySeconds").readNullable[Long].map(_.map(_.seconds)) ~ + (__ \ "container").readNullable[Container] ~ + (__ \ "healthChecks").readNullable[Set[HealthCheck]] ~ + (__ \ "dependencies").readNullable[Set[PathId]] + )( + ( + id, + cmd, + args, + user, + env, + instances, + cpus, + mem, + disk, + executor, + constraints, + storeUrls, + requirePorts, + backoffSeconds, + backoffFactor, + maxLaunchDelaySeconds, + container, + healthChecks, + dependencies) => AppUpdate( - id = id, cmd = cmd, args = args, user = user, env = env, instances = instances, cpus = cpus, mem = mem, - disk = disk, executor = executor, constraints = constraints, storeUrls = storeUrls, requirePorts = requirePorts, - backoff = backoffSeconds, backoffFactor = backoffFactor, maxLaunchDelay = maxLaunchDelaySeconds, - container = container, healthChecks = healthChecks, dependencies = dependencies - ) - ).flatMap { update => + id = id, + cmd = cmd, + args = args, + user = user, + env = env, + instances = instances, + cpus = cpus, + mem = mem, + disk = disk, + executor = executor, + constraints = constraints, + storeUrls = storeUrls, + requirePorts = requirePorts, + backoff = backoffSeconds, + backoffFactor = backoffFactor, + maxLaunchDelay = maxLaunchDelaySeconds, + container = container, + healthChecks = healthChecks, + dependencies = dependencies + )) + .flatMap { update => // necessary because of case class limitations (good for another 21 fields) case class ExtraFields( - uris: Option[Seq[String]], - fetch: Option[Seq[FetchUri]], - upgradeStrategy: Option[UpgradeStrategy], - labels: Option[Map[String, String]], - version: Option[Timestamp], - acceptedResourceRoles: Option[Set[String]], - ipAddress: Option[IpAddress], - residency: Option[Residency], - ports: Option[Seq[Int]], - portDefinitions: Option[Seq[PortDefinition]]) + uris: Option[Seq[String]], + fetch: Option[Seq[FetchUri]], + upgradeStrategy: Option[UpgradeStrategy], + labels: Option[Map[String, String]], + version: Option[Timestamp], + acceptedResourceRoles: Option[Set[String]], + ipAddress: Option[IpAddress], + residency: Option[Residency], + ports: Option[Seq[Int]], + portDefinitions: Option[Seq[PortDefinition]]) val extraReads: Reads[ExtraFields] = ( (__ \ "uris").readNullable[Seq[String]] ~ - (__ \ "fetch").readNullable[Seq[FetchUri]] ~ - (__ \ "upgradeStrategy").readNullable[UpgradeStrategy] ~ - (__ \ "labels").readNullable[Map[String, String]] ~ - (__ \ "version").readNullable[Timestamp] ~ - (__ \ "acceptedResourceRoles").readNullable[Set[String]](nonEmpty) ~ - (__ \ "ipAddress").readNullable[IpAddress] ~ - (__ \ "residency").readNullable[Residency] ~ - (__ \ "ports").readNullable[Seq[Int]](uniquePorts) ~ - (__ \ "portDefinitions").readNullable[Seq[PortDefinition]] + (__ \ "fetch").readNullable[Seq[FetchUri]] ~ + (__ \ "upgradeStrategy").readNullable[UpgradeStrategy] ~ + (__ \ "labels").readNullable[Map[String, String]] ~ + (__ \ "version").readNullable[Timestamp] ~ + (__ \ "acceptedResourceRoles").readNullable[Set[String]](nonEmpty) ~ + (__ \ "ipAddress").readNullable[IpAddress] ~ + (__ \ "residency").readNullable[Residency] ~ + (__ \ "ports").readNullable[Seq[Int]](uniquePorts) ~ + (__ \ "portDefinitions").readNullable[Seq[PortDefinition]] )(ExtraFields) extraReads - .filter(ValidationError("You cannot specify both uris and fetch fields")) { extra => - !(extra.uris.nonEmpty && extra.fetch.nonEmpty) + .filter( + ValidationError("You cannot specify both uris and fetch fields")) { + extra => !(extra.uris.nonEmpty && extra.fetch.nonEmpty) } - .filter(ValidationError("You cannot specify both ports and port definitions")) { extra => - val portDefinitionsIsEquivalentToPorts = extra.portDefinitions.map(_.map(_.port)) == extra.ports + .filter(ValidationError( + "You cannot specify both ports and port definitions")) { extra => + val portDefinitionsIsEquivalentToPorts = + extra.portDefinitions.map(_.map(_.port)) == extra.ports portDefinitionsIsEquivalentToPorts || extra.ports.isEmpty || extra.portDefinitions.isEmpty } .map { extra => @@ -905,28 +1135,43 @@ trait AppAndGroupFormats { version = extra.version, acceptedResourceRoles = extra.acceptedResourceRoles, ipAddress = extra.ipAddress, - fetch = extra.fetch.orElse(extra.uris.map { seq => seq.map(FetchUri.apply(_)) }), + fetch = extra.fetch.orElse(extra.uris.map { seq => + seq.map(FetchUri.apply(_)) + }), residency = extra.residency, portDefinitions = extra.portDefinitions.orElse { extra.ports.map { ports => PortDefinitions.apply(ports: _*) } } ) } - }.map(addHealthCheckPortIndexIfNecessary) + } + .map(addHealthCheckPortIndexIfNecessary) implicit lazy val GroupFormat: Format[Group] = ( (__ \ "id").format[PathId] ~ - (__ \ "apps").formatNullable[Set[AppDefinition]].withDefault(Group.defaultApps) ~ - (__ \ "groups").lazyFormatNullable(implicitly[Format[Set[Group]]]).withDefault(Group.defaultGroups) ~ - (__ \ "dependencies").formatNullable[Set[PathId]].withDefault(Group.defaultDependencies) ~ - (__ \ "version").formatNullable[Timestamp].withDefault(Group.defaultVersion) - ) (Group(_, _, _, _, _), unlift(Group.unapply)) + (__ \ "apps") + .formatNullable[Set[AppDefinition]] + .withDefault(Group.defaultApps) ~ + (__ \ "groups") + .lazyFormatNullable(implicitly[Format[Set[Group]]]) + .withDefault(Group.defaultGroups) ~ + (__ \ "dependencies") + .formatNullable[Set[PathId]] + .withDefault(Group.defaultDependencies) ~ + (__ \ "version") + .formatNullable[Timestamp] + .withDefault(Group.defaultVersion) + )(Group(_, _, _, _, _), unlift(Group.unapply)) implicit lazy val PortDefinitionFormat: Format[PortDefinition] = ( - (__ \ "port").formatNullable[Int].withDefault(AppDefinition.RandomPortValue) ~ - (__ \ "protocol").formatNullable[String].withDefault("tcp") ~ - (__ \ "name").formatNullable[String] ~ - (__ \ "labels").formatNullable[Map[String, String]].withDefault(Map.empty[String, String]) + (__ \ "port") + .formatNullable[Int] + .withDefault(AppDefinition.RandomPortValue) ~ + (__ \ "protocol").formatNullable[String].withDefault("tcp") ~ + (__ \ "name").formatNullable[String] ~ + (__ \ "labels") + .formatNullable[Map[String, String]] + .withDefault(Map.empty[String, String]) )(PortDefinition(_, _, _, _), unlift(PortDefinition.unapply)) } @@ -934,11 +1179,12 @@ trait PluginFormats { implicit lazy val pluginDefinitionFormat: Writes[PluginDefinition] = ( (__ \ "id").write[String] ~ - (__ \ "plugin").write[String] ~ - (__ \ "implementation").write[String] ~ - (__ \ "tags").writeNullable[Set[String]] ~ - (__ \ "info").writeNullable[JsObject] - ) (d => (d.id, d.plugin, d.implementation, d.tags, d.info)) - - implicit lazy val pluginDefinitionsFormat: Writes[PluginDefinitions] = Json.writes[PluginDefinitions] + (__ \ "plugin").write[String] ~ + (__ \ "implementation").write[String] ~ + (__ \ "tags").writeNullable[Set[String]] ~ + (__ \ "info").writeNullable[JsObject] + )(d => (d.id, d.plugin, d.implementation, d.tags, d.info)) + + implicit lazy val pluginDefinitionsFormat: Writes[PluginDefinitions] = + Json.writes[PluginDefinitions] } diff --git a/repos/playframework/framework/src/play-integration-test/src/test/scala/play/it/bindings/GlobalSettingsSpec.scala b/repos/playframework/framework/src/play-integration-test/src/test/scala/play/it/bindings/GlobalSettingsSpec.scala index c4fa3e26782..94eaca1c54b 100644 --- a/repos/playframework/framework/src/play-integration-test/src/test/scala/play/it/bindings/GlobalSettingsSpec.scala +++ b/repos/playframework/framework/src/play-integration-test/src/test/scala/play/it/bindings/GlobalSettingsSpec.scala @@ -13,31 +13,39 @@ import play.api.mvc._ import play.api.mvc.Results._ import play.api.test._ import play.it._ -import play.it.http.{ MockController, JAction } +import play.it.http.{MockController, JAction} import play.mvc.Http import play.mvc.Http.Context -object NettyGlobalSettingsSpec extends GlobalSettingsSpec with NettyIntegrationSpecification +object NettyGlobalSettingsSpec + extends GlobalSettingsSpec + with NettyIntegrationSpecification -trait GlobalSettingsSpec extends PlaySpecification with WsTestClient with ServerIntegrationSpecification { +trait GlobalSettingsSpec + extends PlaySpecification + with WsTestClient + with ServerIntegrationSpecification { sequential - def withServer[T](applicationGlobal: Option[String])(uri: String)(block: String => T) = { + def withServer[T](applicationGlobal: Option[String])(uri: String)( + block: String => T) = { implicit val port = testServerPort - val additionalSettings = applicationGlobal.fold(Map.empty[String, String]) { s: String => - Map("application.global" -> s"play.it.bindings.$s") + val additionalSettings = applicationGlobal.fold(Map.empty[String, String]) { + s: String => Map("application.global" -> s"play.it.bindings.$s") } + ("play.http.requestHandler" -> "play.http.GlobalSettingsHttpRequestHandler") import play.api.inject._ import play.api.routing.sird._ lazy val app: Application = new GuiceApplicationBuilder() .configure(additionalSettings) .overrides(bind[Router].to(Router.from { - case p"/scala" => Action { request => - Ok(request.headers.get("X-Foo").getOrElse("null")) - } + case p"/scala" => + Action { request => + Ok(request.headers.get("X-Foo").getOrElse("null")) + } case p"/java" => JAction(app, JavaAction) - })).build() + })) + .build() running(TestServer(port, app)) { val response = await(wsUrl(uri).get()) block(response.body) @@ -45,16 +53,18 @@ trait GlobalSettingsSpec extends PlaySpecification with WsTestClient with Server } "GlobalSettings filters" should { - "not have X-Foo header when no Global is configured" in withServer(None)("/scala") { body => - body must_== "null" - } - "have X-Foo header when Scala Global with filters is configured" in withServer(Some("FooFilteringScalaGlobal"))("/scala") { body => + "not have X-Foo header when no Global is configured" in withServer(None)( + "/scala") { body => body must_== "null" } + "have X-Foo header when Scala Global with filters is configured" in withServer( + Some("FooFilteringScalaGlobal"))("/scala") { body => body must_== "filter-constructor-called-by-scala-global" } - "have X-Foo header when Java Global with filters is configured" in withServer(Some("FooFilteringJavaGlobal"))("/scala") { body => + "have X-Foo header when Java Global with filters is configured" in withServer( + Some("FooFilteringJavaGlobal"))("/scala") { body => body must_== "filter-default-constructor" } - "allow intercepting by Java GlobalSettings.onRequest" in withServer(Some("OnRequestJavaGlobal"))("/java") { body => + "allow intercepting by Java GlobalSettings.onRequest" in withServer( + Some("OnRequestJavaGlobal"))("/java") { body => body must_== "intercepted" } } @@ -65,7 +75,8 @@ trait GlobalSettingsSpec extends PlaySpecification with WsTestClient with Server class FooFilter(headerValue: String) extends EssentialFilter { def this() = this("filter-default-constructor") def apply(next: EssentialAction) = EssentialAction { request => - val fooBarHeaders = request.copy(headers = request.headers.add("X-Foo" -> headerValue)) + val fooBarHeaders = + request.copy(headers = request.headers.add("X-Foo" -> headerValue)) next(fooBarHeaders) } @@ -74,23 +85,28 @@ class FooFilter(headerValue: String) extends EssentialFilter { /** Scala GlobalSettings object that uses a filter */ object FooFilteringScalaGlobal extends play.api.GlobalSettings { override def doFilter(next: EssentialAction): EssentialAction = { - Filters(super.doFilter(next), new FooFilter("filter-constructor-called-by-scala-global")) + Filters( + super.doFilter(next), + new FooFilter("filter-constructor-called-by-scala-global")) } } /** Java GlobalSettings class that uses a filter */ class FooFilteringJavaGlobal extends play.GlobalSettings { - override def filters[T]() = Array[Class[T]](classOf[FooFilter].asInstanceOf[Class[T]]) + override def filters[T]() = + Array[Class[T]](classOf[FooFilter].asInstanceOf[Class[T]]) } class OnRequestJavaGlobal extends play.GlobalSettings { override def onRequest(request: Http.Request, actionMethod: Method) = { new play.mvc.Action.Simple { - def call(ctx: Context) = CompletableFuture.completedFuture(play.mvc.Results.ok("intercepted")) + def call(ctx: Context) = + CompletableFuture.completedFuture(play.mvc.Results.ok("intercepted")) } } } object JavaAction extends MockController { - def action = play.mvc.Results.ok(Option(request.getHeader("X-Foo")).getOrElse("null")) + def action = + play.mvc.Results.ok(Option(request.getHeader("X-Foo")).getOrElse("null")) } diff --git a/repos/scalding/scalding-commons/src/main/scala/com/twitter/scalding/examples/MergeTest.scala b/repos/scalding/scalding-commons/src/main/scala/com/twitter/scalding/examples/MergeTest.scala index b02a6311d0a..9c076823393 100644 --- a/repos/scalding/scalding-commons/src/main/scala/com/twitter/scalding/examples/MergeTest.scala +++ b/repos/scalding/scalding-commons/src/main/scala/com/twitter/scalding/examples/MergeTest.scala @@ -5,18 +5,19 @@ import scala.annotation.tailrec import com.twitter.scalding._ /** - * This example job does not yet work. It is a test for Kyro serialization - */ + * This example job does not yet work. It is a test for Kyro serialization + */ class MergeTest(args: Args) extends Job(args) { - TextLine(args("input")).flatMapTo('word) { _.split("""\s+""") } + TextLine(args("input")) + .flatMapTo('word) { _.split("""\s+""") } .groupBy('word) { _.size } //Now, let's get the top 10 words: .groupAll { - _.mapReduceMap(('word, 'size) -> 'list) /* map1 */ { tup: (String, Long) => List(tup) } /* reduce */ { (l1: List[(String, Long)], l2: List[(String, Long)]) => + _.mapReduceMap(('word, 'size) -> 'list) /* map1 */ { + tup: (String, Long) => List(tup) + } /* reduce */ { (l1: List[(String, Long)], l2: List[(String, Long)]) => mergeSort2(l1, l2, 10, cmpTup) - } /* map2 */ { - lout: List[(String, Long)] => lout - } + } /* map2 */ { lout: List[(String, Long)] => lout } } //Now expand out the list. .flatMap('list -> ('word, 'cnt)) { list: List[(String, Long)] => list } @@ -26,9 +27,17 @@ class MergeTest(args: Args) extends Job(args) { //Reverse sort to get the top items def cmpTup(t1: (String, Long), t2: (String, Long)) = t2._2.compareTo(t1._2) - def mergeSort2[T](v1: List[T], v2: List[T], k: Int, cmp: Function2[T, T, Int]) = { + def mergeSort2[T]( + v1: List[T], + v2: List[T], + k: Int, + cmp: Function2[T, T, Int]) = { @tailrec - def mergeSortR(acc: List[T], list1: List[T], list2: List[T], k: Int): List[T] = { + def mergeSortR( + acc: List[T], + list1: List[T], + list2: List[T], + k: Int): List[T] = { (list1, list2, k) match { case (_, _, 0) => acc case (x1 :: t1, x2 :: t2, _) => { @@ -40,7 +49,7 @@ class MergeTest(args: Args) extends Job(args) { } case (x1 :: t1, Nil, _) => mergeSortR(x1 :: acc, t1, Nil, k - 1) case (Nil, x2 :: t2, _) => mergeSortR(x2 :: acc, Nil, t2, k - 1) - case (Nil, Nil, _) => acc + case (Nil, Nil, _) => acc } } mergeSortR(Nil, v1, v2, k).reverse diff --git a/repos/scalding/scalding-core/src/main/scala/com/twitter/scalding/mathematics/Matrix.scala b/repos/scalding/scalding-core/src/main/scala/com/twitter/scalding/mathematics/Matrix.scala index 23e929a8467..c987d6db0ba 100644 --- a/repos/scalding/scalding-core/src/main/scala/com/twitter/scalding/mathematics/Matrix.scala +++ b/repos/scalding/scalding-core/src/main/scala/com/twitter/scalding/mathematics/Matrix.scala @@ -12,10 +12,10 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -*/ + */ package com.twitter.scalding.mathematics -import com.twitter.algebird.{ Monoid, Group, Ring, Field } +import com.twitter.algebird.{Monoid, Group, Ring, Field} import com.twitter.scalding._ import cascading.pipe.assembly._ @@ -31,139 +31,207 @@ import scala.math.max import scala.annotation.tailrec /** - * Matrix class - represents an infinite (hopefully sparse) matrix. - * any elements without a row are interpretted to be zero. - * the pipe hold ('rowIdx, 'colIdx, 'val) where in principle - * each row/col/value type is generic, with the constraint that ValT is a Ring[T] - * In practice, RowT and ColT are going to be Strings, Integers or Longs in the usual case. - * - * WARNING: - * It is NOT OKAY to use the same instance of Matrix/Row/Col with DIFFERENT Monoids/Rings/Fields. - * If you want to change, midstream, the Monoid on your ValT, you have to construct a new Matrix. - * This is due to caching of internal computation graphs. - * - * RowVector - handles matrices of row dimension one. It is the result of some of the matrix methods and has methods - * that return ColVector and diagonal matrix - * - * ColVector - handles matrices of col dimension one. It is the result of some of the matrix methods and has methods - * that return RowVector and diagonal matrix - */ - + * Matrix class - represents an infinite (hopefully sparse) matrix. + * any elements without a row are interpretted to be zero. + * the pipe hold ('rowIdx, 'colIdx, 'val) where in principle + * each row/col/value type is generic, with the constraint that ValT is a Ring[T] + * In practice, RowT and ColT are going to be Strings, Integers or Longs in the usual case. + * + * WARNING: + * It is NOT OKAY to use the same instance of Matrix/Row/Col with DIFFERENT Monoids/Rings/Fields. + * If you want to change, midstream, the Monoid on your ValT, you have to construct a new Matrix. + * This is due to caching of internal computation graphs. + * + * RowVector - handles matrices of row dimension one. It is the result of some of the matrix methods and has methods + * that return ColVector and diagonal matrix + * + * ColVector - handles matrices of col dimension one. It is the result of some of the matrix methods and has methods + * that return RowVector and diagonal matrix + */ // Implicit coversions // Add methods we want to add to pipes here: class MatrixPipeExtensions(pipe: Pipe) { - def toMatrix[RowT, ColT, ValT](fields: Fields)(implicit conv: TupleConverter[(RowT, ColT, ValT)], setter: TupleSetter[(RowT, ColT, ValT)]) = { - val matPipe = RichPipe(pipe).mapTo(fields -> ('row, 'col, 'val))((tup: (RowT, ColT, ValT)) => tup)(conv, setter) + def toMatrix[RowT, ColT, ValT](fields: Fields)( + implicit conv: TupleConverter[(RowT, ColT, ValT)], + setter: TupleSetter[(RowT, ColT, ValT)]) = { + val matPipe = RichPipe(pipe).mapTo(fields -> ('row, 'col, 'val))( + (tup: (RowT, ColT, ValT)) => tup)(conv, setter) new Matrix[RowT, ColT, ValT]('row, 'col, 'val, matPipe) } - def mapToMatrix[T, RowT, ColT, ValT](fields: Fields)(mapfn: T => (RowT, ColT, ValT))(implicit conv: TupleConverter[T], setter: TupleSetter[(RowT, ColT, ValT)]) = { - val matPipe = RichPipe(pipe).mapTo(fields -> ('row, 'col, 'val))(mapfn)(conv, setter) + def mapToMatrix[T, RowT, ColT, ValT](fields: Fields)( + mapfn: T => (RowT, ColT, ValT))( + implicit conv: TupleConverter[T], + setter: TupleSetter[(RowT, ColT, ValT)]) = { + val matPipe = + RichPipe(pipe).mapTo(fields -> ('row, 'col, 'val))(mapfn)(conv, setter) new Matrix[RowT, ColT, ValT]('row, 'col, 'val, matPipe) } - def flatMapToMatrix[T, RowT, ColT, ValT](fields: Fields)(flatMapfn: T => Iterable[(RowT, ColT, ValT)])(implicit conv: TupleConverter[T], setter: TupleSetter[(RowT, ColT, ValT)]) = { - val matPipe = RichPipe(pipe).flatMapTo(fields -> ('row, 'col, 'val))(flatMapfn)(conv, setter) + def flatMapToMatrix[T, RowT, ColT, ValT](fields: Fields)( + flatMapfn: T => Iterable[(RowT, ColT, ValT)])( + implicit conv: TupleConverter[T], + setter: TupleSetter[(RowT, ColT, ValT)]) = { + val matPipe = RichPipe(pipe).flatMapTo(fields -> ('row, 'col, 'val))( + flatMapfn)(conv, setter) new Matrix[RowT, ColT, ValT]('row, 'col, 'val, matPipe) } private def groupPipeIntoMap[ColT, ValT](pipe: Pipe): Pipe = { - pipe.groupBy('group, 'row) { - _.mapReduceMap[(ColT, ValT), Map[ColT, ValT], Map[ColT, ValT]](('col, 'val) -> 'val) { (colval: (ColT, ValT)) => Map(colval._1 -> colval._2) } { (l: Map[ColT, ValT], r: Map[ColT, ValT]) => l ++ r } { (red: Map[ColT, ValT]) => red } - } + pipe + .groupBy('group, 'row) { + _.mapReduceMap[(ColT, ValT), Map[ColT, ValT], Map[ColT, ValT]]( + ('col, 'val) -> 'val) { (colval: (ColT, ValT)) => + Map(colval._1 -> colval._2) + } { (l: Map[ColT, ValT], r: Map[ColT, ValT]) => l ++ r } { + (red: Map[ColT, ValT]) => red + } + } .rename('group, 'col) } - def toBlockMatrix[GroupT, RowT, ColT, ValT](fields: Fields)(implicit conv: TupleConverter[(GroupT, RowT, ColT, ValT)], setter: TupleSetter[(GroupT, RowT, ColT, ValT)]) = { + def toBlockMatrix[GroupT, RowT, ColT, ValT](fields: Fields)( + implicit conv: TupleConverter[(GroupT, RowT, ColT, ValT)], + setter: TupleSetter[(GroupT, RowT, ColT, ValT)]) = { val matPipe = RichPipe(pipe) - .mapTo(fields -> ('group, 'row, 'col, 'val))((tup: (GroupT, RowT, ColT, ValT)) => tup)(conv, setter) + .mapTo(fields -> ('group, 'row, 'col, 'val))( + (tup: (GroupT, RowT, ColT, ValT)) => tup)(conv, setter) - new BlockMatrix[GroupT, RowT, ColT, ValT](new Matrix('row, 'col, 'val, groupPipeIntoMap(matPipe))) + new BlockMatrix[GroupT, RowT, ColT, ValT]( + new Matrix('row, 'col, 'val, groupPipeIntoMap(matPipe))) } - def mapToBlockMatrix[T, GroupT, RowT, ColT, ValT](fields: Fields)(mapfn: T => (GroupT, RowT, ColT, ValT))(implicit conv: TupleConverter[T], setter: TupleSetter[(GroupT, RowT, ColT, ValT)]) = { + def mapToBlockMatrix[T, GroupT, RowT, ColT, ValT](fields: Fields)( + mapfn: T => (GroupT, RowT, ColT, ValT))( + implicit conv: TupleConverter[T], + setter: TupleSetter[(GroupT, RowT, ColT, ValT)]) = { val matPipe = RichPipe(pipe) .mapTo(fields -> ('group, 'row, 'col, 'val))(mapfn)(conv, setter) - new BlockMatrix[GroupT, RowT, ColT, ValT](new Matrix('row, 'col, 'val, groupPipeIntoMap(matPipe))) + new BlockMatrix[GroupT, RowT, ColT, ValT]( + new Matrix('row, 'col, 'val, groupPipeIntoMap(matPipe))) } - def flatMapToBlockMatrix[T, GroupT, RowT, ColT, ValT](fields: Fields)(flatMapfn: T => Iterable[(GroupT, RowT, ColT, ValT)])(implicit conv: TupleConverter[T], setter: TupleSetter[(GroupT, RowT, ColT, ValT)]) = { - val matPipe = RichPipe(pipe).flatMapTo(fields -> ('group, 'row, 'col, 'val))(flatMapfn)(conv, setter) - new BlockMatrix[GroupT, RowT, ColT, ValT](new Matrix('row, 'col, 'val, groupPipeIntoMap(matPipe))) + def flatMapToBlockMatrix[T, GroupT, RowT, ColT, ValT](fields: Fields)( + flatMapfn: T => Iterable[(GroupT, RowT, ColT, ValT)])( + implicit conv: TupleConverter[T], + setter: TupleSetter[(GroupT, RowT, ColT, ValT)]) = { + val matPipe = RichPipe(pipe).flatMapTo( + fields -> ('group, 'row, 'col, 'val))(flatMapfn)(conv, setter) + new BlockMatrix[GroupT, RowT, ColT, ValT]( + new Matrix('row, 'col, 'val, groupPipeIntoMap(matPipe))) } - def toColVector[RowT, ValT](fields: Fields)(implicit conv: TupleConverter[(RowT, ValT)], setter: TupleSetter[(RowT, ValT)]) = { - val vecPipe = RichPipe(pipe).mapTo(fields -> ('row, 'val))((tup: (RowT, ValT)) => tup)(conv, setter) + def toColVector[RowT, ValT](fields: Fields)( + implicit conv: TupleConverter[(RowT, ValT)], + setter: TupleSetter[(RowT, ValT)]) = { + val vecPipe = RichPipe(pipe).mapTo(fields -> ('row, 'val))( + (tup: (RowT, ValT)) => tup)(conv, setter) new ColVector[RowT, ValT]('row, 'val, vecPipe) } - def mapToColVector[T, RowT, ValT](fields: Fields)(mapfn: T => (RowT, ValT))(implicit conv: TupleConverter[T], setter: TupleSetter[(RowT, ValT)]) = { - val vecPipe = RichPipe(pipe).mapTo(fields -> ('row, 'val))(mapfn)(conv, setter) + def mapToColVector[T, RowT, ValT](fields: Fields)(mapfn: T => (RowT, ValT))( + implicit conv: TupleConverter[T], + setter: TupleSetter[(RowT, ValT)]) = { + val vecPipe = + RichPipe(pipe).mapTo(fields -> ('row, 'val))(mapfn)(conv, setter) new ColVector[RowT, ValT]('row, 'val, vecPipe) } - def flatMapToColVector[T, RowT, ValT](fields: Fields)(flatMapfn: T => Iterable[(RowT, ValT)])(implicit conv: TupleConverter[T], setter: TupleSetter[(RowT, ValT)]) = { - val vecPipe = RichPipe(pipe).flatMapTo(fields -> ('row, 'val))(flatMapfn)(conv, setter) + def flatMapToColVector[T, RowT, ValT](fields: Fields)( + flatMapfn: T => Iterable[(RowT, ValT)])( + implicit conv: TupleConverter[T], + setter: TupleSetter[(RowT, ValT)]) = { + val vecPipe = + RichPipe(pipe).flatMapTo(fields -> ('row, 'val))(flatMapfn)(conv, setter) new ColVector[RowT, ValT]('row, 'val, vecPipe) } - def toRowVector[ColT, ValT](fields: Fields)(implicit conv: TupleConverter[(ColT, ValT)], setter: TupleSetter[(ColT, ValT)]) = { - val vecPipe = RichPipe(pipe).mapTo(fields -> ('col, 'val))((tup: (ColT, ValT)) => tup)(conv, setter) + def toRowVector[ColT, ValT](fields: Fields)( + implicit conv: TupleConverter[(ColT, ValT)], + setter: TupleSetter[(ColT, ValT)]) = { + val vecPipe = RichPipe(pipe).mapTo(fields -> ('col, 'val))( + (tup: (ColT, ValT)) => tup)(conv, setter) new RowVector[ColT, ValT]('col, 'val, vecPipe) } - def mapToRowVector[T, ColT, ValT](fields: Fields)(mapfn: T => (ColT, ValT))(implicit conv: TupleConverter[T], setter: TupleSetter[(ColT, ValT)]) = { - val vecPipe = RichPipe(pipe).mapTo(fields -> ('col, 'val))(mapfn)(conv, setter) + def mapToRowVector[T, ColT, ValT](fields: Fields)(mapfn: T => (ColT, ValT))( + implicit conv: TupleConverter[T], + setter: TupleSetter[(ColT, ValT)]) = { + val vecPipe = + RichPipe(pipe).mapTo(fields -> ('col, 'val))(mapfn)(conv, setter) new RowVector[ColT, ValT]('col, 'val, vecPipe) } - def flatMapToRowVector[T, ColT, ValT](fields: Fields)(flatMapfn: T => Iterable[(ColT, ValT)])(implicit conv: TupleConverter[T], setter: TupleSetter[(ColT, ValT)]) = { - val vecPipe = RichPipe(pipe).flatMapTo(fields -> ('col, 'val))(flatMapfn)(conv, setter) + def flatMapToRowVector[T, ColT, ValT](fields: Fields)( + flatMapfn: T => Iterable[(ColT, ValT)])( + implicit conv: TupleConverter[T], + setter: TupleSetter[(ColT, ValT)]) = { + val vecPipe = + RichPipe(pipe).flatMapTo(fields -> ('col, 'val))(flatMapfn)(conv, setter) new RowVector[ColT, ValT]('col, 'val, vecPipe) } } /** - * This is the enrichment pattern on Mappable[T] for converting to Matrix types - */ -class MatrixMappableExtensions[T](mappable: Mappable[T])(implicit fd: FlowDef, mode: Mode) { - def toMatrix[Row, Col, Val](implicit ev: <:<[T, (Row, Col, Val)], - setter: TupleSetter[(Row, Col, Val)]): Matrix[Row, Col, Val] = + * This is the enrichment pattern on Mappable[T] for converting to Matrix types + */ +class MatrixMappableExtensions[T](mappable: Mappable[T])( + implicit fd: FlowDef, + mode: Mode) { + def toMatrix[Row, Col, Val]( + implicit ev: <:<[T, (Row, Col, Val)], + setter: TupleSetter[(Row, Col, Val)]): Matrix[Row, Col, Val] = mapToMatrix { _.asInstanceOf[(Row, Col, Val)] } - def mapToMatrix[Row, Col, Val](fn: (T) => (Row, Col, Val))(implicit setter: TupleSetter[(Row, Col, Val)]): Matrix[Row, Col, Val] = { + def mapToMatrix[Row, Col, Val](fn: (T) => (Row, Col, Val))( + implicit setter: TupleSetter[(Row, Col, Val)]): Matrix[Row, Col, Val] = { val fields = ('row, 'col, 'val) val matPipe = mappable.mapTo(fields)(fn) new Matrix[Row, Col, Val]('row, 'col, 'val, matPipe) } - def toBlockMatrix[Group, Row, Col, Val](implicit ev: <:<[T, (Group, Row, Col, Val)], ord: Ordering[(Group, Row)], - setter: TupleSetter[(Group, Row, Col, Val)]): BlockMatrix[Group, Row, Col, Val] = + def toBlockMatrix[Group, Row, Col, Val]( + implicit ev: <:<[T, (Group, Row, Col, Val)], + ord: Ordering[(Group, Row)], + setter: TupleSetter[(Group, Row, Col, Val)]) + : BlockMatrix[Group, Row, Col, Val] = mapToBlockMatrix { _.asInstanceOf[(Group, Row, Col, Val)] } - def mapToBlockMatrix[Group, Row, Col, Val](fn: (T) => (Group, Row, Col, Val))(implicit ord: Ordering[(Group, Row)]): BlockMatrix[Group, Row, Col, Val] = { + def mapToBlockMatrix[Group, Row, Col, Val](fn: (T) => (Group, Row, Col, Val))( + implicit ord: Ordering[(Group, Row)]) + : BlockMatrix[Group, Row, Col, Val] = { val matPipe = TypedPipe .from(mappable) .map(fn) .groupBy(t => (t._1, t._2)) - .mapValueStream(s => Iterator(s.map{ case (_, _, c, v) => (c, v) }.toMap)) + .mapValueStream(s => + Iterator(s.map { case (_, _, c, v) => (c, v) }.toMap)) .toTypedPipe - .map{ case ((g, r), m) => (r, g, m) } + .map { case ((g, r), m) => (r, g, m) } .toPipe(('row, 'col, 'val)) new BlockMatrix[Group, Row, Col, Val](new Matrix('row, 'col, 'val, matPipe)) } - def toRow[Row, Val](implicit ev: <:<[T, (Row, Val)], setter: TupleSetter[(Row, Val)]): RowVector[Row, Val] = mapToRow { _.asInstanceOf[(Row, Val)] } + def toRow[Row, Val]( + implicit ev: <:<[T, (Row, Val)], + setter: TupleSetter[(Row, Val)]): RowVector[Row, Val] = mapToRow { + _.asInstanceOf[(Row, Val)] + } - def mapToRow[Row, Val](fn: (T) => (Row, Val))(implicit setter: TupleSetter[(Row, Val)], fd: FlowDef): RowVector[Row, Val] = { + def mapToRow[Row, Val](fn: (T) => (Row, Val))( + implicit setter: TupleSetter[(Row, Val)], + fd: FlowDef): RowVector[Row, Val] = { val fields = ('row, 'val) val rowPipe = mappable.mapTo(fields)(fn) new RowVector[Row, Val]('row, 'val, rowPipe) } - def toCol[Col, Val](implicit ev: <:<[T, (Col, Val)], setter: TupleSetter[(Col, Val)]): ColVector[Col, Val] = + def toCol[Col, Val]( + implicit ev: <:<[T, (Col, Val)], + setter: TupleSetter[(Col, Val)]): ColVector[Col, Val] = mapToCol { _.asInstanceOf[(Col, Val)] } - def mapToCol[Col, Val](fn: (T) => (Col, Val))(implicit setter: TupleSetter[(Col, Val)]): ColVector[Col, Val] = { + def mapToCol[Col, Val](fn: (T) => (Col, Val))( + implicit setter: TupleSetter[(Col, Val)]): ColVector[Col, Val] = { val fields = ('col, 'val) val colPipe = mappable.mapTo(fields)(fn) new ColVector[Col, Val]('col, 'val, colPipe) @@ -173,10 +241,12 @@ class MatrixMappableExtensions[T](mappable: Mappable[T])(implicit fd: FlowDef, m object Matrix { // If this function is implicit, you can use the PipeExtensions methods on pipe implicit def pipeExtensions[P <% Pipe](p: P) = new MatrixPipeExtensions(p) - implicit def mappableExtensions[T](mt: Mappable[T])(implicit fd: FlowDef, mode: Mode) = + implicit def mappableExtensions[T]( + mt: Mappable[T])(implicit fd: FlowDef, mode: Mode) = new MatrixMappableExtensions(mt)(fd, mode) - def filterOutZeros[ValT](fSym: Symbol, group: Monoid[ValT])(fpipe: Pipe): Pipe = { + def filterOutZeros[ValT](fSym: Symbol, group: Monoid[ValT])( + fpipe: Pipe): Pipe = { fpipe.filter(fSym) { tup: Tuple1[ValT] => group.isNonZero(tup._1) } } @@ -191,10 +261,16 @@ object Matrix { implicit def literalToScalar[ValT](v: ValT) = new LiteralScalar(v) // Converts to Matrix for addition - implicit def diagonalToMatrix[RowT, ValT](diag: DiagonalMatrix[RowT, ValT]): Matrix[RowT, RowT, ValT] = { + implicit def diagonalToMatrix[RowT, ValT]( + diag: DiagonalMatrix[RowT, ValT]): Matrix[RowT, RowT, ValT] = { val colSym = newSymbol(Set(diag.idxSym, diag.valSym), 'col) val newPipe = diag.pipe.map(diag.idxSym -> colSym) { (x: RowT) => x } - new Matrix[RowT, RowT, ValT](diag.idxSym, colSym, diag.valSym, newPipe, diag.sizeHint) + new Matrix[RowT, RowT, ValT]( + diag.idxSym, + colSym, + diag.valSym, + newPipe, + diag.sizeHint) } } @@ -204,15 +280,23 @@ object Matrix { trait WrappedPipe { def fields: Fields def pipe: Pipe - def writePipe(src: Source, outFields: Fields = Fields.NONE)(implicit fd: FlowDef, mode: Mode) { - val toWrite = if (outFields.isNone) pipe else pipe.rename(fields -> outFields) + def writePipe(src: Source, outFields: Fields = Fields.NONE)( + implicit fd: FlowDef, + mode: Mode) { + val toWrite = + if (outFields.isNone) pipe else pipe.rename(fields -> outFields) toWrite.write(src) } } -class Matrix[RowT, ColT, ValT](val rowSym: Symbol, val colSym: Symbol, val valSym: Symbol, - inPipe: Pipe, val sizeHint: SizeHint = NoClue) - extends WrappedPipe with java.io.Serializable { +class Matrix[RowT, ColT, ValT]( + val rowSym: Symbol, + val colSym: Symbol, + val valSym: Symbol, + inPipe: Pipe, + val sizeHint: SizeHint = NoClue) + extends WrappedPipe + with java.io.Serializable { import Matrix._ import MatrixProduct._ import Dsl.ensureUniqueFields @@ -222,7 +306,8 @@ class Matrix[RowT, ColT, ValT](val rowSym: Symbol, val colSym: Symbol, val valSy lazy val pipe = inPipe.project(rowSym, colSym, valSym) def fields = rowColValSymbols - def pipeAs(toFields: Fields) = pipe.rename((rowSym, colSym, valSym) -> toFields) + def pipeAs(toFields: Fields) = + pipe.rename((rowSym, colSym, valSym) -> toFields) def hasHint = sizeHint != NoClue @@ -231,23 +316,31 @@ class Matrix[RowT, ColT, ValT](val rowSym: Symbol, val colSym: Symbol, val valSy (that != null) && (that.isInstanceOf[Matrix[_, _, _]]) && { val thatM = that.asInstanceOf[Matrix[RowT, ColT, ValT]] (this.rowSym == thatM.rowSym) && (this.colSym == thatM.colSym) && - (this.valSym == thatM.valSym) && (this.pipe == thatM.pipe) + (this.valSym == thatM.valSym) && (this.pipe == thatM.pipe) } } // Value operations - def mapValues[ValU](fn: (ValT) => ValU)(implicit mon: Monoid[ValU]): Matrix[RowT, ColT, ValU] = { + def mapValues[ValU](fn: (ValT) => ValU)( + implicit mon: Monoid[ValU]): Matrix[RowT, ColT, ValU] = { val newPipe = pipe.flatMap(valSym -> valSym) { imp: Tuple1[ValT] => //Ensure an arity of 1 //This annoying Tuple1 wrapping ensures we can handle ValT that may itself be a Tuple. mon.nonZeroOption(fn(imp._1)).map { Tuple1(_) } } - new Matrix[RowT, ColT, ValU](this.rowSym, this.colSym, this.valSym, newPipe, sizeHint) + new Matrix[RowT, ColT, ValU]( + this.rowSym, + this.colSym, + this.valSym, + newPipe, + sizeHint) } + /** - * like zipWithIndex.map but ONLY CHANGES THE VALUE not the index. - * Note you will only see non-zero elements on the matrix. This does not enumerate the zeros - */ - def mapWithIndex[ValNew](fn: (ValT, RowT, ColT) => ValNew)(implicit mon: Monoid[ValNew]): Matrix[RowT, ColT, ValNew] = { + * like zipWithIndex.map but ONLY CHANGES THE VALUE not the index. + * Note you will only see non-zero elements on the matrix. This does not enumerate the zeros + */ + def mapWithIndex[ValNew](fn: (ValT, RowT, ColT) => ValNew)( + implicit mon: Monoid[ValNew]): Matrix[RowT, ColT, ValNew] = { val newPipe = pipe.flatMap(fields -> fields) { imp: (RowT, ColT, ValT) => mon.nonZeroOption(fn(imp._3, imp._1, imp._2)).map { (imp._1, imp._2, _) } } @@ -260,12 +353,21 @@ class Matrix[RowT, ColT, ValT](val rowSym: Symbol, val colSym: Symbol, val valSy //This annoying Tuple1 wrapping ensures we can handle ValT that may itself be a Tuple. fn(imp._1) } - new Matrix[RowT, ColT, ValT](this.rowSym, this.colSym, this.valSym, newPipe, sizeHint) + new Matrix[RowT, ColT, ValT]( + this.rowSym, + this.colSym, + this.valSym, + newPipe, + sizeHint) } // Binarize values, all x != 0 become 1 - def binarizeAs[NewValT](implicit mon: Monoid[ValT], ring: Ring[NewValT]): Matrix[RowT, ColT, NewValT] = { - mapValues(x => if (mon.isNonZero(x)) { ring.one } else { ring.zero })(ring) + def binarizeAs[NewValT]( + implicit mon: Monoid[ValT], + ring: Ring[NewValT]): Matrix[RowT, ColT, NewValT] = { + mapValues(x => + if (mon.isNonZero(x)) { ring.one } + else { ring.zero })(ring) } // Row Operations @@ -273,20 +375,23 @@ class Matrix[RowT, ColT, ValT](val rowSym: Symbol, val colSym: Symbol, val valSy // Get a specific row def getRow(index: RowT): RowVector[ColT, ValT] = { val newPipe = inPipe - .filter(rowSym){ input: RowT => input == index } + .filter(rowSym) { input: RowT => input == index } .project(colSym, valSym) val newHint = sizeHint.setRows(1L) new RowVector[ColT, ValT](colSym, valSym, newPipe, newHint) } // Reduce all rows to a single row (zeros or ignored) - def reduceRowVectors(fn: (ValT, ValT) => ValT)(implicit mon: Monoid[ValT]): RowVector[ColT, ValT] = { + def reduceRowVectors(fn: (ValT, ValT) => ValT)( + implicit mon: Monoid[ValT]): RowVector[ColT, ValT] = { val newPipe = filterOutZeros(valSym, mon) { pipe.groupBy(colSym) { - _.reduce(valSym) { (x: Tuple1[ValT], y: Tuple1[ValT]) => Tuple1(fn(x._1, y._1)) } - // Matrices are generally huge and cascading has problems with diverse key spaces and - // mapside operations - // TODO continually evaluate if this is needed to avoid OOM + _.reduce(valSym) { (x: Tuple1[ValT], y: Tuple1[ValT]) => + Tuple1(fn(x._1, y._1)) + } + // Matrices are generally huge and cascading has problems with diverse key spaces and + // mapside operations + // TODO continually evaluate if this is needed to avoid OOM .reducers(MatrixProduct.numOfReducers(sizeHint)) .forceToReducers } @@ -303,77 +408,90 @@ class Matrix[RowT, ColT, ValT](val rowSym: Symbol, val colSym: Symbol, val valSy // Maps rows using a per-row mapping function // Use this for non-decomposable vector processing functions // and with vectors that can fit in one-single machine memory - def mapRows(fn: Iterable[(ColT, ValT)] => Iterable[(ColT, ValT)])(implicit mon: Monoid[ValT]): Matrix[RowT, ColT, ValT] = { + def mapRows(fn: Iterable[(ColT, ValT)] => Iterable[(ColT, ValT)])( + implicit mon: Monoid[ValT]): Matrix[RowT, ColT, ValT] = { val newListSym = Symbol(colSym.name + "_" + valSym.name + "_list") // TODO, I think we can count the rows/cols for free here val newPipe = filterOutZeros(valSym, mon) { - pipe.groupBy(rowSym) { - _.toList[(ColT, ValT)]((colSym, valSym) -> newListSym) - } - .flatMapTo((rowSym, newListSym) -> (rowSym, colSym, valSym)) { tup: (RowT, List[(ColT, ValT)]) => - val row = tup._1 - val list = fn(tup._2) - // Now flatten out to (row, col, val): - list.map{ imp: (ColT, ValT) => (row, imp._1, imp._2) } + pipe + .groupBy(rowSym) { + _.toList[(ColT, ValT)]((colSym, valSym) -> newListSym) + } + .flatMapTo((rowSym, newListSym) -> (rowSym, colSym, valSym)) { + tup: (RowT, List[(ColT, ValT)]) => + val row = tup._1 + val list = fn(tup._2) + // Now flatten out to (row, col, val): + list.map { imp: (ColT, ValT) => (row, imp._1, imp._2) } } } new Matrix[RowT, ColT, ValT](rowSym, colSym, valSym, newPipe, sizeHint) } - def topRowElems(k: Int)(implicit ord: Ordering[ValT]): Matrix[RowT, ColT, ValT] = { + def topRowElems(k: Int)( + implicit ord: Ordering[ValT]): Matrix[RowT, ColT, ValT] = { if (k < 1000) { topRowWithTiny(k) } else { - val newPipe = pipe.groupBy(rowSym){ - _ - .sortBy(valSym) - .reverse - .take(k) - } + val newPipe = pipe + .groupBy(rowSym) { + _.sortBy(valSym).reverse + .take(k) + } .project(rowSym, colSym, valSym) - new Matrix[RowT, ColT, ValT](rowSym, colSym, valSym, newPipe, FiniteHint(-1L, k)) + new Matrix[RowT, ColT, ValT]( + rowSym, + colSym, + valSym, + newPipe, + FiniteHint(-1L, k)) } } - protected def topRowWithTiny(k: Int)(implicit ord: Ordering[ValT]): Matrix[RowT, ColT, ValT] = { + protected def topRowWithTiny(k: Int)( + implicit ord: Ordering[ValT]): Matrix[RowT, ColT, ValT] = { val topSym = Symbol(colSym.name + "_topK") - val newPipe = pipe.groupBy(rowSym){ - _ - .sortWithTake((colSym, valSym) -> 'top_vals, k) ((t0: (ColT, ValT), t1: (ColT, ValT)) => ord.gt(t0._2, t1._2)) - } - .flatMapTo((0, 1) -> (rowSym, topSym, valSym)) { imp: (RowT, List[(ColT, ValT)]) => - val row = imp._1 - val list = imp._2 - list.map{ imp: (ColT, ValT) => (row, imp._1, imp._2) } + val newPipe = pipe + .groupBy(rowSym) { + _.sortWithTake((colSym, valSym) -> 'top_vals, k)( + (t0: (ColT, ValT), t1: (ColT, ValT)) => ord.gt(t0._2, t1._2)) } - new Matrix[RowT, ColT, ValT](rowSym, topSym, valSym, newPipe, FiniteHint(-1L, k)) + .flatMapTo((0, 1) -> (rowSym, topSym, valSym)) { + imp: (RowT, List[(ColT, ValT)]) => + val row = imp._1 + val list = imp._2 + list.map { imp: (ColT, ValT) => (row, imp._1, imp._2) } + } + new Matrix[RowT, ColT, ValT]( + rowSym, + topSym, + valSym, + newPipe, + FiniteHint(-1L, k)) } protected lazy val rowL0Norm = { val matD = this.asInstanceOf[Matrix[RowT, ColT, Double]] - (matD.mapValues { x => 1.0 } - .sumColVectors - .diag - .inverse) * matD + (matD.mapValues { x => 1.0 }.sumColVectors.diag.inverse) * matD } - def rowL0Normalize(implicit ev: =:=[ValT, Double]): Matrix[RowT, ColT, Double] = rowL0Norm + def rowL0Normalize( + implicit ev: =:=[ValT, Double]): Matrix[RowT, ColT, Double] = rowL0Norm protected lazy val rowL1Norm = { val matD = this.asInstanceOf[Matrix[RowT, ColT, Double]] - (matD.mapValues { x => x.abs } - .sumColVectors - .diag - .inverse) * matD + (matD.mapValues { x => x.abs }.sumColVectors.diag.inverse) * matD } // Row L1 normalization, only makes sense for Doubles // At the end of L1 normalization, sum of row values is one - def rowL1Normalize(implicit ev: =:=[ValT, Double]): Matrix[RowT, ColT, Double] = rowL1Norm + def rowL1Normalize( + implicit ev: =:=[ValT, Double]): Matrix[RowT, ColT, Double] = rowL1Norm protected lazy val rowL2Norm = { val matD = this.asInstanceOf[Matrix[RowT, ColT, Double]] - (matD.mapValues { x => x * x } + (matD + .mapValues { x => x * x } .sumColVectors .diag .mapValues { x => scala.math.sqrt(x) } @@ -382,7 +500,8 @@ class Matrix[RowT, ColT, ValT](val rowSym: Symbol, val colSym: Symbol, val valSy } // Row L2 normalization (can only be called for Double) // After this operation, the sum(|x|^2) along each row will be 1. - def rowL2Normalize(implicit ev: =:=[ValT, Double]): Matrix[RowT, ColT, Double] = rowL2Norm + def rowL2Normalize( + implicit ev: =:=[ValT, Double]): Matrix[RowT, ColT, Double] = rowL2Norm // Remove the mean of each row from each value in a row. // Double ValT only (only over the observed values, not dividing by the unobserved ones) @@ -400,15 +519,22 @@ class Matrix[RowT, ColT, ValT](val rowSym: Symbol, val colSym: Symbol, val valSy val newPipe = inPipe .groupBy(rowSym) { _.sizeAveStdev((valSym) -> ('size, 'ave, 'stdev)) } - .flatMapTo((rowSym, 'size, 'ave, 'stdev) -> (rowSym, newColSym, newValSym)) { tup: (RowT, Long, Double, Double) => - val row = tup._1 - val size = tup._2.toDouble - val avg = tup._3 - val stdev = tup._4 - List((row, 1, size), (row, 2, avg), (row, 3, stdev)) + .flatMapTo( + (rowSym, 'size, 'ave, 'stdev) -> (rowSym, newColSym, newValSym)) { + tup: (RowT, Long, Double, Double) => + val row = tup._1 + val size = tup._2.toDouble + val avg = tup._3 + val stdev = tup._4 + List((row, 1, size), (row, 2, avg), (row, 3, stdev)) } val newHint = sizeHint.setCols(3L) - new Matrix[RowT, Int, Double](rowSym, newColSym, newValSym, newPipe, newHint) + new Matrix[RowT, Int, Double]( + rowSym, + newColSym, + newValSym, + newPipe, + newHint) } def rowColValSymbols: Fields = (rowSym, colSym, valSym) @@ -419,7 +545,8 @@ class Matrix[RowT, ColT, ValT](val rowSym: Symbol, val colSym: Symbol, val valSy this.transpose.getRow(index).transpose } - def reduceColVectors(fn: (ValT, ValT) => ValT)(implicit mon: Monoid[ValT]): ColVector[RowT, ValT] = { + def reduceColVectors(fn: (ValT, ValT) => ValT)( + implicit mon: Monoid[ValT]): ColVector[RowT, ValT] = { this.transpose.reduceRowVectors(fn)(mon).transpose } @@ -427,11 +554,13 @@ class Matrix[RowT, ColT, ValT](val rowSym: Symbol, val colSym: Symbol, val valSy this.transpose.sumRowVectors(mon).transpose } - def mapCols(fn: Iterable[(RowT, ValT)] => Iterable[(RowT, ValT)])(implicit mon: Monoid[ValT]): Matrix[RowT, ColT, ValT] = { + def mapCols(fn: Iterable[(RowT, ValT)] => Iterable[(RowT, ValT)])( + implicit mon: Monoid[ValT]): Matrix[RowT, ColT, ValT] = { this.transpose.mapRows(fn)(mon).transpose } - def topColElems(k: Int)(implicit ord: Ordering[ValT]): Matrix[RowT, ColT, ValT] = { + def topColElems(k: Int)( + implicit ord: Ordering[ValT]): Matrix[RowT, ColT, ValT] = { this.transpose.topRowElems(k)(ord).transpose } @@ -455,7 +584,9 @@ class Matrix[RowT, ColT, ValT](val rowSym: Symbol, val colSym: Symbol, val valSy this.transpose.rowSizeAveStdev } - def *[That, Res](that: That)(implicit prod: MatrixProduct[Matrix[RowT, ColT, ValT], That, Res]): Res = { + def *[That, Res](that: That)( + implicit prod: MatrixProduct[Matrix[RowT, ColT, ValT], That, Res]) + : Res = { prod(this, that) } @@ -476,14 +607,16 @@ class Matrix[RowT, ColT, ValT](val rowSym: Symbol, val colSym: Symbol, val valSy // It assumes that the function fn(0,0) = 0 // This function assumes only one value in each matrix for a given row and column index. (no stacking of operations yet) // TODO: Optimize this later and be lazy on groups and joins. - def elemWiseOp(that: Matrix[RowT, ColT, ValT])(fn: (ValT, ValT) => ValT)(implicit mon: Monoid[ValT]): Matrix[RowT, ColT, ValT] = { + def elemWiseOp(that: Matrix[RowT, ColT, ValT])(fn: (ValT, ValT) => ValT)( + implicit mon: Monoid[ValT]): Matrix[RowT, ColT, ValT] = { // If the following is not true, it's not clear this is meaningful // assert(mon.isZero(fn(mon.zero,mon.zero)), "f is illdefined") zip(that).mapValues({ pair => fn(pair._1, pair._2) })(mon) } // Matrix summation - def +(that: Matrix[RowT, ColT, ValT])(implicit mon: Monoid[ValT]): Matrix[RowT, ColT, ValT] = { + def +(that: Matrix[RowT, ColT, ValT])( + implicit mon: Monoid[ValT]): Matrix[RowT, ColT, ValT] = { if (equals(that)) { // No need to do any groupBy operation mapValues { v => mon.plus(v, v) }(mon) @@ -493,24 +626,29 @@ class Matrix[RowT, ColT, ValT](val rowSym: Symbol, val colSym: Symbol, val valSy } // Matrix difference - def -(that: Matrix[RowT, ColT, ValT])(implicit grp: Group[ValT]): Matrix[RowT, ColT, ValT] = { + def -(that: Matrix[RowT, ColT, ValT])( + implicit grp: Group[ValT]): Matrix[RowT, ColT, ValT] = { elemWiseOp(that)((x, y) => grp.minus(x, y))(grp) } // Matrix elementwise product / Hadamard product // see http://en.wikipedia.org/wiki/Hadamard_product_(matrices) - def hProd(mat: Matrix[RowT, ColT, ValT])(implicit ring: Ring[ValT]): Matrix[RowT, ColT, ValT] = { + def hProd(mat: Matrix[RowT, ColT, ValT])( + implicit ring: Ring[ValT]): Matrix[RowT, ColT, ValT] = { elemWiseOp(mat)((x, y) => ring.times(x, y))(ring) } /** - * Considering the matrix as a graph, propagate the column: - * Does the calculation: \sum_{j where M(i,j) == true) c_j - */ - def propagate[ColValT](vec: ColVector[ColT, ColValT])(implicit ev: =:=[ValT, Boolean], monT: Monoid[ColValT]): ColVector[RowT, ColValT] = { + * Considering the matrix as a graph, propagate the column: + * Does the calculation: \sum_{j where M(i,j) == true) c_j + */ + def propagate[ColValT](vec: ColVector[ColT, ColValT])( + implicit ev: =:=[ValT, Boolean], + monT: Monoid[ColValT]): ColVector[RowT, ColValT] = { //This cast will always succeed: val boolMat = this.asInstanceOf[Matrix[RowT, ColT, Boolean]] - boolMat.zip(vec.transpose) + boolMat + .zip(vec.transpose) .mapValues { boolT => if (boolT._1) boolT._2 else monT.zero } .sumColVectors } @@ -527,16 +665,24 @@ class Matrix[RowT, ColT, ValT](val rowSym: Symbol, val colSym: Symbol, val valSy } def transpose: Matrix[ColT, RowT, ValT] = { - new Matrix[ColT, RowT, ValT](colSym, rowSym, valSym, inPipe, sizeHint.transpose) + new Matrix[ColT, RowT, ValT]( + colSym, + rowSym, + valSym, + inPipe, + sizeHint.transpose) } // This should only be called by def diagonal, which verifies that RowT == ColT protected lazy val mainDiagonal: DiagonalMatrix[RowT, ValT] = { - val diagPipe = pipe.filter(rowSym, colSym) { input: (RowT, RowT) => - (input._1 == input._2) - } + val diagPipe = pipe + .filter(rowSym, colSym) { input: (RowT, RowT) => (input._1 == input._2) } .project(rowSym, valSym) - new DiagonalMatrix[RowT, ValT](rowSym, valSym, diagPipe, SizeHint.asDiagonal(sizeHint)) + new DiagonalMatrix[RowT, ValT]( + rowSym, + valSym, + diagPipe, + SizeHint.asDiagonal(sizeHint)) } // This method will only work if the row type and column type are the same // the type constraint below means there is evidence that RowT and ColT are @@ -546,9 +692,11 @@ class Matrix[RowT, ColT, ValT](val rowSym: Symbol, val colSym: Symbol, val valSy /* * This just removes zeros after the join inside a zip */ - private def cleanUpZipJoin[ValU](otherVSym: Fields, pairMonoid: Monoid[(ValT, ValU)])(joinedPipe: Pipe): Pipe = { + private def cleanUpZipJoin[ValU]( + otherVSym: Fields, + pairMonoid: Monoid[(ValT, ValU)])(joinedPipe: Pipe): Pipe = { joinedPipe - //Make sure the zeros are set correctly: + //Make sure the zeros are set correctly: .map(valSym -> valSym) { (x: ValT) => if (null == x) pairMonoid.zero._1 else x } @@ -556,7 +704,9 @@ class Matrix[RowT, ColT, ValT](val rowSym: Symbol, val colSym: Symbol, val valSy if (null == x) pairMonoid.zero._2 else x } //Put the pair into a single item, ugly in scalding sadly... - .map(valSym.append(otherVSym) -> valSym) { tup: (ValT, ValU) => Tuple1(tup) } + .map(valSym.append(otherVSym) -> valSym) { tup: (ValT, ValU) => + Tuple1(tup) + } .project(rowColValSymbols) } @@ -564,7 +714,9 @@ class Matrix[RowT, ColT, ValT](val rowSym: Symbol, val colSym: Symbol, val valSy * This ensures both side rows and columns have correct indexes (fills in nulls from the other side * in the case of outerjoins) */ - private def cleanUpIndexZipJoin(fields: Fields, joinedPipe: RichPipe): Pipe = { + private def cleanUpIndexZipJoin( + fields: Fields, + joinedPipe: RichPipe): Pipe = { def anyRefOr(tup: (AnyRef, AnyRef)): (AnyRef, AnyRef) = { val newRef = Option(tup._1).getOrElse(tup._2) @@ -576,21 +728,38 @@ class Matrix[RowT, ColT, ValT](val rowSym: Symbol, val colSym: Symbol, val valSy } // Similar to zip, but combine the scalar on the right with all non-zeros in this matrix: - def nonZerosWith[ValU](that: Scalar[ValU]): Matrix[RowT, ColT, (ValT, ValU)] = { - val (newRFields, newRPipe) = ensureUniqueFields(rowColValSymbols, that.valSym, that.pipe) - val newPipe = inPipe.crossWithTiny(newRPipe) - .map(valSym.append(getField(newRFields, 0)) -> valSym) { leftRight: (ValT, ValU) => Tuple1(leftRight) } + def nonZerosWith[ValU]( + that: Scalar[ValU]): Matrix[RowT, ColT, (ValT, ValU)] = { + val (newRFields, newRPipe) = + ensureUniqueFields(rowColValSymbols, that.valSym, that.pipe) + val newPipe = inPipe + .crossWithTiny(newRPipe) + .map(valSym.append(getField(newRFields, 0)) -> valSym) { + leftRight: (ValT, ValU) => Tuple1(leftRight) + } .project(rowColValSymbols) - new Matrix[RowT, ColT, (ValT, ValU)](rowSym, colSym, valSym, newPipe, sizeHint) + new Matrix[RowT, ColT, (ValT, ValU)]( + rowSym, + colSym, + valSym, + newPipe, + sizeHint) } // Similar to zip, but combine the scalar on the right with all non-zeros in this matrix: - def nonZerosWith[ValU](that: LiteralScalar[ValU]): Matrix[RowT, ColT, (ValT, ValU)] = { - val newPipe = inPipe.map(valSym -> valSym) { left: Tuple1[ValT] => - Tuple1((left._1, that.value)) - } + def nonZerosWith[ValU]( + that: LiteralScalar[ValU]): Matrix[RowT, ColT, (ValT, ValU)] = { + val newPipe = inPipe + .map(valSym -> valSym) { left: Tuple1[ValT] => + Tuple1((left._1, that.value)) + } .project(rowColValSymbols) - new Matrix[RowT, ColT, (ValT, ValU)](rowSym, colSym, valSym, newPipe, sizeHint) + new Matrix[RowT, ColT, (ValT, ValU)]( + rowSym, + colSym, + valSym, + newPipe, + sizeHint) } // Override the size hint @@ -599,154 +768,263 @@ class Matrix[RowT, ColT, ValT](val rowSym: Symbol, val colSym: Symbol, val valSy } // Zip the given row with all the rows of the matrix - def zip[ValU](that: ColVector[RowT, ValU])(implicit pairMonoid: Monoid[(ValT, ValU)]): Matrix[RowT, ColT, (ValT, ValU)] = { - val (newRFields, newRPipe) = ensureUniqueFields(rowColValSymbols, (that.rowS, that.valS), that.pipe) + def zip[ValU](that: ColVector[RowT, ValU])( + implicit pairMonoid: Monoid[(ValT, ValU)]) + : Matrix[RowT, ColT, (ValT, ValU)] = { + val (newRFields, newRPipe) = + ensureUniqueFields(rowColValSymbols, (that.rowS, that.valS), that.pipe) // we must do an outer join to preserve zeros on one side or the other. // joinWithTiny can't do outer. And since the number // of values for each key is 1,2 it doesn't matter if we do joinWithSmaller or Larger: // TODO optimize the number of reducers val zipped = cleanUpZipJoin(getField(newRFields, 1), pairMonoid) { pipe - .joinWithSmaller(rowSym -> getField(newRFields, 0), newRPipe, new OuterJoin) - .thenDo{ p: RichPipe => cleanUpIndexZipJoin(rowSym.append(getField(newRFields, 0)), p) } + .joinWithSmaller( + rowSym -> getField(newRFields, 0), + newRPipe, + new OuterJoin) + .thenDo { p: RichPipe => + cleanUpIndexZipJoin(rowSym.append(getField(newRFields, 0)), p) + } } - new Matrix[RowT, ColT, (ValT, ValU)](rowSym, colSym, valSym, zipped, sizeHint + that.sizeH) + new Matrix[RowT, ColT, (ValT, ValU)]( + rowSym, + colSym, + valSym, + zipped, + sizeHint + that.sizeH) } // Zip the given row with all the rows of the matrix - def zip[ValU](that: RowVector[ColT, ValU])(implicit pairMonoid: Monoid[(ValT, ValU)]): Matrix[RowT, ColT, (ValT, ValU)] = { - val (newRFields, newRPipe) = ensureUniqueFields(rowColValSymbols, (that.colS, that.valS), that.pipe) + def zip[ValU](that: RowVector[ColT, ValU])( + implicit pairMonoid: Monoid[(ValT, ValU)]) + : Matrix[RowT, ColT, (ValT, ValU)] = { + val (newRFields, newRPipe) = + ensureUniqueFields(rowColValSymbols, (that.colS, that.valS), that.pipe) // we must do an outer join to preserve zeros on one side or the other. // joinWithTiny can't do outer. And since the number // of values for each key is 1,2 it doesn't matter if we do joinWithSmaller or Larger: // TODO optimize the number of reducers val zipped = cleanUpZipJoin(getField(newRFields, 1), pairMonoid) { pipe - .joinWithSmaller(colSym -> getField(newRFields, 0), newRPipe, new OuterJoin) - .thenDo{ p: RichPipe => cleanUpIndexZipJoin(colSym.append(getField(newRFields, 0)), p) } + .joinWithSmaller( + colSym -> getField(newRFields, 0), + newRPipe, + new OuterJoin) + .thenDo { p: RichPipe => + cleanUpIndexZipJoin(colSym.append(getField(newRFields, 0)), p) + } } - new Matrix[RowT, ColT, (ValT, ValU)](rowSym, colSym, valSym, zipped, sizeHint + that.sizeH) + new Matrix[RowT, ColT, (ValT, ValU)]( + rowSym, + colSym, + valSym, + zipped, + sizeHint + that.sizeH) } // This creates the matrix with pairs for the entries - def zip[ValU](that: Matrix[RowT, ColT, ValU])(implicit pairMonoid: Monoid[(ValT, ValU)]): Matrix[RowT, ColT, (ValT, ValU)] = { - val (newRFields, newRPipe) = ensureUniqueFields(rowColValSymbols, that.rowColValSymbols, that.pipe) + def zip[ValU](that: Matrix[RowT, ColT, ValU])( + implicit pairMonoid: Monoid[(ValT, ValU)]) + : Matrix[RowT, ColT, (ValT, ValU)] = { + val (newRFields, newRPipe) = + ensureUniqueFields(rowColValSymbols, that.rowColValSymbols, that.pipe) // we must do an outer join to preserve zeros on one side or the other. // joinWithTiny can't do outer. And since the number // of values for each key is 1,2 it doesn't matter if we do joinWithSmaller or Larger: // TODO optimize the number of reducers val zipped = cleanUpZipJoin[ValU](getField(newRFields, 2), pairMonoid) { pipe - .joinWithSmaller((rowSym, colSym) -> - (getField(newRFields, 0).append(getField(newRFields, 1))), - newRPipe, new OuterJoin) - .thenDo{ p: RichPipe => cleanUpIndexZipJoin(rowSym.append(getField(newRFields, 0)), p) } - .thenDo{ p: RichPipe => cleanUpIndexZipJoin(colSym.append(getField(newRFields, 1)), p) } + .joinWithSmaller( + (rowSym, colSym) -> + (getField(newRFields, 0).append(getField(newRFields, 1))), + newRPipe, + new OuterJoin) + .thenDo { p: RichPipe => + cleanUpIndexZipJoin(rowSym.append(getField(newRFields, 0)), p) + } + .thenDo { p: RichPipe => + cleanUpIndexZipJoin(colSym.append(getField(newRFields, 1)), p) + } } - new Matrix[RowT, ColT, (ValT, ValU)](rowSym, colSym, valSym, zipped, sizeHint + that.sizeHint) + new Matrix[RowT, ColT, (ValT, ValU)]( + rowSym, + colSym, + valSym, + zipped, + sizeHint + that.sizeHint) } - def toBlockMatrix[G](grouping: (RowT) => (G, RowT)): BlockMatrix[G, RowT, ColT, ValT] = { - inPipe.map('row -> ('group, 'row))(grouping).toBlockMatrix(('group, 'row, 'col, 'val)) + def toBlockMatrix[G]( + grouping: (RowT) => (G, RowT)): BlockMatrix[G, RowT, ColT, ValT] = { + inPipe + .map('row -> ('group, 'row))(grouping) + .toBlockMatrix(('group, 'row, 'col, 'val)) } /** - * removes any elements in this matrix that also appear in the argument matrix - */ - def removeElementsBy[ValU](that: Matrix[RowT, ColT, ValU]): Matrix[RowT, ColT, ValT] = { + * removes any elements in this matrix that also appear in the argument matrix + */ + def removeElementsBy[ValU]( + that: Matrix[RowT, ColT, ValU]): Matrix[RowT, ColT, ValT] = { val filterR = '___filterR___ val filterC = '___filterC___ val filterV = '___filterV___ - val joined = pipe.joinWithSmaller((rowSym, colSym) -> (filterR, filterC), - that.pipe.rename((that.rowSym, that.colSym, that.valSym) -> (filterR, filterC, filterV)), new LeftJoin) - val filtered = joined.filter(filterV){ x: ValU => null == x } - new Matrix[RowT, ColT, ValT](rowSym, colSym, valSym, filtered.project(rowSym, colSym, valSym)) + val joined = pipe.joinWithSmaller( + (rowSym, colSym) -> (filterR, filterC), + that.pipe.rename( + (that.rowSym, that.colSym, that.valSym) -> (filterR, filterC, filterV)), + new LeftJoin) + val filtered = joined.filter(filterV) { x: ValU => null == x } + new Matrix[RowT, ColT, ValT]( + rowSym, + colSym, + valSym, + filtered.project(rowSym, colSym, valSym)) } /** - * keep only elements in this matrix that also appear in the argument matrix - */ - def keepElementsBy[ValU](that: Matrix[RowT, ColT, ValU]): Matrix[RowT, ColT, ValT] = { + * keep only elements in this matrix that also appear in the argument matrix + */ + def keepElementsBy[ValU]( + that: Matrix[RowT, ColT, ValU]): Matrix[RowT, ColT, ValT] = { val keepR = '___keepR___ val keepC = '___keepC___ val keepV = '___keepV___ - val joined = pipe.joinWithSmaller((rowSym, colSym) -> (keepR, keepC), - that.pipe.rename((that.rowSym, that.colSym, that.valSym) -> (keepR, keepC, keepV))) - new Matrix[RowT, ColT, ValT](rowSym, colSym, valSym, joined.project(rowSym, colSym, valSym)) + val joined = pipe.joinWithSmaller( + (rowSym, colSym) -> (keepR, keepC), + that.pipe.rename( + (that.rowSym, that.colSym, that.valSym) -> (keepR, keepC, keepV))) + new Matrix[RowT, ColT, ValT]( + rowSym, + colSym, + valSym, + joined.project(rowSym, colSym, valSym)) } /** - * keeps only those rows that are in the joining column - */ - def keepRowsBy[ValU](that: ColVector[RowT, ValU]): Matrix[RowT, ColT, ValT] = { + * keeps only those rows that are in the joining column + */ + def keepRowsBy[ValU]( + that: ColVector[RowT, ValU]): Matrix[RowT, ColT, ValT] = { val index = '____index____ - val joined = pipe.joinWithSmaller(rowSym -> index, that.pipe.rename(that.rowS -> index).project(index)) - new Matrix[RowT, ColT, ValT](rowSym, colSym, valSym, joined.project(rowSym, colSym, valSym)) + val joined = pipe.joinWithSmaller( + rowSym -> index, + that.pipe.rename(that.rowS -> index).project(index)) + new Matrix[RowT, ColT, ValT]( + rowSym, + colSym, + valSym, + joined.project(rowSym, colSym, valSym)) } /** - * keeps only those cols that are in the joining rows - */ - def keepColsBy[ValU](that: RowVector[ColT, ValU]): Matrix[RowT, ColT, ValT] = { + * keeps only those cols that are in the joining rows + */ + def keepColsBy[ValU]( + that: RowVector[ColT, ValU]): Matrix[RowT, ColT, ValT] = { val index = '____index____ - val joined = pipe.joinWithSmaller(colSym -> index, that.pipe.rename(that.colS -> index).project(index)) - new Matrix[RowT, ColT, ValT](rowSym, colSym, valSym, joined.project(rowSym, colSym, valSym)) + val joined = pipe.joinWithSmaller( + colSym -> index, + that.pipe.rename(that.colS -> index).project(index)) + new Matrix[RowT, ColT, ValT]( + rowSym, + colSym, + valSym, + joined.project(rowSym, colSym, valSym)) } /** - * removes those rows that are in the joining column - */ - def removeRowsBy[ValU](that: ColVector[RowT, ValU]): Matrix[RowT, ColT, ValT] = { + * removes those rows that are in the joining column + */ + def removeRowsBy[ValU]( + that: ColVector[RowT, ValU]): Matrix[RowT, ColT, ValT] = { val index = '____index____ - val joined = pipe.joinWithSmaller(rowSym -> index, that.pipe.rename(that.rowS -> index).project(index), joiner = new LeftJoin) - new Matrix[RowT, ColT, ValT](rowSym, colSym, valSym, joined.filter(index){ x: RowT => null == x } - .project(rowSym, colSym, valSym)) + val joined = pipe.joinWithSmaller( + rowSym -> index, + that.pipe.rename(that.rowS -> index).project(index), + joiner = new LeftJoin) + new Matrix[RowT, ColT, ValT]( + rowSym, + colSym, + valSym, + joined + .filter(index) { x: RowT => null == x } + .project(rowSym, colSym, valSym)) } /** - * removes those cols that are in the joining column - */ - def removeColsBy[ValU](that: RowVector[ColT, ValU]): Matrix[RowT, ColT, ValT] = { + * removes those cols that are in the joining column + */ + def removeColsBy[ValU]( + that: RowVector[ColT, ValU]): Matrix[RowT, ColT, ValT] = { val index = '____index____ - val joined = pipe.joinWithSmaller(colSym -> index, that.pipe.rename(that.colS -> index).project(index), joiner = new LeftJoin) - new Matrix[RowT, ColT, ValT](rowSym, colSym, valSym, joined.filter(index){ x: ColT => null == x } - .project(rowSym, colSym, valSym)) + val joined = pipe.joinWithSmaller( + colSym -> index, + that.pipe.rename(that.colS -> index).project(index), + joiner = new LeftJoin) + new Matrix[RowT, ColT, ValT]( + rowSym, + colSym, + valSym, + joined + .filter(index) { x: ColT => null == x } + .project(rowSym, colSym, valSym)) } /** - * Write the matrix, optionally renaming row,col,val fields to the given fields - * then return this. - */ - def write(src: Source, outFields: Fields = Fields.NONE)(implicit fd: FlowDef, mode: Mode): Matrix[RowT, ColT, ValT] = { + * Write the matrix, optionally renaming row,col,val fields to the given fields + * then return this. + */ + def write(src: Source, outFields: Fields = Fields.NONE)( + implicit fd: FlowDef, + mode: Mode): Matrix[RowT, ColT, ValT] = { writePipe(src, outFields) this } } class LiteralScalar[ValT](val value: ValT) extends java.io.Serializable { - def *[That, Res](that: That)(implicit prod: MatrixProduct[LiteralScalar[ValT], That, Res]): Res = { prod(this, that) } + def *[That, Res](that: That)( + implicit prod: MatrixProduct[LiteralScalar[ValT], That, Res]): Res = { + prod(this, that) + } } -class Scalar[ValT](val valSym: Symbol, inPipe: Pipe) extends WrappedPipe with java.io.Serializable { +class Scalar[ValT](val valSym: Symbol, inPipe: Pipe) + extends WrappedPipe + with java.io.Serializable { def pipe = inPipe def fields = valSym - def *[That, Res](that: That)(implicit prod: MatrixProduct[Scalar[ValT], That, Res]): Res = { prod(this, that) } + def *[That, Res](that: That)( + implicit prod: MatrixProduct[Scalar[ValT], That, Res]): Res = { + prod(this, that) + } + /** - * Write the Scalar, optionally renaming val fields to the given fields - * then return this. - */ - def write(src: Source, outFields: Fields = Fields.NONE)(implicit fd: FlowDef, mode: Mode) = { + * Write the Scalar, optionally renaming val fields to the given fields + * then return this. + */ + def write(src: Source, outFields: Fields = Fields.NONE)( + implicit fd: FlowDef, + mode: Mode) = { writePipe(src, outFields) this } } -class DiagonalMatrix[IdxT, ValT](val idxSym: Symbol, - val valSym: Symbol, inPipe: Pipe, val sizeHint: SizeHint = FiniteHint(1L, -1L)) - extends WrappedPipe with java.io.Serializable { +class DiagonalMatrix[IdxT, ValT]( + val idxSym: Symbol, + val valSym: Symbol, + inPipe: Pipe, + val sizeHint: SizeHint = FiniteHint(1L, -1L)) + extends WrappedPipe + with java.io.Serializable { - def *[That, Res](that: That)(implicit prod: MatrixProduct[DiagonalMatrix[IdxT, ValT], That, Res]): Res = { prod(this, that) } + def *[That, Res](that: That)( + implicit prod: MatrixProduct[DiagonalMatrix[IdxT, ValT], That, Res]) + : Res = { prod(this, that) } def pipe = inPipe def fields = (idxSym, valSym) @@ -767,14 +1045,16 @@ class DiagonalMatrix[IdxT, ValT](val idxSym: Symbol, // Inverse of this matrix *IGNORING ZEROS* def inverse(implicit field: Field[ValT]): DiagonalMatrix[IdxT, ValT] = { val diagPipe = inPipe.flatMap(valSym -> valSym) { element: ValT => - field.nonZeroOption(element) + field + .nonZeroOption(element) .map { field.inverse } } new DiagonalMatrix[IdxT, ValT](idxSym, valSym, diagPipe, sizeHint) } // Value operations - def mapValues[ValU](fn: (ValT) => ValU)(implicit mon: Monoid[ValU]): DiagonalMatrix[IdxT, ValU] = { + def mapValues[ValU](fn: (ValT) => ValU)( + implicit mon: Monoid[ValU]): DiagonalMatrix[IdxT, ValU] = { val newPipe = pipe.flatMap(valSym -> valSym) { imp: Tuple1[ValT] => // Ensure an arity of 1 //This annoying Tuple1 wrapping ensures we can handle ValT that may itself be a Tuple. mon.nonZeroOption(fn(imp._1)).map { Tuple1(_) } @@ -783,28 +1063,42 @@ class DiagonalMatrix[IdxT, ValT](val idxSym: Symbol, } /** - * Write optionally renaming val fields to the given fields - * then return this. - */ - def write(src: Source, outFields: Fields = Fields.NONE)(implicit fd: FlowDef, mode: Mode) = { + * Write optionally renaming val fields to the given fields + * then return this. + */ + def write(src: Source, outFields: Fields = Fields.NONE)( + implicit fd: FlowDef, + mode: Mode) = { writePipe(src, outFields) this } } -class RowVector[ColT, ValT](val colS: Symbol, val valS: Symbol, inPipe: Pipe, val sizeH: SizeHint = FiniteHint(1L, -1L)) - extends java.io.Serializable with WrappedPipe { +class RowVector[ColT, ValT]( + val colS: Symbol, + val valS: Symbol, + inPipe: Pipe, + val sizeH: SizeHint = FiniteHint(1L, -1L)) + extends java.io.Serializable + with WrappedPipe { def pipe = inPipe.project(colS, valS) def fields = (colS, valS) - def *[That, Res](that: That)(implicit prod: MatrixProduct[RowVector[ColT, ValT], That, Res]): Res = { prod(this, that) } + def *[That, Res](that: That)( + implicit prod: MatrixProduct[RowVector[ColT, ValT], That, Res]): Res = { + prod(this, that) + } - def +(that: RowVector[ColT, ValT])(implicit mon: Monoid[ValT]) = (this.toMatrix(true) + that.toMatrix(true)).getRow(true) + def +(that: RowVector[ColT, ValT])(implicit mon: Monoid[ValT]) = + (this.toMatrix(true) + that.toMatrix(true)).getRow(true) - def -(that: RowVector[ColT, ValT])(implicit group: Group[ValT]) = (this.toMatrix(true) - that.toMatrix(true)).getRow(true) + def -(that: RowVector[ColT, ValT])(implicit group: Group[ValT]) = + (this.toMatrix(true) - that.toMatrix(true)).getRow(true) - def hProd(that: RowVector[ColT, ValT])(implicit ring: Ring[ValT]): RowVector[ColT, ValT] = (this.transpose hProd that.transpose).transpose + def hProd(that: RowVector[ColT, ValT])( + implicit ring: Ring[ValT]): RowVector[ColT, ValT] = + (this.transpose hProd that.transpose).transpose def transpose: ColVector[ColT, ValT] = { new ColVector[ColT, ValT](colS, valS, inPipe, sizeH.transpose) @@ -816,17 +1110,22 @@ class RowVector[ColT, ValT](val colS: Symbol, val valS: Symbol, inPipe: Pipe, va } /** - * like zipWithIndex.map but ONLY CHANGES THE VALUE not the index. - * Note you will only see non-zero elements on the vector. This does not enumerate the zeros - */ - def mapWithIndex[ValNew](fn: (ValT, ColT) => ValNew)(implicit mon: Monoid[ValNew]): RowVector[ColT, ValNew] = { - val newPipe = pipe.mapTo((valS, colS) -> (valS, colS)) { tup: (ValT, ColT) => (fn(tup._1, tup._2), tup._2) } + * like zipWithIndex.map but ONLY CHANGES THE VALUE not the index. + * Note you will only see non-zero elements on the vector. This does not enumerate the zeros + */ + def mapWithIndex[ValNew](fn: (ValT, ColT) => ValNew)( + implicit mon: Monoid[ValNew]): RowVector[ColT, ValNew] = { + val newPipe = pipe + .mapTo((valS, colS) -> (valS, colS)) { tup: (ValT, ColT) => + (fn(tup._1, tup._2), tup._2) + } .filter(valS) { (v: ValNew) => mon.isNonZero(v) } new RowVector(colS, valS, newPipe, sizeH) } // Value operations - def mapValues[ValU](fn: (ValT) => ValU)(implicit mon: Monoid[ValU]): RowVector[ColT, ValU] = { + def mapValues[ValU](fn: (ValT) => ValU)( + implicit mon: Monoid[ValU]): RowVector[ColT, ValU] = { val newPipe = pipe.flatMap(valS -> valS) { imp: Tuple1[ValT] => // Ensure an arity of 1 //This annoying Tuple1 wrapping ensures we can handle ValT that may itself be a Tuple. mon.nonZeroOption(fn(imp._1)).map { Tuple1(_) } @@ -835,28 +1134,31 @@ class RowVector[ColT, ValT](val colS: Symbol, val valS: Symbol, inPipe: Pipe, va } /** - * Do a right-propogation of a row, transpose of Matrix.propagate - */ - def propagate[MatColT](mat: Matrix[ColT, MatColT, Boolean])(implicit monT: Monoid[ValT]): RowVector[MatColT, ValT] = { + * Do a right-propogation of a row, transpose of Matrix.propagate + */ + def propagate[MatColT](mat: Matrix[ColT, MatColT, Boolean])( + implicit monT: Monoid[ValT]): RowVector[MatColT, ValT] = { mat.transpose.propagate(this.transpose).transpose } def L0Normalize(implicit ev: =:=[ValT, Double]): RowVector[ColT, ValT] = { val normedMatrix = this.toMatrix(0).rowL0Normalize - new RowVector(normedMatrix.colSym, + new RowVector( + normedMatrix.colSym, normedMatrix.valSym, normedMatrix.pipe.project(normedMatrix.colSym, normedMatrix.valSym)) } def L1Normalize(implicit ev: =:=[ValT, Double]): RowVector[ColT, ValT] = { val normedMatrix = this.toMatrix(0).rowL1Normalize - new RowVector(normedMatrix.colSym, + new RowVector( + normedMatrix.colSym, normedMatrix.valSym, normedMatrix.pipe.project(normedMatrix.colSym, normedMatrix.valSym)) } def sum(implicit mon: Monoid[ValT]): Scalar[ValT] = { - val scalarPipe = pipe.groupAll{ + val scalarPipe = pipe.groupAll { _.reduce(valS -> valS) { (left: Tuple1[ValT], right: Tuple1[ValT]) => Tuple1(mon.plus(left._1, right._1)) } @@ -872,29 +1174,40 @@ class RowVector[ColT, ValT](val colS: Symbol, val valS: Symbol, inPipe: Pipe, va val ordValS = new Fields(fieldName) ordValS.setComparator(fieldName, ord) - val newPipe = pipe.groupAll{ - _ - .sortBy(ordValS) - .reverse - .take(k) - }.project(colS, valS) - new RowVector[ColT, ValT](colS, valS, newPipe, sizeH.setCols(k).setRows(1L)) + val newPipe = pipe + .groupAll { + _.sortBy(ordValS).reverse + .take(k) + } + .project(colS, valS) + new RowVector[ColT, ValT]( + colS, + valS, + newPipe, + sizeH.setCols(k).setRows(1L)) } } - protected def topWithTiny(k: Int)(implicit ord: Ordering[ValT]): RowVector[ColT, ValT] = { + protected def topWithTiny(k: Int)( + implicit ord: Ordering[ValT]): RowVector[ColT, ValT] = { val topSym = Symbol(colS.name + "_topK") - val newPipe = pipe.groupAll{ - _ - .sortWithTake((colS, valS) -> 'top_vals, k) ((t0: (ColT, ValT), t1: (ColT, ValT)) => ord.gt(t0._2, t1._2)) - } + val newPipe = pipe + .groupAll { + _.sortWithTake((colS, valS) -> 'top_vals, k)( + (t0: (ColT, ValT), t1: (ColT, ValT)) => ord.gt(t0._2, t1._2)) + } .flatMap('top_vals -> (topSym, valS)) { imp: List[(ColT, ValT)] => imp } - new RowVector[ColT, ValT](topSym, valS, newPipe, sizeH.setCols(k).setRows(1L)) + new RowVector[ColT, ValT]( + topSym, + valS, + newPipe, + sizeH.setCols(k).setRows(1L)) } def toMatrix[RowT](rowId: RowT): Matrix[RowT, ColT, ValT] = { val rowSym = newSymbol(Set(colS, valS), 'row) //Matrix.newSymbol(Set(colS, valS), 'row) - val newPipe = inPipe.map(() -> rowSym){ u: Unit => rowId } + val newPipe = inPipe + .map(() -> rowSym) { u: Unit => rowId } .project(rowSym, colS, valS) new Matrix[RowT, ColT, ValT](rowSym, colS, valS, newPipe, sizeH.setRows(1L)) } @@ -905,28 +1218,42 @@ class RowVector[ColT, ValT](val colS: Symbol, val valS: Symbol, inPipe: Pipe, va } /** - * Write optionally renaming val fields to the given fields - * then return this. - */ - def write(src: Source, outFields: Fields = Fields.NONE)(implicit fd: FlowDef, mode: Mode) = { + * Write optionally renaming val fields to the given fields + * then return this. + */ + def write(src: Source, outFields: Fields = Fields.NONE)( + implicit fd: FlowDef, + mode: Mode) = { writePipe(src, outFields) this } } -class ColVector[RowT, ValT](val rowS: Symbol, val valS: Symbol, inPipe: Pipe, val sizeH: SizeHint = FiniteHint(-1L, 1L)) - extends java.io.Serializable with WrappedPipe { +class ColVector[RowT, ValT]( + val rowS: Symbol, + val valS: Symbol, + inPipe: Pipe, + val sizeH: SizeHint = FiniteHint(-1L, 1L)) + extends java.io.Serializable + with WrappedPipe { def pipe = inPipe.project(rowS, valS) def fields = (rowS, valS) - def *[That, Res](that: That)(implicit prod: MatrixProduct[ColVector[RowT, ValT], That, Res]): Res = { prod(this, that) } + def *[That, Res](that: That)( + implicit prod: MatrixProduct[ColVector[RowT, ValT], That, Res]): Res = { + prod(this, that) + } - def +(that: ColVector[RowT, ValT])(implicit mon: Monoid[ValT]) = (this.toMatrix(true) + that.toMatrix(true)).getCol(true) + def +(that: ColVector[RowT, ValT])(implicit mon: Monoid[ValT]) = + (this.toMatrix(true) + that.toMatrix(true)).getCol(true) - def -(that: ColVector[RowT, ValT])(implicit group: Group[ValT]) = (this.toMatrix(true) - that.toMatrix(true)).getCol(true) + def -(that: ColVector[RowT, ValT])(implicit group: Group[ValT]) = + (this.toMatrix(true) - that.toMatrix(true)).getCol(true) - def hProd(that: ColVector[RowT, ValT])(implicit ring: Ring[ValT]): ColVector[RowT, ValT] = (this.toMatrix(true) hProd that.toMatrix(true)).getCol(true) + def hProd(that: ColVector[RowT, ValT])( + implicit ring: Ring[ValT]): ColVector[RowT, ValT] = + (this.toMatrix(true) hProd that.toMatrix(true)).getCol(true) def transpose: RowVector[RowT, ValT] = { new RowVector[RowT, ValT](rowS, valS, inPipe, sizeH.transpose) @@ -938,13 +1265,16 @@ class ColVector[RowT, ValT](val rowS: Symbol, val valS: Symbol, inPipe: Pipe, va } /** - * like zipWithIndex.map but ONLY CHANGES THE VALUE not the index. - * Note you will only see non-zero elements on the vector. This does not enumerate the zeros - */ - def mapWithIndex[ValNew](fn: (ValT, RowT) => ValNew)(implicit mon: Monoid[ValNew]): ColVector[RowT, ValNew] = transpose.mapWithIndex(fn).transpose + * like zipWithIndex.map but ONLY CHANGES THE VALUE not the index. + * Note you will only see non-zero elements on the vector. This does not enumerate the zeros + */ + def mapWithIndex[ValNew](fn: (ValT, RowT) => ValNew)( + implicit mon: Monoid[ValNew]): ColVector[RowT, ValNew] = + transpose.mapWithIndex(fn).transpose // Value operations - def mapValues[ValU](fn: (ValT) => ValU)(implicit mon: Monoid[ValU]): ColVector[RowT, ValU] = { + def mapValues[ValU](fn: (ValT) => ValU)( + implicit mon: Monoid[ValU]): ColVector[RowT, ValU] = { val newPipe = pipe.flatMap(valS -> valS) { imp: Tuple1[ValT] => // Ensure an arity of 1 //This annoying Tuple1 wrapping ensures we can handle ValT that may itself be a Tuple. mon.nonZeroOption(fn(imp._1)).map { Tuple1(_) } @@ -953,7 +1283,7 @@ class ColVector[RowT, ValT](val rowS: Symbol, val valS: Symbol, inPipe: Pipe, va } def sum(implicit mon: Monoid[ValT]): Scalar[ValT] = { - val scalarPipe = pipe.groupAll{ + val scalarPipe = pipe.groupAll { _.reduce(valS -> valS) { (left: Tuple1[ValT], right: Tuple1[ValT]) => Tuple1(mon.plus(left._1, right._1)) } @@ -963,14 +1293,16 @@ class ColVector[RowT, ValT](val rowS: Symbol, val valS: Symbol, inPipe: Pipe, va def L0Normalize(implicit ev: =:=[ValT, Double]): ColVector[RowT, ValT] = { val normedMatrix = this.toMatrix(0).colL0Normalize - new ColVector(normedMatrix.rowSym, + new ColVector( + normedMatrix.rowSym, normedMatrix.valSym, normedMatrix.pipe.project(normedMatrix.rowSym, normedMatrix.valSym)) } def L1Normalize(implicit ev: =:=[ValT, Double]): ColVector[RowT, ValT] = { val normedMatrix = this.toMatrix(0).colL1Normalize - new ColVector(normedMatrix.rowSym, + new ColVector( + normedMatrix.rowSym, normedMatrix.valSym, normedMatrix.pipe.project(normedMatrix.rowSym, normedMatrix.valSym)) } @@ -978,29 +1310,40 @@ class ColVector[RowT, ValT](val rowS: Symbol, val valS: Symbol, inPipe: Pipe, va def topElems(k: Int)(implicit ord: Ordering[ValT]): ColVector[RowT, ValT] = { if (k < 1000) { topWithTiny(k) } else { - val newPipe = pipe.groupAll{ - _ - .sortBy(valS) - .reverse - .take(k) - }.project(rowS, valS) - new ColVector[RowT, ValT](rowS, valS, newPipe, sizeH.setCols(1L).setRows(k)) + val newPipe = pipe + .groupAll { + _.sortBy(valS).reverse + .take(k) + } + .project(rowS, valS) + new ColVector[RowT, ValT]( + rowS, + valS, + newPipe, + sizeH.setCols(1L).setRows(k)) } } - protected def topWithTiny(k: Int)(implicit ord: Ordering[ValT]): ColVector[RowT, ValT] = { + protected def topWithTiny(k: Int)( + implicit ord: Ordering[ValT]): ColVector[RowT, ValT] = { val topSym = Symbol(rowS.name + "_topK") - val newPipe = pipe.groupAll{ - _ - .sortWithTake((rowS, valS) -> 'top_vals, k) ((t0: (RowT, ValT), t1: (RowT, ValT)) => ord.gt(t0._2, t1._2)) - } + val newPipe = pipe + .groupAll { + _.sortWithTake((rowS, valS) -> 'top_vals, k)( + (t0: (RowT, ValT), t1: (RowT, ValT)) => ord.gt(t0._2, t1._2)) + } .flatMap('top_vals -> (topSym, valS)) { imp: List[(RowT, ValT)] => imp } - new ColVector[RowT, ValT](topSym, valS, newPipe, sizeH.setCols(1L).setRows(k)) + new ColVector[RowT, ValT]( + topSym, + valS, + newPipe, + sizeH.setCols(1L).setRows(k)) } def toMatrix[ColT](colIdx: ColT): Matrix[RowT, ColT, ValT] = { val colSym = newSymbol(Set(rowS, valS), 'col) //Matrix.newSymbol(Set(rowS, valS), 'col) - val newPipe = inPipe.map(() -> colSym){ u: Unit => colIdx } + val newPipe = inPipe + .map(() -> colSym) { u: Unit => colIdx } .project(rowS, colSym, valS) new Matrix[RowT, ColT, ValT](rowS, colSym, valS, newPipe, sizeH.setCols(1L)) } @@ -1011,23 +1354,30 @@ class ColVector[RowT, ValT](val rowS: Symbol, val valS: Symbol, inPipe: Pipe, va } /** - * Write optionally renaming val fields to the given fields - * then return this. - */ - def write(src: Source, outFields: Fields = Fields.NONE)(implicit fd: FlowDef, mode: Mode) = { + * Write optionally renaming val fields to the given fields + * then return this. + */ + def write(src: Source, outFields: Fields = Fields.NONE)( + implicit fd: FlowDef, + mode: Mode) = { writePipe(src, outFields) this } } /** - * BlockMatrix is 3 dimensional matrix where the rows are grouped - * It is useful for when we want to multiply groups of vectors only between themselves. - * For example, grouping users by countries and calculating products only between users from the same country - */ -class BlockMatrix[RowT, GroupT, ColT, ValT](private val mat: Matrix[RowT, GroupT, Map[ColT, ValT]]) { - def dotProd[RowT2](that: BlockMatrix[GroupT, RowT2, ColT, ValT])(implicit prod: MatrixProduct[Matrix[RowT, GroupT, Map[ColT, ValT]], Matrix[GroupT, RowT2, Map[ColT, ValT]], Matrix[RowT, RowT2, Map[ColT, ValT]]], - mon: Monoid[ValT]): Matrix[RowT, RowT2, ValT] = { + * BlockMatrix is 3 dimensional matrix where the rows are grouped + * It is useful for when we want to multiply groups of vectors only between themselves. + * For example, grouping users by countries and calculating products only between users from the same country + */ +class BlockMatrix[RowT, GroupT, ColT, ValT]( + private val mat: Matrix[RowT, GroupT, Map[ColT, ValT]]) { + def dotProd[RowT2](that: BlockMatrix[GroupT, RowT2, ColT, ValT])( + implicit prod: MatrixProduct[ + Matrix[RowT, GroupT, Map[ColT, ValT]], + Matrix[GroupT, RowT2, Map[ColT, ValT]], + Matrix[RowT, RowT2, Map[ColT, ValT]]], + mon: Monoid[ValT]): Matrix[RowT, RowT2, ValT] = { prod(mat, that.mat).mapValues(_.values.foldLeft(mon.zero)(mon.plus)) } diff --git a/repos/scalding/scalding-core/src/main/scala/com/twitter/scalding/mathematics/MatrixProduct.scala b/repos/scalding/scalding-core/src/main/scala/com/twitter/scalding/mathematics/MatrixProduct.scala index d06c5b61834..22367f6554f 100644 --- a/repos/scalding/scalding-core/src/main/scala/com/twitter/scalding/mathematics/MatrixProduct.scala +++ b/repos/scalding/scalding-core/src/main/scala/com/twitter/scalding/mathematics/MatrixProduct.scala @@ -12,15 +12,14 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -*/ + */ package com.twitter.scalding.mathematics /** - * Handles the implementation of various versions of MatrixProducts - */ - -import com.twitter.algebird.{ Ring, Monoid, Group, Field } + * Handles the implementation of various versions of MatrixProducts + */ +import com.twitter.algebird.{Ring, Monoid, Group, Field} import com.twitter.scalding.RichPipe import com.twitter.scalding.Dsl._ @@ -31,31 +30,43 @@ import scala.math.Ordering import scala.annotation.tailrec /** - * Abstracts the approach taken to join the two matrices - */ + * Abstracts the approach taken to join the two matrices + */ abstract class MatrixJoiner extends java.io.Serializable { def apply(left: Pipe, joinFields: (Fields, Fields), right: Pipe): Pipe } case object AnyToTiny extends MatrixJoiner { - override def apply(left: Pipe, joinFields: (Fields, Fields), right: Pipe): Pipe = { + override def apply( + left: Pipe, + joinFields: (Fields, Fields), + right: Pipe): Pipe = { RichPipe(left).joinWithTiny(joinFields, right) } } class BigToSmall(red: Int) extends MatrixJoiner { - override def apply(left: Pipe, joinFields: (Fields, Fields), right: Pipe): Pipe = { + override def apply( + left: Pipe, + joinFields: (Fields, Fields), + right: Pipe): Pipe = { RichPipe(left).joinWithSmaller(joinFields, right, reducers = red) } } case object TinyToAny extends MatrixJoiner { - override def apply(left: Pipe, joinFields: (Fields, Fields), right: Pipe): Pipe = { + override def apply( + left: Pipe, + joinFields: (Fields, Fields), + right: Pipe): Pipe = { val reversed = (joinFields._2, joinFields._1) RichPipe(right).joinWithTiny(reversed, left) } } class SmallToBig(red: Int) extends MatrixJoiner { - override def apply(left: Pipe, joinFields: (Fields, Fields), right: Pipe): Pipe = { + override def apply( + left: Pipe, + joinFields: (Fields, Fields), + right: Pipe): Pipe = { RichPipe(left).joinWithLarger(joinFields, right, reducers = red) } } @@ -81,249 +92,494 @@ trait MatrixProduct[Left, Right, Result] extends java.io.Serializable { } /** - * TODO: Muliplication is the expensive stuff. We need to optimize the methods below: - * This object holds the implicits to handle matrix products between various types - */ + * TODO: Muliplication is the expensive stuff. We need to optimize the methods below: + * This object holds the implicits to handle matrix products between various types + */ object MatrixProduct extends java.io.Serializable { // These are VARS, so you can set them before you start: var maxTinyJoin = 100000L // Bigger than this, and we use joinWithSmaller var maxReducers = 200 def numOfReducers(hint: SizeHint) = { - hint.total.map { tot => - // + 1L is to make sure there is at least once reducer - (tot / MatrixProduct.maxTinyJoin + 1L).toInt min MatrixProduct.maxReducers - }.getOrElse(-1) + hint.total + .map { tot => + // + 1L is to make sure there is at least once reducer + (tot / MatrixProduct.maxTinyJoin + 1L).toInt min MatrixProduct.maxReducers + } + .getOrElse(-1) } def getJoiner(leftSize: SizeHint, rightSize: SizeHint): MatrixJoiner = { val newHint = leftSize * rightSize if (SizeHintOrdering.lteq(leftSize, rightSize)) { // If leftsize is definite: - leftSize.total.map { t => if (t < maxTinyJoin) TinyToAny else new SmallToBig(numOfReducers(newHint)) } + leftSize.total + .map { t => + if (t < maxTinyJoin) TinyToAny + else new SmallToBig(numOfReducers(newHint)) + } // Else just assume the right is smaller, but both are unknown: .getOrElse(new BigToSmall(numOfReducers(newHint))) } else { // left > right - rightSize.total.map { rs => - if (rs < maxTinyJoin) AnyToTiny else new BigToSmall(numOfReducers(newHint)) - }.getOrElse(new BigToSmall(numOfReducers(newHint))) + rightSize.total + .map { rs => + if (rs < maxTinyJoin) AnyToTiny + else new BigToSmall(numOfReducers(newHint)) + } + .getOrElse(new BigToSmall(numOfReducers(newHint))) } } def getCrosser(rightSize: SizeHint): MatrixCrosser = - rightSize.total.map { t => if (t < maxTinyJoin) AnyCrossTiny else AnyCrossSmall } + rightSize.total + .map { t => if (t < maxTinyJoin) AnyCrossTiny else AnyCrossSmall } .getOrElse(AnyCrossSmall) - implicit def literalScalarRightProduct[Row, Col, ValT](implicit ring: Ring[ValT]): MatrixProduct[Matrix[Row, Col, ValT], LiteralScalar[ValT], Matrix[Row, Col, ValT]] = - new MatrixProduct[Matrix[Row, Col, ValT], LiteralScalar[ValT], Matrix[Row, Col, ValT]] { + implicit def literalScalarRightProduct[Row, Col, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + Matrix[Row, Col, ValT], + LiteralScalar[ValT], + Matrix[Row, Col, ValT]] = + new MatrixProduct[ + Matrix[Row, Col, ValT], + LiteralScalar[ValT], + Matrix[Row, Col, ValT]] { def apply(left: Matrix[Row, Col, ValT], right: LiteralScalar[ValT]) = { val newPipe = left.pipe.map(left.valSym -> left.valSym) { (v: ValT) => ring.times(v, right.value) } - new Matrix[Row, Col, ValT](left.rowSym, left.colSym, left.valSym, newPipe, left.sizeHint) + new Matrix[Row, Col, ValT]( + left.rowSym, + left.colSym, + left.valSym, + newPipe, + left.sizeHint) } } - implicit def literalRightProduct[Row, Col, ValT](implicit ring: Ring[ValT]): MatrixProduct[Matrix[Row, Col, ValT], ValT, Matrix[Row, Col, ValT]] = + implicit def literalRightProduct[Row, Col, ValT](implicit ring: Ring[ValT]) + : MatrixProduct[Matrix[Row, Col, ValT], ValT, Matrix[Row, Col, ValT]] = new MatrixProduct[Matrix[Row, Col, ValT], ValT, Matrix[Row, Col, ValT]] { def apply(left: Matrix[Row, Col, ValT], right: ValT) = { val newPipe = left.pipe.map(left.valSym -> left.valSym) { (v: ValT) => ring.times(v, right) } - new Matrix[Row, Col, ValT](left.rowSym, left.colSym, left.valSym, newPipe, left.sizeHint) + new Matrix[Row, Col, ValT]( + left.rowSym, + left.colSym, + left.valSym, + newPipe, + left.sizeHint) } } - implicit def literalScalarLeftProduct[Row, Col, ValT](implicit ring: Ring[ValT]): MatrixProduct[LiteralScalar[ValT], Matrix[Row, Col, ValT], Matrix[Row, Col, ValT]] = - new MatrixProduct[LiteralScalar[ValT], Matrix[Row, Col, ValT], Matrix[Row, Col, ValT]] { + implicit def literalScalarLeftProduct[Row, Col, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + LiteralScalar[ValT], + Matrix[Row, Col, ValT], + Matrix[Row, Col, ValT]] = + new MatrixProduct[ + LiteralScalar[ValT], + Matrix[Row, Col, ValT], + Matrix[Row, Col, ValT]] { def apply(left: LiteralScalar[ValT], right: Matrix[Row, Col, ValT]) = { - val newPipe = right.pipe.map(right.valSym -> right.valSym) { (v: ValT) => - ring.times(left.value, v) + val newPipe = right.pipe.map(right.valSym -> right.valSym) { + (v: ValT) => ring.times(left.value, v) } - new Matrix[Row, Col, ValT](right.rowSym, right.colSym, right.valSym, newPipe, right.sizeHint) + new Matrix[Row, Col, ValT]( + right.rowSym, + right.colSym, + right.valSym, + newPipe, + right.sizeHint) } } - implicit def scalarPipeRightProduct[Row, Col, ValT](implicit ring: Ring[ValT]): MatrixProduct[Matrix[Row, Col, ValT], Scalar[ValT], Matrix[Row, Col, ValT]] = - new MatrixProduct[Matrix[Row, Col, ValT], Scalar[ValT], Matrix[Row, Col, ValT]] { + implicit def scalarPipeRightProduct[Row, Col, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + Matrix[Row, Col, ValT], + Scalar[ValT], + Matrix[Row, Col, ValT]] = + new MatrixProduct[ + Matrix[Row, Col, ValT], + Scalar[ValT], + Matrix[Row, Col, ValT]] { def apply(left: Matrix[Row, Col, ValT], right: Scalar[ValT]) = { - left.nonZerosWith(right).mapValues({ leftRight => - val (left, right) = leftRight - ring.times(left, right) - })(ring) + left + .nonZerosWith(right) + .mapValues({ leftRight => + val (left, right) = leftRight + ring.times(left, right) + })(ring) } } - implicit def scalarPipeLeftProduct[Row, Col, ValT](implicit ring: Ring[ValT]): MatrixProduct[Scalar[ValT], Matrix[Row, Col, ValT], Matrix[Row, Col, ValT]] = - new MatrixProduct[Scalar[ValT], Matrix[Row, Col, ValT], Matrix[Row, Col, ValT]] { + implicit def scalarPipeLeftProduct[Row, Col, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + Scalar[ValT], + Matrix[Row, Col, ValT], + Matrix[Row, Col, ValT]] = + new MatrixProduct[ + Scalar[ValT], + Matrix[Row, Col, ValT], + Matrix[Row, Col, ValT]] { def apply(left: Scalar[ValT], right: Matrix[Row, Col, ValT]) = { - right.nonZerosWith(left).mapValues({ matScal => - val (matVal, scalarVal) = matScal - ring.times(scalarVal, matVal) - })(ring) + right + .nonZerosWith(left) + .mapValues({ matScal => + val (matVal, scalarVal) = matScal + ring.times(scalarVal, matVal) + })(ring) } } - implicit def scalarRowRightProduct[Col, ValT](implicit ring: Ring[ValT]): MatrixProduct[RowVector[Col, ValT], Scalar[ValT], RowVector[Col, ValT]] = + implicit def scalarRowRightProduct[Col, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + RowVector[Col, ValT], + Scalar[ValT], + RowVector[Col, ValT]] = new MatrixProduct[RowVector[Col, ValT], Scalar[ValT], RowVector[Col, ValT]] { - def apply(left: RowVector[Col, ValT], right: Scalar[ValT]): RowVector[Col, ValT] = { + def apply( + left: RowVector[Col, ValT], + right: Scalar[ValT]): RowVector[Col, ValT] = { val prod = left.toMatrix(0) * right - new RowVector[Col, ValT](prod.colSym, prod.valSym, prod.pipe.project(prod.colSym, prod.valSym)) + new RowVector[Col, ValT]( + prod.colSym, + prod.valSym, + prod.pipe.project(prod.colSym, prod.valSym)) } } - implicit def scalarRowLeftProduct[Col, ValT](implicit ring: Ring[ValT]): MatrixProduct[Scalar[ValT], RowVector[Col, ValT], RowVector[Col, ValT]] = + implicit def scalarRowLeftProduct[Col, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + Scalar[ValT], + RowVector[Col, ValT], + RowVector[Col, ValT]] = new MatrixProduct[Scalar[ValT], RowVector[Col, ValT], RowVector[Col, ValT]] { - def apply(left: Scalar[ValT], right: RowVector[Col, ValT]): RowVector[Col, ValT] = { + def apply( + left: Scalar[ValT], + right: RowVector[Col, ValT]): RowVector[Col, ValT] = { val prod = (right.transpose.toMatrix(0)) * left - new RowVector[Col, ValT](prod.rowSym, prod.valSym, prod.pipe.project(prod.rowSym, prod.valSym)) + new RowVector[Col, ValT]( + prod.rowSym, + prod.valSym, + prod.pipe.project(prod.rowSym, prod.valSym)) } } - implicit def scalarColRightProduct[Row, ValT](implicit ring: Ring[ValT]): MatrixProduct[ColVector[Row, ValT], Scalar[ValT], ColVector[Row, ValT]] = + implicit def scalarColRightProduct[Row, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + ColVector[Row, ValT], + Scalar[ValT], + ColVector[Row, ValT]] = new MatrixProduct[ColVector[Row, ValT], Scalar[ValT], ColVector[Row, ValT]] { - def apply(left: ColVector[Row, ValT], right: Scalar[ValT]): ColVector[Row, ValT] = { + def apply( + left: ColVector[Row, ValT], + right: Scalar[ValT]): ColVector[Row, ValT] = { val prod = left.toMatrix(0) * right - new ColVector[Row, ValT](prod.rowSym, prod.valSym, prod.pipe.project(prod.rowSym, prod.valSym)) + new ColVector[Row, ValT]( + prod.rowSym, + prod.valSym, + prod.pipe.project(prod.rowSym, prod.valSym)) } } - implicit def scalarColLeftProduct[Row, ValT](implicit ring: Ring[ValT]): MatrixProduct[Scalar[ValT], ColVector[Row, ValT], ColVector[Row, ValT]] = + implicit def scalarColLeftProduct[Row, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + Scalar[ValT], + ColVector[Row, ValT], + ColVector[Row, ValT]] = new MatrixProduct[Scalar[ValT], ColVector[Row, ValT], ColVector[Row, ValT]] { - def apply(left: Scalar[ValT], right: ColVector[Row, ValT]): ColVector[Row, ValT] = { + def apply( + left: Scalar[ValT], + right: ColVector[Row, ValT]): ColVector[Row, ValT] = { val prod = (right.toMatrix(0)) * left - new ColVector[Row, ValT](prod.rowSym, prod.valSym, prod.pipe.project(prod.rowSym, prod.valSym)) + new ColVector[Row, ValT]( + prod.rowSym, + prod.valSym, + prod.pipe.project(prod.rowSym, prod.valSym)) } } - implicit def litScalarRowRightProduct[Col, ValT](implicit ring: Ring[ValT]): MatrixProduct[RowVector[Col, ValT], LiteralScalar[ValT], RowVector[Col, ValT]] = - new MatrixProduct[RowVector[Col, ValT], LiteralScalar[ValT], RowVector[Col, ValT]] { - def apply(left: RowVector[Col, ValT], right: LiteralScalar[ValT]): RowVector[Col, ValT] = { + implicit def litScalarRowRightProduct[Col, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + RowVector[Col, ValT], + LiteralScalar[ValT], + RowVector[Col, ValT]] = + new MatrixProduct[ + RowVector[Col, ValT], + LiteralScalar[ValT], + RowVector[Col, ValT]] { + def apply( + left: RowVector[Col, ValT], + right: LiteralScalar[ValT]): RowVector[Col, ValT] = { val prod = left.toMatrix(0) * right - new RowVector[Col, ValT](prod.colSym, prod.valSym, prod.pipe.project(prod.colSym, prod.valSym)) + new RowVector[Col, ValT]( + prod.colSym, + prod.valSym, + prod.pipe.project(prod.colSym, prod.valSym)) } } - implicit def litScalarRowLeftProduct[Col, ValT](implicit ring: Ring[ValT]): MatrixProduct[LiteralScalar[ValT], RowVector[Col, ValT], RowVector[Col, ValT]] = - new MatrixProduct[LiteralScalar[ValT], RowVector[Col, ValT], RowVector[Col, ValT]] { - def apply(left: LiteralScalar[ValT], right: RowVector[Col, ValT]): RowVector[Col, ValT] = { + implicit def litScalarRowLeftProduct[Col, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + LiteralScalar[ValT], + RowVector[Col, ValT], + RowVector[Col, ValT]] = + new MatrixProduct[ + LiteralScalar[ValT], + RowVector[Col, ValT], + RowVector[Col, ValT]] { + def apply( + left: LiteralScalar[ValT], + right: RowVector[Col, ValT]): RowVector[Col, ValT] = { val prod = (right.transpose.toMatrix(0)) * left - new RowVector[Col, ValT](prod.rowSym, prod.valSym, prod.pipe.project(prod.rowSym, prod.valSym)) + new RowVector[Col, ValT]( + prod.rowSym, + prod.valSym, + prod.pipe.project(prod.rowSym, prod.valSym)) } } - implicit def litScalarColRightProduct[Row, ValT](implicit ring: Ring[ValT]): MatrixProduct[ColVector[Row, ValT], LiteralScalar[ValT], ColVector[Row, ValT]] = - new MatrixProduct[ColVector[Row, ValT], LiteralScalar[ValT], ColVector[Row, ValT]] { - def apply(left: ColVector[Row, ValT], right: LiteralScalar[ValT]): ColVector[Row, ValT] = { + implicit def litScalarColRightProduct[Row, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + ColVector[Row, ValT], + LiteralScalar[ValT], + ColVector[Row, ValT]] = + new MatrixProduct[ + ColVector[Row, ValT], + LiteralScalar[ValT], + ColVector[Row, ValT]] { + def apply( + left: ColVector[Row, ValT], + right: LiteralScalar[ValT]): ColVector[Row, ValT] = { val prod = left.toMatrix(0) * right - new ColVector[Row, ValT](prod.rowSym, prod.valSym, prod.pipe.project(prod.rowSym, prod.valSym)) + new ColVector[Row, ValT]( + prod.rowSym, + prod.valSym, + prod.pipe.project(prod.rowSym, prod.valSym)) } } - implicit def litScalarColLeftProduct[Row, ValT](implicit ring: Ring[ValT]): MatrixProduct[LiteralScalar[ValT], ColVector[Row, ValT], ColVector[Row, ValT]] = - new MatrixProduct[LiteralScalar[ValT], ColVector[Row, ValT], ColVector[Row, ValT]] { - def apply(left: LiteralScalar[ValT], right: ColVector[Row, ValT]): ColVector[Row, ValT] = { + implicit def litScalarColLeftProduct[Row, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + LiteralScalar[ValT], + ColVector[Row, ValT], + ColVector[Row, ValT]] = + new MatrixProduct[ + LiteralScalar[ValT], + ColVector[Row, ValT], + ColVector[Row, ValT]] { + def apply( + left: LiteralScalar[ValT], + right: ColVector[Row, ValT]): ColVector[Row, ValT] = { val prod = (right.toMatrix(0)) * left - new ColVector[Row, ValT](prod.rowSym, prod.valSym, prod.pipe.project(prod.rowSym, prod.valSym)) + new ColVector[Row, ValT]( + prod.rowSym, + prod.valSym, + prod.pipe.project(prod.rowSym, prod.valSym)) } } - implicit def scalarDiagRightProduct[Row, ValT](implicit ring: Ring[ValT]): MatrixProduct[DiagonalMatrix[Row, ValT], Scalar[ValT], DiagonalMatrix[Row, ValT]] = - new MatrixProduct[DiagonalMatrix[Row, ValT], Scalar[ValT], DiagonalMatrix[Row, ValT]] { - def apply(left: DiagonalMatrix[Row, ValT], right: Scalar[ValT]): DiagonalMatrix[Row, ValT] = { + implicit def scalarDiagRightProduct[Row, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + DiagonalMatrix[Row, ValT], + Scalar[ValT], + DiagonalMatrix[Row, ValT]] = + new MatrixProduct[ + DiagonalMatrix[Row, ValT], + Scalar[ValT], + DiagonalMatrix[Row, ValT]] { + def apply( + left: DiagonalMatrix[Row, ValT], + right: Scalar[ValT]): DiagonalMatrix[Row, ValT] = { val prod = (left.toCol.toMatrix(0)) * right - new DiagonalMatrix[Row, ValT](prod.rowSym, prod.valSym, prod.pipe.project(prod.rowSym, prod.valSym)) + new DiagonalMatrix[Row, ValT]( + prod.rowSym, + prod.valSym, + prod.pipe.project(prod.rowSym, prod.valSym)) } } - implicit def scalarDiagLeftProduct[Row, ValT](implicit ring: Ring[ValT]): MatrixProduct[Scalar[ValT], DiagonalMatrix[Row, ValT], DiagonalMatrix[Row, ValT]] = - new MatrixProduct[Scalar[ValT], DiagonalMatrix[Row, ValT], DiagonalMatrix[Row, ValT]] { - def apply(left: Scalar[ValT], right: DiagonalMatrix[Row, ValT]): DiagonalMatrix[Row, ValT] = { + implicit def scalarDiagLeftProduct[Row, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + Scalar[ValT], + DiagonalMatrix[Row, ValT], + DiagonalMatrix[Row, ValT]] = + new MatrixProduct[ + Scalar[ValT], + DiagonalMatrix[Row, ValT], + DiagonalMatrix[Row, ValT]] { + def apply( + left: Scalar[ValT], + right: DiagonalMatrix[Row, ValT]): DiagonalMatrix[Row, ValT] = { val prod = (right.toCol.toMatrix(0)) * left - new DiagonalMatrix[Row, ValT](prod.rowSym, prod.valSym, prod.pipe.project(prod.rowSym, prod.valSym)) + new DiagonalMatrix[Row, ValT]( + prod.rowSym, + prod.valSym, + prod.pipe.project(prod.rowSym, prod.valSym)) } } - implicit def litScalarDiagRightProduct[Col, ValT](implicit ring: Ring[ValT]): MatrixProduct[DiagonalMatrix[Col, ValT], LiteralScalar[ValT], DiagonalMatrix[Col, ValT]] = - new MatrixProduct[DiagonalMatrix[Col, ValT], LiteralScalar[ValT], DiagonalMatrix[Col, ValT]] { - def apply(left: DiagonalMatrix[Col, ValT], right: LiteralScalar[ValT]): DiagonalMatrix[Col, ValT] = { + implicit def litScalarDiagRightProduct[Col, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + DiagonalMatrix[Col, ValT], + LiteralScalar[ValT], + DiagonalMatrix[Col, ValT]] = + new MatrixProduct[ + DiagonalMatrix[Col, ValT], + LiteralScalar[ValT], + DiagonalMatrix[Col, ValT]] { + def apply( + left: DiagonalMatrix[Col, ValT], + right: LiteralScalar[ValT]): DiagonalMatrix[Col, ValT] = { val prod = (left.toRow.toMatrix(0)) * right - new DiagonalMatrix[Col, ValT](prod.colSym, prod.valSym, prod.pipe.project(prod.colSym, prod.valSym)) + new DiagonalMatrix[Col, ValT]( + prod.colSym, + prod.valSym, + prod.pipe.project(prod.colSym, prod.valSym)) } } - implicit def litScalarDiagLeftProduct[Col, ValT](implicit ring: Ring[ValT]): MatrixProduct[LiteralScalar[ValT], DiagonalMatrix[Col, ValT], DiagonalMatrix[Col, ValT]] = - new MatrixProduct[LiteralScalar[ValT], DiagonalMatrix[Col, ValT], DiagonalMatrix[Col, ValT]] { - def apply(left: LiteralScalar[ValT], right: DiagonalMatrix[Col, ValT]): DiagonalMatrix[Col, ValT] = { + implicit def litScalarDiagLeftProduct[Col, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + LiteralScalar[ValT], + DiagonalMatrix[Col, ValT], + DiagonalMatrix[Col, ValT]] = + new MatrixProduct[ + LiteralScalar[ValT], + DiagonalMatrix[Col, ValT], + DiagonalMatrix[Col, ValT]] { + def apply( + left: LiteralScalar[ValT], + right: DiagonalMatrix[Col, ValT]): DiagonalMatrix[Col, ValT] = { val prod = (right.toCol.toMatrix(0)) * left - new DiagonalMatrix[Col, ValT](prod.rowSym, prod.valSym, prod.pipe.project(prod.rowSym, prod.valSym)) + new DiagonalMatrix[Col, ValT]( + prod.rowSym, + prod.valSym, + prod.pipe.project(prod.rowSym, prod.valSym)) } } //TODO: remove in 0.9.0, only here just for compatibility. - def vectorInnerProduct[IdxT, ValT](implicit ring: Ring[ValT]): MatrixProduct[RowVector[IdxT, ValT], ColVector[IdxT, ValT], Scalar[ValT]] = + def vectorInnerProduct[IdxT, ValT](implicit ring: Ring[ValT]): MatrixProduct[ + RowVector[IdxT, ValT], + ColVector[IdxT, ValT], + Scalar[ValT]] = rowColProduct(ring) - implicit def rowColProduct[IdxT, ValT](implicit ring: Ring[ValT]): MatrixProduct[RowVector[IdxT, ValT], ColVector[IdxT, ValT], Scalar[ValT]] = - new MatrixProduct[RowVector[IdxT, ValT], ColVector[IdxT, ValT], Scalar[ValT]] { - def apply(left: RowVector[IdxT, ValT], right: ColVector[IdxT, ValT]): Scalar[ValT] = { + implicit def rowColProduct[IdxT, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + RowVector[IdxT, ValT], + ColVector[IdxT, ValT], + Scalar[ValT]] = + new MatrixProduct[ + RowVector[IdxT, ValT], + ColVector[IdxT, ValT], + Scalar[ValT]] { + def apply( + left: RowVector[IdxT, ValT], + right: ColVector[IdxT, ValT]): Scalar[ValT] = { // Normal matrix multiplication works here, but we need to convert to a Scalar - val prod = (left.toMatrix(0) * right.toMatrix(0)): Matrix[Int, Int, ValT] + val prod = + (left.toMatrix(0) * right.toMatrix(0)): Matrix[Int, Int, ValT] new Scalar[ValT](prod.valSym, prod.pipe.project(prod.valSym)) } } - implicit def rowMatrixProduct[Common, ColR, ValT](implicit ring: Ring[ValT]): MatrixProduct[RowVector[Common, ValT], Matrix[Common, ColR, ValT], RowVector[ColR, ValT]] = - new MatrixProduct[RowVector[Common, ValT], Matrix[Common, ColR, ValT], RowVector[ColR, ValT]] { - def apply(left: RowVector[Common, ValT], right: Matrix[Common, ColR, ValT]) = { + implicit def rowMatrixProduct[Common, ColR, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + RowVector[Common, ValT], + Matrix[Common, ColR, ValT], + RowVector[ColR, ValT]] = + new MatrixProduct[ + RowVector[Common, ValT], + Matrix[Common, ColR, ValT], + RowVector[ColR, ValT]] { + def apply( + left: RowVector[Common, ValT], + right: Matrix[Common, ColR, ValT]) = { (left.toMatrix(true) * right).getRow(true) } } - implicit def matrixColProduct[RowR, Common, ValT](implicit ring: Ring[ValT]): MatrixProduct[Matrix[RowR, Common, ValT], ColVector[Common, ValT], ColVector[RowR, ValT]] = - new MatrixProduct[Matrix[RowR, Common, ValT], ColVector[Common, ValT], ColVector[RowR, ValT]] { - def apply(left: Matrix[RowR, Common, ValT], right: ColVector[Common, ValT]) = { + implicit def matrixColProduct[RowR, Common, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + Matrix[RowR, Common, ValT], + ColVector[Common, ValT], + ColVector[RowR, ValT]] = + new MatrixProduct[ + Matrix[RowR, Common, ValT], + ColVector[Common, ValT], + ColVector[RowR, ValT]] { + def apply( + left: Matrix[RowR, Common, ValT], + right: ColVector[Common, ValT]) = { (left * right.toMatrix(true)).getCol(true) } } - implicit def vectorOuterProduct[RowT, ColT, ValT](implicit ring: Ring[ValT]): MatrixProduct[ColVector[RowT, ValT], RowVector[ColT, ValT], Matrix[RowT, ColT, ValT]] = - new MatrixProduct[ColVector[RowT, ValT], RowVector[ColT, ValT], Matrix[RowT, ColT, ValT]] { - def apply(left: ColVector[RowT, ValT], right: RowVector[ColT, ValT]): Matrix[RowT, ColT, ValT] = { + implicit def vectorOuterProduct[RowT, ColT, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + ColVector[RowT, ValT], + RowVector[ColT, ValT], + Matrix[RowT, ColT, ValT]] = + new MatrixProduct[ + ColVector[RowT, ValT], + RowVector[ColT, ValT], + Matrix[RowT, ColT, ValT]] { + def apply( + left: ColVector[RowT, ValT], + right: RowVector[ColT, ValT]): Matrix[RowT, ColT, ValT] = { val (newRightFields, newRightPipe) = ensureUniqueFields( (left.rowS, left.valS), (right.colS, right.valS), right.pipe) val newColSym = Symbol(right.colS.name + "_newCol") val newHint = left.sizeH * right.sizeH - val productPipe = Matrix.filterOutZeros(left.valS, ring) { - getCrosser(right.sizeH) - .apply(left.pipe, newRightPipe) - .map(left.valS.append(getField(newRightFields, 1)) -> left.valS) { pair: (ValT, ValT) => - ring.times(pair._1, pair._2) - } - } + val productPipe = Matrix + .filterOutZeros(left.valS, ring) { + getCrosser(right.sizeH) + .apply(left.pipe, newRightPipe) + .map(left.valS.append(getField(newRightFields, 1)) -> left.valS) { + pair: (ValT, ValT) => ring.times(pair._1, pair._2) + } + } .rename(getField(newRightFields, 0) -> newColSym) - new Matrix[RowT, ColT, ValT](left.rowS, newColSym, left.valS, productPipe, newHint) + new Matrix[RowT, ColT, ValT]( + left.rowS, + newColSym, + left.valS, + productPipe, + newHint) } } - implicit def standardMatrixProduct[RowL, Common, ColR, ValT](implicit ring: Ring[ValT]): MatrixProduct[Matrix[RowL, Common, ValT], Matrix[Common, ColR, ValT], Matrix[RowL, ColR, ValT]] = - new MatrixProduct[Matrix[RowL, Common, ValT], Matrix[Common, ColR, ValT], Matrix[RowL, ColR, ValT]] { - def apply(left: Matrix[RowL, Common, ValT], right: Matrix[Common, ColR, ValT]) = { + implicit def standardMatrixProduct[RowL, Common, ColR, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + Matrix[RowL, Common, ValT], + Matrix[Common, ColR, ValT], + Matrix[RowL, ColR, ValT]] = + new MatrixProduct[ + Matrix[RowL, Common, ValT], + Matrix[Common, ColR, ValT], + Matrix[RowL, ColR, ValT]] { + def apply( + left: Matrix[RowL, Common, ValT], + right: Matrix[Common, ColR, ValT]) = { val (newRightFields, newRightPipe) = ensureUniqueFields( (left.rowSym, left.colSym, left.valSym), (right.rowSym, right.colSym, right.valSym), @@ -332,32 +588,53 @@ object MatrixProduct extends java.io.Serializable { // Hint of groupBy reducer size val grpReds = numOfReducers(newHint) - val productPipe = Matrix.filterOutZeros(left.valSym, ring) { - getJoiner(left.sizeHint, right.sizeHint) + val productPipe = Matrix + .filterOutZeros(left.valSym, ring) { + getJoiner(left.sizeHint, right.sizeHint) // TODO: we should use the size hints to set the number of reducers: - .apply(left.pipe, (left.colSym -> getField(newRightFields, 0)), newRightPipe) - // Do the product: - .map((left.valSym.append(getField(newRightFields, 2))) -> left.valSym) { pair: (ValT, ValT) => - ring.times(pair._1, pair._2) - } - .groupBy(left.rowSym.append(getField(newRightFields, 1))) { - // We should use the size hints to set the number of reducers here - _.reduce(left.valSym) { (x: Tuple1[ValT], y: Tuple1[ValT]) => Tuple1(ring.plus(x._1, y._1)) } + .apply( + left.pipe, + (left.colSym -> getField(newRightFields, 0)), + newRightPipe) + // Do the product: + .map((left.valSym + .append(getField(newRightFields, 2))) -> left.valSym) { + pair: (ValT, ValT) => ring.times(pair._1, pair._2) + } + .groupBy(left.rowSym.append(getField(newRightFields, 1))) { + // We should use the size hints to set the number of reducers here + _.reduce(left.valSym) { (x: Tuple1[ValT], y: Tuple1[ValT]) => + Tuple1(ring.plus(x._1, y._1)) + } // There is a low chance that many (row,col) keys are co-located, and the keyspace // is likely huge, just push to reducers .forceToReducers - .reducers(grpReds) - } - } + .reducers(grpReds) + } + } // Keep the names from the left: .rename(getField(newRightFields, 1) -> left.colSym) - new Matrix[RowL, ColR, ValT](left.rowSym, left.colSym, left.valSym, productPipe, newHint) + new Matrix[RowL, ColR, ValT]( + left.rowSym, + left.colSym, + left.valSym, + productPipe, + newHint) } } - implicit def diagMatrixProduct[RowT, ColT, ValT](implicit ring: Ring[ValT]): MatrixProduct[DiagonalMatrix[RowT, ValT], Matrix[RowT, ColT, ValT], Matrix[RowT, ColT, ValT]] = - new MatrixProduct[DiagonalMatrix[RowT, ValT], Matrix[RowT, ColT, ValT], Matrix[RowT, ColT, ValT]] { - def apply(left: DiagonalMatrix[RowT, ValT], right: Matrix[RowT, ColT, ValT]) = { + implicit def diagMatrixProduct[RowT, ColT, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + DiagonalMatrix[RowT, ValT], + Matrix[RowT, ColT, ValT], + Matrix[RowT, ColT, ValT]] = + new MatrixProduct[ + DiagonalMatrix[RowT, ValT], + Matrix[RowT, ColT, ValT], + Matrix[RowT, ColT, ValT]] { + def apply( + left: DiagonalMatrix[RowT, ValT], + right: Matrix[RowT, ColT, ValT]) = { val (newRightFields, newRightPipe) = ensureUniqueFields( (left.idxSym, left.valSym), (right.rowSym, right.colSym, right.valSym), @@ -365,61 +642,116 @@ object MatrixProduct extends java.io.Serializable { val newHint = left.sizeHint * right.sizeHint val productPipe = Matrix.filterOutZeros(right.valSym, ring) { getJoiner(left.sizeHint, right.sizeHint) - // TODO: we should use the size hints to set the number of reducers: - .apply(left.pipe, (left.idxSym -> getField(newRightFields, 0)), newRightPipe) + // TODO: we should use the size hints to set the number of reducers: + .apply( + left.pipe, + (left.idxSym -> getField(newRightFields, 0)), + newRightPipe) // Do the product: - .map((left.valSym.append(getField(newRightFields, 2))) -> getField(newRightFields, 2)) { pair: (ValT, ValT) => - ring.times(pair._1, pair._2) - } + .map( + (left.valSym.append(getField(newRightFields, 2))) -> getField( + newRightFields, + 2)) { pair: (ValT, ValT) => ring.times(pair._1, pair._2) } // Keep the names from the right: .project(newRightFields) - .rename(newRightFields -> (right.rowSym, right.colSym, right.valSym)) + .rename( + newRightFields -> (right.rowSym, right.colSym, right.valSym)) } - new Matrix[RowT, ColT, ValT](right.rowSym, right.colSym, right.valSym, productPipe, newHint) + new Matrix[RowT, ColT, ValT]( + right.rowSym, + right.colSym, + right.valSym, + productPipe, + newHint) } } - implicit def matrixDiagProduct[RowT, ColT, ValT](implicit ring: Ring[ValT]): MatrixProduct[Matrix[RowT, ColT, ValT], DiagonalMatrix[ColT, ValT], Matrix[RowT, ColT, ValT]] = - new MatrixProduct[Matrix[RowT, ColT, ValT], DiagonalMatrix[ColT, ValT], Matrix[RowT, ColT, ValT]] { - def apply(left: Matrix[RowT, ColT, ValT], right: DiagonalMatrix[ColT, ValT]) = { + implicit def matrixDiagProduct[RowT, ColT, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + Matrix[RowT, ColT, ValT], + DiagonalMatrix[ColT, ValT], + Matrix[RowT, ColT, ValT]] = + new MatrixProduct[ + Matrix[RowT, ColT, ValT], + DiagonalMatrix[ColT, ValT], + Matrix[RowT, ColT, ValT]] { + def apply( + left: Matrix[RowT, ColT, ValT], + right: DiagonalMatrix[ColT, ValT]) = { // (A * B) = (B^T * A^T)^T // note diagonal^T = diagonal (right * (left.transpose)).transpose } } - implicit def diagDiagProduct[IdxT, ValT](implicit ring: Ring[ValT]): MatrixProduct[DiagonalMatrix[IdxT, ValT], DiagonalMatrix[IdxT, ValT], DiagonalMatrix[IdxT, ValT]] = - new MatrixProduct[DiagonalMatrix[IdxT, ValT], DiagonalMatrix[IdxT, ValT], DiagonalMatrix[IdxT, ValT]] { - def apply(left: DiagonalMatrix[IdxT, ValT], right: DiagonalMatrix[IdxT, ValT]) = { + implicit def diagDiagProduct[IdxT, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + DiagonalMatrix[IdxT, ValT], + DiagonalMatrix[IdxT, ValT], + DiagonalMatrix[IdxT, ValT]] = + new MatrixProduct[ + DiagonalMatrix[IdxT, ValT], + DiagonalMatrix[IdxT, ValT], + DiagonalMatrix[IdxT, ValT]] { + def apply( + left: DiagonalMatrix[IdxT, ValT], + right: DiagonalMatrix[IdxT, ValT]) = { val (newRightFields, newRightPipe) = ensureUniqueFields( (left.idxSym, left.valSym), (right.idxSym, right.valSym), right.pipe) val newHint = left.sizeHint * right.sizeHint - val productPipe = Matrix.filterOutZeros(left.valSym, ring) { - getJoiner(left.sizeHint, right.sizeHint) + val productPipe = Matrix + .filterOutZeros(left.valSym, ring) { + getJoiner(left.sizeHint, right.sizeHint) // TODO: we should use the size hints to set the number of reducers: - .apply(left.pipe, (left.idxSym -> getField(newRightFields, 0)), newRightPipe) - // Do the product: - .map((left.valSym.append(getField(newRightFields, 1))) -> left.valSym) { pair: (ValT, ValT) => - ring.times(pair._1, pair._2) - } - } + .apply( + left.pipe, + (left.idxSym -> getField(newRightFields, 0)), + newRightPipe) + // Do the product: + .map((left.valSym + .append(getField(newRightFields, 1))) -> left.valSym) { + pair: (ValT, ValT) => ring.times(pair._1, pair._2) + } + } // Keep the names from the left: .project(left.idxSym, left.valSym) - new DiagonalMatrix[IdxT, ValT](left.idxSym, left.valSym, productPipe, newHint) + new DiagonalMatrix[IdxT, ValT]( + left.idxSym, + left.valSym, + productPipe, + newHint) } } - implicit def diagColProduct[IdxT, ValT](implicit ring: Ring[ValT]): MatrixProduct[DiagonalMatrix[IdxT, ValT], ColVector[IdxT, ValT], ColVector[IdxT, ValT]] = - new MatrixProduct[DiagonalMatrix[IdxT, ValT], ColVector[IdxT, ValT], ColVector[IdxT, ValT]] { - def apply(left: DiagonalMatrix[IdxT, ValT], right: ColVector[IdxT, ValT]) = { + implicit def diagColProduct[IdxT, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + DiagonalMatrix[IdxT, ValT], + ColVector[IdxT, ValT], + ColVector[IdxT, ValT]] = + new MatrixProduct[ + DiagonalMatrix[IdxT, ValT], + ColVector[IdxT, ValT], + ColVector[IdxT, ValT]] { + def apply( + left: DiagonalMatrix[IdxT, ValT], + right: ColVector[IdxT, ValT]) = { (left * (right.diag)).toCol } } - implicit def rowDiagProduct[IdxT, ValT](implicit ring: Ring[ValT]): MatrixProduct[RowVector[IdxT, ValT], DiagonalMatrix[IdxT, ValT], RowVector[IdxT, ValT]] = - new MatrixProduct[RowVector[IdxT, ValT], DiagonalMatrix[IdxT, ValT], RowVector[IdxT, ValT]] { - def apply(left: RowVector[IdxT, ValT], right: DiagonalMatrix[IdxT, ValT]) = { + implicit def rowDiagProduct[IdxT, ValT]( + implicit ring: Ring[ValT]): MatrixProduct[ + RowVector[IdxT, ValT], + DiagonalMatrix[IdxT, ValT], + RowVector[IdxT, ValT]] = + new MatrixProduct[ + RowVector[IdxT, ValT], + DiagonalMatrix[IdxT, ValT], + RowVector[IdxT, ValT]] { + def apply( + left: RowVector[IdxT, ValT], + right: DiagonalMatrix[IdxT, ValT]) = { ((left.diag) * right).toRow } } diff --git a/repos/scalding/scalding-core/src/test/scala/com/twitter/scalding/typed/RequireOrderedSerializationTest.scala b/repos/scalding/scalding-core/src/test/scala/com/twitter/scalding/typed/RequireOrderedSerializationTest.scala index 4d83c34cc22..903b29de718 100644 --- a/repos/scalding/scalding-core/src/test/scala/com/twitter/scalding/typed/RequireOrderedSerializationTest.scala +++ b/repos/scalding/scalding-core/src/test/scala/com/twitter/scalding/typed/RequireOrderedSerializationTest.scala @@ -12,20 +12,22 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -*/ + */ package com.twitter.scalding import com.twitter.scalding.serialization.CascadingBinaryComparator import com.twitter.scalding.serialization.OrderedSerialization import com.twitter.scalding.serialization.StringOrderedSerialization -import org.scalatest.{ Matchers, WordSpec } +import org.scalatest.{Matchers, WordSpec} class NoOrderdSerJob(args: Args) extends Job(args) { - override def config = super.config + (Config.ScaldingRequireOrderedSerialization -> "true") + override def config = + super.config + (Config.ScaldingRequireOrderedSerialization -> "true") - TypedPipe.from(TypedTsv[(String, String)]("input")) + TypedPipe + .from(TypedTsv[(String, String)]("input")) .group .max .write(TypedTsv[(String, String)]("output")) @@ -33,11 +35,14 @@ class NoOrderdSerJob(args: Args) extends Job(args) { class OrderdSerJob(args: Args) extends Job(args) { - implicit def stringOS: OrderedSerialization[String] = new StringOrderedSerialization + implicit def stringOS: OrderedSerialization[String] = + new StringOrderedSerialization - override def config = super.config + (Config.ScaldingRequireOrderedSerialization -> "true") + override def config = + super.config + (Config.ScaldingRequireOrderedSerialization -> "true") - TypedPipe.from(TypedTsv[(String, String)]("input")) + TypedPipe + .from(TypedTsv[(String, String)]("input")) .group .sorted .max @@ -50,8 +55,12 @@ class RequireOrderedSerializationTest extends WordSpec with Matchers { "throw when run" in { val ex = the[Exception] thrownBy { JobTest(new NoOrderdSerJob(_)) - .source(TypedTsv[(String, String)]("input"), List(("a", "a"), ("b", "b"))) - .sink[(String, String)](TypedTsv[(String, String)]("output")) { outBuf => () } + .source( + TypedTsv[(String, String)]("input"), + List(("a", "a"), ("b", "b"))) + .sink[(String, String)](TypedTsv[(String, String)]("output")) { + outBuf => () + } .run .finish } @@ -62,9 +71,11 @@ class RequireOrderedSerializationTest extends WordSpec with Matchers { // throw if we try to run in: "run" in { JobTest(new OrderdSerJob(_)) - .source(TypedTsv[(String, String)]("input"), List(("a", "a"), ("a", "b"), ("b", "b"))) - .sink[(String, String)](TypedTsv[(String, String)]("output")) { outBuf => - outBuf.toSet shouldBe Set(("a", "b"), ("b", "b")) + .source( + TypedTsv[(String, String)]("input"), + List(("a", "a"), ("a", "b"), ("b", "b"))) + .sink[(String, String)](TypedTsv[(String, String)]("output")) { + outBuf => outBuf.toSet shouldBe Set(("a", "b"), ("b", "b")) } .run .finish diff --git a/repos/scalding/scalding-hadoop-test/src/test/scala/com/twitter/scalding/platform/PlatformTest.scala b/repos/scalding/scalding-hadoop-test/src/test/scala/com/twitter/scalding/platform/PlatformTest.scala index e0ebdc1687e..9292523eded 100644 --- a/repos/scalding/scalding-hadoop-test/src/test/scala/com/twitter/scalding/platform/PlatformTest.scala +++ b/repos/scalding/scalding-hadoop-test/src/test/scala/com/twitter/scalding/platform/PlatformTest.scala @@ -12,18 +12,18 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -*/ + */ package com.twitter.scalding.platform -import cascading.pipe.joiner.{ JoinerClosure, InnerJoin } +import cascading.pipe.joiner.{JoinerClosure, InnerJoin} import cascading.tuple.{Fields, Tuple} import com.twitter.scalding._ import com.twitter.scalding.serialization.OrderedSerialization -import java.util.{ Iterator => JIterator } -import org.scalacheck.{ Arbitrary, Gen } -import org.scalatest.{ Matchers, WordSpec } -import org.slf4j.{ LoggerFactory, Logger } +import java.util.{Iterator => JIterator} +import org.scalacheck.{Arbitrary, Gen} +import org.scalatest.{Matchers, WordSpec} +import org.slf4j.{LoggerFactory, Logger} import scala.collection.JavaConverters._ import scala.language.experimental.macros import scala.math.Ordering @@ -68,8 +68,10 @@ object TsvNoCacheJob { class TsvNoCacheJob(args: Args) extends Job(args) { import TsvNoCacheJob._ dataInput.read - .flatMap(new cascading.tuple.Fields(Integer.valueOf(0)) -> 'word){ line: String => line.split("\\s") } - .groupBy('word){ group => group.size } + .flatMap(new cascading.tuple.Fields(Integer.valueOf(0)) -> 'word) { + line: String => line.split("\\s") + } + .groupBy('word) { group => group.size } .mapTo('word -> 'num) { (w: String) => w.toFloat } .write(throwAwayOutput) .groupAll { _.sortBy('num) } @@ -89,7 +91,10 @@ class IterableSourceDistinctJob(args: Args) extends Job(args) { class IterableSourceDistinctIdentityJob(args: Args) extends Job(args) { import IterableSourceDistinctJob._ - TypedPipe.from(data ++ data ++ data).distinctBy(identity).write(TypedTsv("output")) + TypedPipe + .from(data ++ data ++ data) + .distinctBy(identity) + .write(TypedTsv("output")) } class NormalDistinctJob(args: Args) extends Job(args) { @@ -107,11 +112,13 @@ class MultipleGroupByJob(args: Args) extends Job(args) { import com.twitter.scalding.serialization._ import MultipleGroupByJobData._ implicit val stringOrdSer = new StringOrderedSerialization() - implicit val stringTup2OrdSer = new OrderedSerialization2(stringOrdSer, stringOrdSer) - val otherStream = TypedPipe.from(data).map{ k => (k, k) }.group + implicit val stringTup2OrdSer = + new OrderedSerialization2(stringOrdSer, stringOrdSer) + val otherStream = TypedPipe.from(data).map { k => (k, k) }.group - TypedPipe.from(data) - .map{ k => (k, 1L) } + TypedPipe + .from(data) + .map { k => (k, 1L) } .group[String, Long](implicitly, stringOrdSer) .sum .map { @@ -120,9 +127,7 @@ class MultipleGroupByJob(args: Args) extends Job(args) { } .sumByKey[(String, String), Long](implicitly, stringTup2OrdSer, implicitly) .map(_._1._1) - .map { t => - (t.toString, t) - } + .map { t => (t.toString, t) } .group .leftJoin(otherStream) .map(_._1) @@ -131,7 +136,8 @@ class MultipleGroupByJob(args: Args) extends Job(args) { } class TypedPipeWithDescriptionJob(args: Args) extends Job(args) { - TypedPipe.from[String](List("word1", "word1", "word2")) + TypedPipe + .from[String](List("word1", "word1", "word2")) .withDescription("map stage - assign words to 1") .map { w => (w, 1L) } .group @@ -148,7 +154,9 @@ class TypedPipeJoinWithDescriptionJob(args: Args) extends Job(args) { val y = TypedPipe.from[(Int, String)](List((1, "first"))) val z = TypedPipe.from[(Int, Boolean)](List((2, true))).group - x.hashJoin(y) // this triggers an implicit that somehow pushes the line number to the next one + x.hashJoin( + y + ) // this triggers an implicit that somehow pushes the line number to the next one .withDescription("hashJoin") .leftJoin(z) .withDescription("leftJoin") @@ -163,11 +171,11 @@ class TypedPipeHashJoinWithForceToDiskJob(args: Args) extends Job(args) { val y = TypedPipe.from[(Int, String)](List((1, "first"))) //trivial transform and forceToDisk on the rhs - val yMap = y.map( p => (p._1, p._2.toUpperCase)).forceToDisk + val yMap = y.map(p => (p._1, p._2.toUpperCase)).forceToDisk x.hashJoin(yMap) - .withDescription("hashJoin") - .write(TypedTsv[(Int, (Int, String))]("output")) + .withDescription("hashJoin") + .write(TypedTsv[(Int, (Int, String))]("output")) } class TypedPipeHashJoinWithForceToDiskFilterJob(args: Args) extends Job(args) { @@ -177,25 +185,30 @@ class TypedPipeHashJoinWithForceToDiskFilterJob(args: Args) extends Job(args) { val y = TypedPipe.from[(Int, String)](List((1, "first"))) //trivial transform and forceToDisk followed by filter on rhs - val yFilter = y.map( p => (p._1, p._2.toUpperCase)).forceToDisk.filter(p => p._1 == 1) + val yFilter = + y.map(p => (p._1, p._2.toUpperCase)).forceToDisk.filter(p => p._1 == 1) x.hashJoin(yFilter) - .withDescription("hashJoin") - .write(TypedTsv[(Int, (Int, String))]("output")) + .withDescription("hashJoin") + .write(TypedTsv[(Int, (Int, String))]("output")) } -class TypedPipeHashJoinWithForceToDiskWithComplete(args: Args) extends Job(args) { +class TypedPipeHashJoinWithForceToDiskWithComplete(args: Args) + extends Job(args) { PlatformTest.setAutoForceRight(mode, true) val x = TypedPipe.from[(Int, Int)](List((1, 1))) val y = TypedPipe.from[(Int, String)](List((1, "first"))) //trivial transform and forceToDisk followed by WithComplete on rhs - val yComplete = y.map( p => (p._1, p._2.toUpperCase)).forceToDisk.onComplete(() => println("step complete")) + val yComplete = y + .map(p => (p._1, p._2.toUpperCase)) + .forceToDisk + .onComplete(() => println("step complete")) x.hashJoin(yComplete) - .withDescription("hashJoin") - .write(TypedTsv[(Int, (Int, String))]("output")) + .withDescription("hashJoin") + .write(TypedTsv[(Int, (Int, String))]("output")) } class TypedPipeHashJoinWithForceToDiskMapJob(args: Args) extends Job(args) { @@ -204,20 +217,27 @@ class TypedPipeHashJoinWithForceToDiskMapJob(args: Args) extends Job(args) { val y = TypedPipe.from[(Int, String)](List((1, "first"))) //trivial transform and forceToDisk followed by map on rhs - val yMap = y.map( p => (p._1, p._2.toUpperCase)).forceToDisk.map(p => (p._1, p._2.toLowerCase)) + val yMap = y + .map(p => (p._1, p._2.toUpperCase)) + .forceToDisk + .map(p => (p._1, p._2.toLowerCase)) x.hashJoin(yMap) - .withDescription("hashJoin") - .write(TypedTsv[(Int, (Int, String))]("output")) + .withDescription("hashJoin") + .write(TypedTsv[(Int, (Int, String))]("output")) } -class TypedPipeHashJoinWithForceToDiskMapWithAutoForceJob(args: Args) extends Job(args) { +class TypedPipeHashJoinWithForceToDiskMapWithAutoForceJob(args: Args) + extends Job(args) { PlatformTest.setAutoForceRight(mode, true) val x = TypedPipe.from[(Int, Int)](List((1, 1))) val y = TypedPipe.from[(Int, String)](List((1, "first"))) //trivial transform and forceToDisk followed by map on rhs - val yMap = y.map( p => (p._1, p._2.toUpperCase)).forceToDisk.map(p => (p._1, p._2.toLowerCase)) + val yMap = y + .map(p => (p._1, p._2.toUpperCase)) + .forceToDisk + .map(p => (p._1, p._2.toLowerCase)) x.hashJoin(yMap) .withDescription("hashJoin") @@ -230,30 +250,35 @@ class TypedPipeHashJoinWithGroupByJob(args: Args) extends Job(args) { val x = TypedPipe.from[(String, Int)](Tsv("input1", ('x1, 'y1)), Fields.ALL) val y = Tsv("input2", ('x2, 'y2)) - val yGroup = y.groupBy('x2){p => p} + val yGroup = y.groupBy('x2) { p => p } val yTypedPipe = TypedPipe.from[(String, Int)](yGroup, Fields.ALL) x.hashJoin(yTypedPipe) - .withDescription("hashJoin") - .write(TypedTsv[(String, (Int, Int))]("output")) + .withDescription("hashJoin") + .write(TypedTsv[(String, (Int, Int))]("output")) } class TypedPipeHashJoinWithCoGroupJob(args: Args) extends Job(args) { PlatformTest.setAutoForceRight(mode, true) val x = TypedPipe.from[(Int, Int)](List((1, 1))) - val in0 = Tsv("input0").read.mapTo((0, 1) -> ('x0, 'a)) { input: (Int, Int) => input } - val in1 = Tsv("input1").read.mapTo((0, 1) -> ('x1, 'b)) { input: (Int, Int) => input } + val in0 = Tsv("input0").read.mapTo((0, 1) -> ('x0, 'a)) { input: (Int, Int) => + input + } + val in1 = Tsv("input1").read.mapTo((0, 1) -> ('x1, 'b)) { input: (Int, Int) => + input + } val coGroupPipe = in0.coGroupBy('x0) { _.coGroup('x1, in1, OuterJoinMode) } - val coGroupTypedPipe = TypedPipe.from[(Int, Int, Int)](coGroupPipe, Fields.ALL) - val coGroupTuplePipe = coGroupTypedPipe.map{ case(a,b,c) => (a, (b,c)) } + val coGroupTypedPipe = + TypedPipe.from[(Int, Int, Int)](coGroupPipe, Fields.ALL) + val coGroupTuplePipe = coGroupTypedPipe.map { case (a, b, c) => (a, (b, c)) } x.hashJoin(coGroupTuplePipe) - .withDescription("hashJoin") - .write(TypedTsv[(Int, (Int, (Int, Int)))]("output")) + .withDescription("hashJoin") + .write(TypedTsv[(Int, (Int, (Int, Int)))]("output")) } class TypedPipeHashJoinWithEveryJob(args: Args) extends Job(args) { @@ -261,18 +286,19 @@ class TypedPipeHashJoinWithEveryJob(args: Args) extends Job(args) { val x = TypedPipe.from[(Int, String)](Tsv("input1", ('x1, 'y1)), Fields.ALL) val y = Tsv("input2", ('x2, 'y2)).groupBy('x2) { - _.foldLeft('y2 -> 'y2)(0){ (b: Int, a: Int) => b + a} + _.foldLeft('y2 -> 'y2)(0) { (b: Int, a: Int) => b + a } } val yTypedPipe = TypedPipe.from[(Int, Int)](y, Fields.ALL) x.hashJoin(yTypedPipe) - .withDescription("hashJoin") - .write(TypedTsv[(Int, (String, Int))]("output")) + .withDescription("hashJoin") + .write(TypedTsv[(Int, (String, Int))]("output")) } class TypedPipeForceToDiskWithDescriptionJob(args: Args) extends Job(args) { val writeWords = { - TypedPipe.from[String](List("word1 word2", "word1", "word2")) + TypedPipe + .from[String](List("word1 word2", "word1", "word2")) .withDescription("write words to disk") .flatMap { _.split("\\s+") } .forceToDisk @@ -288,7 +314,7 @@ object OrderedSerializationTest { implicit val genASGK = Arbitrary { for { ts <- Arbitrary.arbitrary[Long] - b <- Gen.nonEmptyListOf(Gen.alphaNumChar).map (_.mkString) + b <- Gen.nonEmptyListOf(Gen.alphaNumChar).map(_.mkString) } yield NestedCaseClass(RichDate(ts), (b, b)) } @@ -299,25 +325,36 @@ object OrderedSerializationTest { case class NestedCaseClass(day: RichDate, key: (String, String)) class ComplexJob(input: List[NestedCaseClass], args: Args) extends Job(args) { - implicit def primitiveOrderedBufferSupplier[T]: OrderedSerialization[T] = macro com.twitter.scalding.serialization.macros.impl.OrderedSerializationProviderImpl[T] + implicit def primitiveOrderedBufferSupplier[T]: OrderedSerialization[T] = + macro com.twitter.scalding.serialization.macros.impl + .OrderedSerializationProviderImpl[T] - val ds1 = TypedPipe.from(input).map(_ -> 1L).group.sorted.mapValueStream(_.map(_ * 2)).toTypedPipe.group + val ds1 = TypedPipe + .from(input) + .map(_ -> 1L) + .group + .sorted + .mapValueStream(_.map(_ * 2)) + .toTypedPipe + .group val ds2 = TypedPipe.from(input).map(_ -> 1L).distinct.group - ds2 - .keys + ds2.keys .map(s => s.toString) .write(TypedTsv[String](args("output1"))) - ds2.join(ds1) + ds2 + .join(ds1) .values .map(_.toString) .write(TypedTsv[String](args("output2"))) } class ComplexJob2(input: List[NestedCaseClass], args: Args) extends Job(args) { - implicit def primitiveOrderedBufferSupplier[T]: OrderedSerialization[T] = macro com.twitter.scalding.serialization.macros.impl.OrderedSerializationProviderImpl[T] + implicit def primitiveOrderedBufferSupplier[T]: OrderedSerialization[T] = + macro com.twitter.scalding.serialization.macros.impl + .OrderedSerializationProviderImpl[T] val ds1 = TypedPipe.from(input).map(_ -> (1L, "asfg")) @@ -351,15 +388,16 @@ class CheckForFlowProcessInFieldsJob(args: Args) extends Job(args) { val inA = Tsv("inputA", ('a, 'b)) val inB = Tsv("inputB", ('x, 'y)) - val p = inA.joinWithSmaller('a -> 'x, inB).map(('b, 'y) -> 'z) { args: (String, String) => - stat.inc + val p = inA.joinWithSmaller('a -> 'x, inB).map(('b, 'y) -> 'z) { + args: (String, String) => + stat.inc - val flowProcess = RuntimeStats.getFlowProcessForUniqueId(uniqueID) - if (flowProcess == null) { - throw new NullPointerException("No active FlowProcess was available.") - } + val flowProcess = RuntimeStats.getFlowProcessForUniqueId(uniqueID) + if (flowProcess == null) { + throw new NullPointerException("No active FlowProcess was available.") + } - s"${args._1},${args._2}" + s"${args._1},${args._2}" } p.write(Tsv("output", ('b, 'y))) @@ -372,16 +410,21 @@ class CheckForFlowProcessInTypedJob(args: Args) extends Job(args) { val inA = TypedPipe.from(TypedTsv[(String, String)]("inputA")) val inB = TypedPipe.from(TypedTsv[(String, String)]("inputB")) - inA.group.join(inB.group).forceToReducers.mapGroup((key, valuesIter) => { - stat.inc - - val flowProcess = RuntimeStats.getFlowProcessForUniqueId(uniqueID) - if (flowProcess == null) { - throw new NullPointerException("No active FlowProcess was available.") - } - - valuesIter.map({ case (a, b) => s"$a:$b" }) - }).toTypedPipe.write(TypedTsv[(String, String)]("output")) + inA.group + .join(inB.group) + .forceToReducers + .mapGroup((key, valuesIter) => { + stat.inc + + val flowProcess = RuntimeStats.getFlowProcessForUniqueId(uniqueID) + if (flowProcess == null) { + throw new NullPointerException("No active FlowProcess was available.") + } + + valuesIter.map({ case (a, b) => s"$a:$b" }) + }) + .toTypedPipe + .write(TypedTsv[(String, String)]("output")) } object PlatformTest { @@ -397,7 +440,10 @@ object PlatformTest { // Keeping all of the specifications in the same tests puts the result output all together at the end. // This is useful given that the Hadoop MiniMRCluster and MiniDFSCluster spew a ton of logging. -class PlatformTest extends WordSpec with Matchers with HadoopSharedPlatformTest { +class PlatformTest + extends WordSpec + with Matchers + with HadoopSharedPlatformTest { "An InAndOutTest" should { val inAndOut = Seq("a", "b", "c") @@ -429,7 +475,11 @@ class PlatformTest extends WordSpec with Matchers with HadoopSharedPlatformTest HadoopPlatformJobTest(new TsvNoCacheJob(_), cluster) .source(dataInput, data) .sink(typedThrowAwayOutput) { _.toSet should have size 4 } - .sink(typedRealOutput) { _.map{ f: Float => (f * 10).toInt }.toList shouldBe (outputData.map{ f: Float => (f * 10).toInt }.toList) } + .sink(typedRealOutput) { + _.map { f: Float => (f * 10).toInt }.toList shouldBe (outputData.map { + f: Float => (f * 10).toInt + }.toList) + } .run } } @@ -448,177 +498,207 @@ class PlatformTest extends WordSpec with Matchers with HadoopSharedPlatformTest "A TypedPipeForceToDiskWithDescriptionPipe" should { "have a custom step name from withDescription" in { - HadoopPlatformJobTest(new TypedPipeForceToDiskWithDescriptionJob(_), cluster) - .inspectCompletedFlow { flow => - val steps = flow.getFlowSteps.asScala - val firstStep = steps.filter(_.getName.startsWith("(1/2")) - val secondStep = steps.filter(_.getName.startsWith("(2/2")) - val lab1 = firstStep.map(_.getConfig.get(Config.StepDescriptions)) - lab1 should have size 1 - lab1(0) should include ("write words to disk") - val lab2 = secondStep.map(_.getConfig.get(Config.StepDescriptions)) - lab2 should have size 1 - lab2(0) should include ("output frequency by length") - } - .run + HadoopPlatformJobTest( + new TypedPipeForceToDiskWithDescriptionJob(_), + cluster).inspectCompletedFlow { flow => + val steps = flow.getFlowSteps.asScala + val firstStep = steps.filter(_.getName.startsWith("(1/2")) + val secondStep = steps.filter(_.getName.startsWith("(2/2")) + val lab1 = firstStep.map(_.getConfig.get(Config.StepDescriptions)) + lab1 should have size 1 + lab1(0) should include("write words to disk") + val lab2 = secondStep.map(_.getConfig.get(Config.StepDescriptions)) + lab2 should have size 1 + lab2(0) should include("output frequency by length") + }.run } } //also tests HashJoin behavior to verify that we don't introduce a forceToDisk as the RHS pipe is source Pipe "A TypedPipeJoinWithDescriptionPipe" should { "have a custom step name from withDescription and no extra forceToDisk steps on hashJoin's rhs" in { - HadoopPlatformJobTest(new TypedPipeJoinWithDescriptionJob(_), cluster) - .inspectCompletedFlow { flow => + HadoopPlatformJobTest(new TypedPipeJoinWithDescriptionJob(_), cluster).inspectCompletedFlow { + flow => val steps = flow.getFlowSteps.asScala steps should have size 1 - val firstStep = steps.headOption.map(_.getConfig.get(Config.StepDescriptions)).getOrElse("") + val firstStep = steps.headOption + .map(_.getConfig.get(Config.StepDescriptions)) + .getOrElse("") val lines = List(149, 151, 152, 155, 156).map { i => s"com.twitter.scalding.platform.TypedPipeJoinWithDescriptionJob.(PlatformTest.scala:$i" } - firstStep should include ("leftJoin") - firstStep should include ("hashJoin") - lines.foreach { l => firstStep should include (l) } - steps.map(_.getConfig.get(Config.StepDescriptions)).foreach(s => info(s)) - } - .run + firstStep should include("leftJoin") + firstStep should include("hashJoin") + lines.foreach { l => firstStep should include(l) } + steps + .map(_.getConfig.get(Config.StepDescriptions)) + .foreach(s => info(s)) + }.run } } //expect two jobs - one for the map prior to the Checkpoint and one for the hashJoin "A TypedPipeHashJoinWithForceToDiskJob" should { "have a custom step name from withDescription and only one user provided forceToDisk on hashJoin's rhs" in { - HadoopPlatformJobTest(new TypedPipeHashJoinWithForceToDiskJob(_), cluster) - .inspectCompletedFlow { flow => - val steps = flow.getFlowSteps.asScala - steps should have size 2 - val secondStep = steps.lastOption.map(_.getConfig.get(Config.StepDescriptions)).getOrElse("") - secondStep should include ("hashJoin") - } - .run + HadoopPlatformJobTest(new TypedPipeHashJoinWithForceToDiskJob(_), cluster).inspectCompletedFlow { + flow => + val steps = flow.getFlowSteps.asScala + steps should have size 2 + val secondStep = steps.lastOption + .map(_.getConfig.get(Config.StepDescriptions)) + .getOrElse("") + secondStep should include("hashJoin") + }.run } } //expect 3 jobs - one extra compared to previous as there's a new forceToDisk added "A TypedPipeHashJoinWithForceToDiskFilterJob" should { "have a custom step name from withDescription and an extra forceToDisk due to a filter operation on hashJoin's rhs" in { - HadoopPlatformJobTest(new TypedPipeHashJoinWithForceToDiskFilterJob(_), cluster) - .inspectCompletedFlow { flow => - val steps = flow.getFlowSteps.asScala - steps should have size 3 - val lastStep = steps.lastOption.map(_.getConfig.get(Config.StepDescriptions)).getOrElse("") - lastStep should include ("hashJoin") - } - .run + HadoopPlatformJobTest( + new TypedPipeHashJoinWithForceToDiskFilterJob(_), + cluster).inspectCompletedFlow { flow => + val steps = flow.getFlowSteps.asScala + steps should have size 3 + val lastStep = steps.lastOption + .map(_.getConfig.get(Config.StepDescriptions)) + .getOrElse("") + lastStep should include("hashJoin") + }.run } } //expect two jobs - one for the map prior to the Checkpoint and one for the rest "A TypedPipeHashJoinWithForceToDiskWithComplete" should { "have a custom step name from withDescription and no extra forceToDisk due to with complete operation on hashJoin's rhs" in { - HadoopPlatformJobTest(new TypedPipeHashJoinWithForceToDiskWithComplete(_), cluster) - .inspectCompletedFlow { flow => - val steps = flow.getFlowSteps.asScala - steps should have size 2 - val lastStep = steps.lastOption.map(_.getConfig.get(Config.StepDescriptions)).getOrElse("") - lastStep should include ("hashJoin") - } - .run + HadoopPlatformJobTest( + new TypedPipeHashJoinWithForceToDiskWithComplete(_), + cluster).inspectCompletedFlow { flow => + val steps = flow.getFlowSteps.asScala + steps should have size 2 + val lastStep = steps.lastOption + .map(_.getConfig.get(Config.StepDescriptions)) + .getOrElse("") + lastStep should include("hashJoin") + }.run } } //expect two jobs - one for the map prior to the Checkpoint and one for the rest "A TypedPipeHashJoinWithForceToDiskMapJob" should { "have a custom step name from withDescription and no extra forceToDisk due to map (autoForce = false) on forceToDisk operation on hashJoin's rhs" in { - HadoopPlatformJobTest(new TypedPipeHashJoinWithForceToDiskMapJob(_), cluster) - .inspectCompletedFlow { flow => - val steps = flow.getFlowSteps.asScala - steps should have size 2 - val lastStep = steps.lastOption.map(_.getConfig.get(Config.StepDescriptions)).getOrElse("") - lastStep should include ("hashJoin") - } - .run + HadoopPlatformJobTest( + new TypedPipeHashJoinWithForceToDiskMapJob(_), + cluster).inspectCompletedFlow { flow => + val steps = flow.getFlowSteps.asScala + steps should have size 2 + val lastStep = steps.lastOption + .map(_.getConfig.get(Config.StepDescriptions)) + .getOrElse("") + lastStep should include("hashJoin") + }.run } } //expect one extra job from the above - we end up performing a forceToDisk after the map "A TypedPipeHashJoinWithForceToDiskMapWithAutoForceJob" should { "have a custom step name from withDescription and an extra forceToDisk due to map (autoForce = true) on forceToDisk operation on hashJoin's rhs" in { - HadoopPlatformJobTest(new TypedPipeHashJoinWithForceToDiskMapWithAutoForceJob(_), cluster) - .inspectCompletedFlow { flow => - val steps = flow.getFlowSteps.asScala - steps should have size 3 - val lastStep = steps.lastOption.map(_.getConfig.get(Config.StepDescriptions)).getOrElse("") - lastStep should include ("hashJoin") - } - .run + HadoopPlatformJobTest( + new TypedPipeHashJoinWithForceToDiskMapWithAutoForceJob(_), + cluster).inspectCompletedFlow { flow => + val steps = flow.getFlowSteps.asScala + steps should have size 3 + val lastStep = steps.lastOption + .map(_.getConfig.get(Config.StepDescriptions)) + .getOrElse("") + lastStep should include("hashJoin") + }.run } } "A TypedPipeHashJoinWithGroupByJob" should { "have a custom step name from withDescription and no extra forceToDisk after groupBy on hashJoin's rhs" in { HadoopPlatformJobTest(new TypedPipeHashJoinWithGroupByJob(_), cluster) - .source(TypedTsv[(String, Int)]("input1"), Seq(("first", 45))) - .source(TypedTsv[(String, Int)]("input2"), Seq(("first", 1), ("first", 2), ("first", 3), ("second", 1), ("second", 2))) - .inspectCompletedFlow { flow => - val steps = flow.getFlowSteps.asScala - steps should have size 2 - val lastStep = steps.lastOption.map(_.getConfig.get(Config.StepDescriptions)).getOrElse("") - lastStep should include ("hashJoin") - } - .run + .source(TypedTsv[(String, Int)]("input1"), Seq(("first", 45))) + .source( + TypedTsv[(String, Int)]("input2"), + Seq( + ("first", 1), + ("first", 2), + ("first", 3), + ("second", 1), + ("second", 2))) + .inspectCompletedFlow { flow => + val steps = flow.getFlowSteps.asScala + steps should have size 2 + val lastStep = steps.lastOption + .map(_.getConfig.get(Config.StepDescriptions)) + .getOrElse("") + lastStep should include("hashJoin") + } + .run } } "A TypedPipeHashJoinWithCoGroupJob" should { "have a custom step name from withDescription and no extra forceToDisk after coGroup + map on hashJoin's rhs" in { HadoopPlatformJobTest(new TypedPipeHashJoinWithCoGroupJob(_), cluster) - .source(TypedTsv[(Int, Int)]("input0"), List((0, 1), (1, 1), (2, 1), (3, 2))) - .source(TypedTsv[(Int, Int)]("input1"), List((0, 1), (2, 5), (3, 2))) - .inspectCompletedFlow { flow => - val steps = flow.getFlowSteps.asScala - steps should have size 2 - val lastStep = steps.lastOption.map(_.getConfig.get(Config.StepDescriptions)).getOrElse("") - lastStep should include ("hashJoin") - } - .run + .source( + TypedTsv[(Int, Int)]("input0"), + List((0, 1), (1, 1), (2, 1), (3, 2))) + .source(TypedTsv[(Int, Int)]("input1"), List((0, 1), (2, 5), (3, 2))) + .inspectCompletedFlow { flow => + val steps = flow.getFlowSteps.asScala + steps should have size 2 + val lastStep = steps.lastOption + .map(_.getConfig.get(Config.StepDescriptions)) + .getOrElse("") + lastStep should include("hashJoin") + } + .run } } "A TypedPipeHashJoinWithEveryJob" should { "have a custom step name from withDescription and no extra forceToDisk after an Every on hashJoin's rhs" in { HadoopPlatformJobTest(new TypedPipeHashJoinWithEveryJob(_), cluster) - .source(TypedTsv[(Int, String)]("input1"), Seq((1, "foo"))) - .source(TypedTsv[(Int, Int)]("input2"), Seq((1, 30), (1,10), (1,20), (2,20))) - .inspectCompletedFlow { flow => - val steps = flow.getFlowSteps.asScala - steps should have size 2 - val lastStep = steps.lastOption.map(_.getConfig.get(Config.StepDescriptions)).getOrElse("") - lastStep should include ("hashJoin") - } - .run + .source(TypedTsv[(Int, String)]("input1"), Seq((1, "foo"))) + .source( + TypedTsv[(Int, Int)]("input2"), + Seq((1, 30), (1, 10), (1, 20), (2, 20))) + .inspectCompletedFlow { flow => + val steps = flow.getFlowSteps.asScala + steps should have size 2 + val lastStep = steps.lastOption + .map(_.getConfig.get(Config.StepDescriptions)) + .getOrElse("") + lastStep should include("hashJoin") + } + .run } } "A TypedPipeWithDescriptionPipe" should { "have a custom step name from withDescription" in { - HadoopPlatformJobTest(new TypedPipeWithDescriptionJob(_), cluster) - .inspectCompletedFlow { flow => + HadoopPlatformJobTest(new TypedPipeWithDescriptionJob(_), cluster).inspectCompletedFlow { + flow => val steps = flow.getFlowSteps.asScala - val descs = List("map stage - assign words to 1", + val descs = List( + "map stage - assign words to 1", "reduce stage - sum", "write", // should see the .group and the .write show up as line numbers "com.twitter.scalding.platform.TypedPipeWithDescriptionJob.(PlatformTest.scala:137)", - "com.twitter.scalding.platform.TypedPipeWithDescriptionJob.(PlatformTest.scala:141)") + "com.twitter.scalding.platform.TypedPipeWithDescriptionJob.(PlatformTest.scala:141)" + ) val foundDescs = steps.map(_.getConfig.get(Config.StepDescriptions)) descs.foreach { d => assert(foundDescs.size == 1) assert(foundDescs(0).contains(d)) } - //steps.map(_.getConfig.get(Config.StepDescriptions)).foreach(s => info(s)) - } - .run + //steps.map(_.getConfig.get(Config.StepDescriptions)).foreach(s => info(s)) + }.run } } @@ -677,26 +757,36 @@ class PlatformTest extends WordSpec with Matchers with HadoopSharedPlatformTest "Methods called from a Joiner" should { "have access to a FlowProcess from a join in the Fields-based API" in { HadoopPlatformJobTest(new CheckForFlowProcessInFieldsJob(_), cluster) - .source(TypedTsv[(String, String)]("inputA"), Seq(("1", "alpha"), ("2", "beta"))) - .source(TypedTsv[(String, String)]("inputB"), Seq(("1", "first"), ("2", "second"))) + .source( + TypedTsv[(String, String)]("inputA"), + Seq(("1", "alpha"), ("2", "beta"))) + .source( + TypedTsv[(String, String)]("inputB"), + Seq(("1", "first"), ("2", "second"))) .sink(TypedTsv[(String, String)]("output")) { _ => // The job will fail with an exception if the FlowProcess is unavailable. } .inspectCompletedFlow({ flow => - flow.getFlowStats.getCounterValue(Stats.ScaldingGroup, "joins") shouldBe 2 + flow.getFlowStats + .getCounterValue(Stats.ScaldingGroup, "joins") shouldBe 2 }) .run } "have access to a FlowProcess from a join in the Typed API" in { HadoopPlatformJobTest(new CheckForFlowProcessInTypedJob(_), cluster) - .source(TypedTsv[(String, String)]("inputA"), Seq(("1", "alpha"), ("2", "beta"))) - .source(TypedTsv[(String, String)]("inputB"), Seq(("1", "first"), ("2", "second"))) + .source( + TypedTsv[(String, String)]("inputA"), + Seq(("1", "alpha"), ("2", "beta"))) + .source( + TypedTsv[(String, String)]("inputB"), + Seq(("1", "first"), ("2", "second"))) .sink[(String, String)](TypedTsv[(String, String)]("output")) { _ => // The job will fail with an exception if the FlowProcess is unavailable. } .inspectCompletedFlow({ flow => - flow.getFlowStats.getCounterValue(Stats.ScaldingGroup, "joins") shouldBe 2 + flow.getFlowStats + .getCounterValue(Stats.ScaldingGroup, "joins") shouldBe 2 }) .run } diff --git a/repos/slick/slick/src/sphinx/code/LiftedEmbedding.scala b/repos/slick/slick/src/sphinx/code/LiftedEmbedding.scala index 83aa1a7719a..5331c8d1a2a 100644 --- a/repos/slick/slick/src/sphinx/code/LiftedEmbedding.scala +++ b/repos/slick/slick/src/sphinx/code/LiftedEmbedding.scala @@ -9,33 +9,36 @@ import scala.reflect.ClassTag object LiftedEmbedding extends App { // Simple Coffees for Rep types comparison { - //#reptypes - class Coffees(tag: Tag) extends Table[(String, Double)](tag, "COFFEES") { - def name = column[String]("COF_NAME") - def price = column[Double]("PRICE") - def * = (name, price) - } - val coffees = TableQuery[Coffees] + //#reptypes + class Coffees(tag: Tag) extends Table[(String, Double)](tag, "COFFEES") { + def name = column[String]("COF_NAME") + def price = column[Double]("PRICE") + def * = (name, price) + } + val coffees = TableQuery[Coffees] - //#reptypes + //#reptypes } { //#plaintypes - case class Coffee(name: String, price: Double) - val coffees: List[Coffee] = //... + case class Coffee(name: String, price: Double) + val coffees: List[Coffee] = //... //#plaintypes - Nil + Nil //#plaintypes - val l = coffees.filter(_.price > 8.0).map(_.name) - // ^ ^ ^ - // Double Double String + val l = coffees.filter(_.price > 8.0).map(_.name) + // ^ ^ ^ + // Double Double String //#plaintypes } //#foreignkey - class Suppliers(tag: Tag) extends Table[(Int, String, String, String, String, String)](tag, "SUPPLIERS") { + class Suppliers(tag: Tag) + extends Table[(Int, String, String, String, String, String)]( + tag, + "SUPPLIERS") { def id = column[Int]("SUP_ID", O.PrimaryKey) //... //#foreignkey @@ -53,7 +56,8 @@ object LiftedEmbedding extends App { //#foreignkey //#tabledef //#foreignkey - class Coffees(tag: Tag) extends Table[(String, Int, Double, Int, Int)](tag, "COFFEES") { + class Coffees(tag: Tag) + extends Table[(String, Int, Double, Int, Int)](tag, "COFFEES") { //#foreignkey def name = column[String]("COF_NAME", O.PrimaryKey) //#foreignkey @@ -71,7 +75,11 @@ object LiftedEmbedding extends App { //#tabledef //#foreignkeynav //#foreignkey - def supplier = foreignKey("SUP_FK", supID, suppliers)(_.id, onUpdate=ForeignKeyAction.Restrict, onDelete=ForeignKeyAction.Cascade) + def supplier = + foreignKey("SUP_FK", supID, suppliers)( + _.id, + onUpdate = ForeignKeyAction.Restrict, + onDelete = ForeignKeyAction.Cascade) //#foreignkeynav // compiles to SQL: // alter table "COFFEES" add constraint "SUP_FK" foreign key("SUP_ID") @@ -91,22 +99,25 @@ object LiftedEmbedding extends App { //#foreignkey { - //#schemaname - class Coffees(tag: Tag) - extends Table[(String, Int, Double, Int, Int)](tag, Some("MYSCHEMA"), "COFFEES") { - //... - //#schemaname - def * = ??? - def name = column[String]("NAME") - //#schemaname - } - //#schemaname + //#schemaname + class Coffees(tag: Tag) + extends Table[(String, Int, Double, Int, Int)]( + tag, + Some("MYSCHEMA"), + "COFFEES") { + //... + //#schemaname + def * = ??? + def name = column[String]("NAME") + //#schemaname + } + //#schemaname - //#tablequery2 - object coffees extends TableQuery(new Coffees(_)) { - val findByName = this.findBy(_.name) - } - //#tablequery2 + //#tablequery2 + object coffees extends TableQuery(new Coffees(_)) { + val findByName = this.findBy(_.name) + } + //#tablequery2 } //#reptypes val q = coffees.filter(_.price > 8.0).map(_.name) @@ -126,8 +137,12 @@ object LiftedEmbedding extends App { } val users = TableQuery[Users] //#mappedtable - def usersForInsert = users.map(u => (u.first, u.last).shaped <> - ({ t => User(None, t._1, t._2)}, { (u: User) => Some((u.first, u.last))})) + def usersForInsert = + users.map(u => + (u.first, u.last).shaped <> + ({ t => User(None, t._1, t._2) }, { (u: User) => + Some((u.first, u.last)) + })) //#insert2 //#index @@ -152,26 +167,27 @@ object LiftedEmbedding extends App { val db: Database = Database.forConfig("h2mem1") try { - //#ddl + //#ddl val schema = coffees.schema ++ suppliers.schema - //#ddl + //#ddl Await.result( - //#ddl - db.run(DBIO.seq( - schema.create, - //... - schema.drop - )) - //#ddl - , Duration.Inf) - - //#ddl2 + //#ddl + db.run( + DBIO.seq( + schema.create, + //... + schema.drop + )) + //#ddl + , + Duration.Inf) + + //#ddl2 schema.create.statements.foreach(println) schema.drop.statements.foreach(println) - //#ddl2 - TableQuery[A].schema.create.statements.foreach(println) - - ;{ + //#ddl2 + TableQuery[A].schema.create.statements.foreach(println); + { //#filtering val q1 = coffees.filter(_.supID === 101) // compiles to SQL (simplified): @@ -194,14 +210,18 @@ object LiftedEmbedding extends App { // building criteria using a "dynamic filter" e.g. from a webform. val criteriaColombian = Option("Colombian") val criteriaEspresso = Option("Espresso") - val criteriaRoast:Option[String] = None + val criteriaRoast: Option[String] = None val q4 = coffees.filter { coffee => List( - criteriaColombian.map(coffee.name === _), - criteriaEspresso.map(coffee.name === _), - criteriaRoast.map(coffee.name === _) // not a condition as `criteriaRoast` evaluates to `None` - ).collect({case Some(criteria) => criteria}).reduceLeftOption(_ || _).getOrElse(true: Rep[Boolean]) + criteriaColombian.map(coffee.name === _), + criteriaEspresso.map(coffee.name === _), + criteriaRoast.map( + coffee.name === _ + ) // not a condition as `criteriaRoast` evaluates to `None` + ).collect({ case Some(criteria) => criteria }) + .reduceLeftOption(_ || _) + .getOrElse(true: Rep[Boolean]) } // compiles to SQL (simplified): // select "COF_NAME", "SUP_ID", "PRICE", "SALES", "TOTAL" @@ -213,9 +233,8 @@ object LiftedEmbedding extends App { println(q2.result.statements.head) println(q3.result.statements.head) println(q4.result.statements.head) - } - - ;{ + }; + { //#aggregation1 val q = coffees.map(_.price) @@ -240,9 +259,8 @@ object LiftedEmbedding extends App { println(q2.shaped.result.statements.head) println(q3.shaped.result.statements.head) println(q4.shaped.result.statements.head) - } - - ;{ + }; + { Await.result(db.run(schema.create), Duration.Inf) //#aggregation2 val q1 = coffees.length @@ -289,17 +307,17 @@ object LiftedEmbedding extends App { //#delete2 Await.result(affectedRowsCount, Duration.Inf) } - } - - ;{ + }; + { //#aggregation3 val q = (for { c <- coffees s <- c.supplier } yield (c, s)).groupBy(_._1.supID) - val q2 = q.map { case (supID, css) => - (supID, css.length, css.map(_._1.price).avg) + val q2 = q.map { + case (supID, css) => + (supID, css.length, css.map(_._1.price).avg) } // compiles to SQL: // select x2."SUP_ID", count(1), avg(x2."PRICE") @@ -308,20 +326,18 @@ object LiftedEmbedding extends App { // group by x2."SUP_ID" //#aggregation3 println(q2.result.statements.head) - } - - ;{ + }; + { //#insert1 val insertActions = DBIO.seq( coffees += ("Colombian", 101, 7.99, 0, 0), - coffees ++= Seq( ("French_Roast", 49, 8.99, 0, 0), - ("Espresso", 150, 9.99, 0, 0) + ("Espresso", 150, 9.99, 0, 0) ), - // "sales" and "total" will use the default value 0: - coffees.map(c => (c.name, c.supID, c.price)) += ("Colombian_Decaf", 101, 8.99) + coffees + .map(c => (c.name, c.supID, c.price)) += ("Colombian_Decaf", 101, 8.99) ) // Get the statement without having to specify a value to insert: @@ -332,14 +348,17 @@ object LiftedEmbedding extends App { //#insert1 println(sql) - Await.result(db.run(DBIO.seq( - (suppliers ++= Seq( - (101, "", "", "", "", ""), - (49, "", "", "", "", ""), - (150, "", "", "", "", "") - )), - insertActions - )), Duration.Inf) + Await.result( + db.run( + DBIO.seq( + (suppliers ++= Seq( + (101, "", "", "", "", ""), + (49, "", "", "", "", ""), + (150, "", "", "", "", "") + )), + insertActions + )), + Duration.Inf) //#insert3 val userId = @@ -350,10 +369,13 @@ object LiftedEmbedding extends App { //#insert3b val userWithId = (users returning users.map(_.id) - into ((user,id) => user.copy(id=Some(id))) - ) += User(None, "Stefan", "Zeiger") + into ((user, id) => user.copy(id = Some(id)))) += User( + None, + "Stefan", + "Zeiger") //#insert3b - val userWithIdRes = Await.result(db.run(users.schema.create >> userWithId), Duration.Inf) + val userWithIdRes = + Await.result(db.run(users.schema.create >> userWithId), Duration.Inf) println(userWithIdRes) //#insert4 @@ -366,7 +388,9 @@ object LiftedEmbedding extends App { val actions = DBIO.seq( users2.schema.create, - users2 forceInsertQuery (users.map { u => (u.id, u.first ++ " " ++ u.last) }), + users2 forceInsertQuery (users.map { u => + (u.id, u.first ++ " " ++ u.last) + }), users2 forceInsertExpr (users.length + 1, "admin") ) //#insert4 @@ -376,14 +400,14 @@ object LiftedEmbedding extends App { val updated = users.insertOrUpdate(User(Some(1), "Admin", "Zeiger")) // returns: number of rows updated - val updatedAdmin = (users returning users).insertOrUpdate(User(Some(1), "Slick Admin", "Zeiger")) + val updatedAdmin = (users returning users).insertOrUpdate( + User(Some(1), "Slick Admin", "Zeiger")) // returns: None if updated, Some((Int, String)) if row inserted //#insertOrUpdate Await.result(db.run(updated), Duration.Inf) Await.result(db.run(updatedAdmin), Duration.Inf) - } - - ;{ + }; + { //#update1 val q = for { c <- coffees if c.name === "Espresso" } yield c.price val updateAction = q.update(10.49) @@ -395,15 +419,16 @@ object LiftedEmbedding extends App { // update "COFFEES" set "PRICE" = ? where "COFFEES"."COF_NAME" = 'Espresso' //#update1 println(sql) - } - - ;{ - Await.result(db.run( - usersForInsert ++= Seq( - User(None,"",""), - User(None,"","") - ) - ), Duration.Inf) + }; + { + Await.result( + db.run( + usersForInsert ++= Seq( + User(None, "", ""), + User(None, "", "") + ) + ), + Duration.Inf) { //#compiled1 @@ -423,7 +448,8 @@ object LiftedEmbedding extends App { { //#compiled2 - val userPaged = Compiled((d: ConstColumn[Long], t: ConstColumn[Long]) => users.drop(d).take(t)) + val userPaged = Compiled((d: ConstColumn[Long], t: ConstColumn[Long]) => + users.drop(d).take(t)) val usersAction1 = userPaged(2, 1).result val usersAction2 = userPaged(1, 3).result @@ -447,10 +473,10 @@ object LiftedEmbedding extends App { val namesAction = userNameByIDRange(2, 5).result //#template1 } - } - - ;{ - class SalesPerDay(tag: Tag) extends Table[(Date, Int)](tag, "SALES_PER_DAY") { + }; + { + class SalesPerDay(tag: Tag) + extends Table[(Date, Int)](tag, "SALES_PER_DAY") { def day = column[Date]("DAY", O.PrimaryKey) def count = column[Int]("COUNT") def * = (day, count) @@ -462,7 +488,9 @@ object LiftedEmbedding extends App { // Use the lifted function in a query to group by day of week val q1 = for { - (dow, q) <- salesPerDay.map(s => (dayOfWeek(s.day), s.count)).groupBy(_._1) + (dow, q) <- salesPerDay + .map(s => (dayOfWeek(s.day), s.count)) + .groupBy(_._1) } yield (dow, q.map(_._2).sum) //#simplefunction1 @@ -471,17 +499,21 @@ object LiftedEmbedding extends App { SimpleFunction[Int]("day_of_week").apply(Seq(c)) //#simplefunction2 - assert{ - Await.result(db.run( - salesPerDay.schema.create >> - (salesPerDay += ( (new Date(999999999), 999) )) >> - { - //#simpleliteral - val current_date = SimpleLiteral[java.sql.Date]("CURRENT_DATE") - salesPerDay.map(_ => current_date) - //#simpleliteral - }.result.head - ), Duration.Inf).isInstanceOf[java.sql.Date] + assert { + Await + .result( + db.run( + salesPerDay.schema.create >> + (salesPerDay += ((new Date(999999999), 999))) >> { + //#simpleliteral + val current_date = SimpleLiteral[java.sql.Date]("CURRENT_DATE") + salesPerDay.map(_ => current_date) + //#simpleliteral + }.result.head + ), + Duration.Inf + ) + .isInstanceOf[java.sql.Date] } } @@ -494,8 +526,8 @@ object LiftedEmbedding extends App { // And a ColumnType that maps it to Int values 1 and 0 implicit val boolColumnType = MappedColumnType.base[Bool, Int]( - { b => if(b == True) 1 else 0 }, // map Bool to Int - { i => if(i == 1) True else False } // map Int to Bool + { b => if (b == True) 1 else 0 }, // map Bool to Int + { i => if (i == 1) True else False } // map Int to Bool ) // You can now use Bool like any built-in column type (in tables, queries, etc.) @@ -514,24 +546,28 @@ object LiftedEmbedding extends App { def * = (id, data) } //#mappedtype2 - } - - ;{ + }; + { //#recordtype1 // A custom record class case class Pair[A, B](a: A, b: B) // A Shape implementation for Pair - final class PairShape[Level <: ShapeLevel, M <: Pair[_,_], U <: Pair[_,_] : ClassTag, P <: Pair[_,_]]( - val shapes: Seq[Shape[_, _, _, _]]) - extends MappedScalaProductShape[Level, Pair[_,_], M, U, P] { + final class PairShape[Level <: ShapeLevel, M <: Pair[_, _], + U <: Pair[_, _]: ClassTag, P <: Pair[_, _]]( + val shapes: Seq[Shape[_, _, _, _]]) + extends MappedScalaProductShape[Level, Pair[_, _], M, U, P] { def buildValue(elems: IndexedSeq[Any]) = Pair(elems(0), elems(1)) - def copy(shapes: Seq[Shape[_ <: ShapeLevel, _, _, _]]) = new PairShape(shapes) + def copy(shapes: Seq[Shape[_ <: ShapeLevel, _, _, _]]) = + new PairShape(shapes) } implicit def pairShape[Level <: ShapeLevel, M1, M2, U1, U2, P1, P2]( - implicit s1: Shape[_ <: Level, M1, U1, P1], s2: Shape[_ <: Level, M2, U2, P2] - ) = new PairShape[Level, Pair[M1, M2], Pair[U1, U2], Pair[P1, P2]](Seq(s1, s2)) + implicit s1: Shape[_ <: Level, M1, U1, P1], + s2: Shape[_ <: Level, M2, U2, P2] + ) = + new PairShape[Level, Pair[M1, M2], Pair[U1, U2], Pair[P1, P2]]( + Seq(s1, s2)) //#recordtype1 //#recordtype2 @@ -555,11 +591,16 @@ object LiftedEmbedding extends App { .map { case a => Pair(a.id, (a.s ++ a.s)) } .filter { case Pair(id, _) => id =!= 1 } .sortBy { case Pair(_, ss) => ss } - .map { case Pair(id, ss) => Pair(id, Pair(42 , ss)) } + .map { case Pair(id, ss) => Pair(id, Pair(42, ss)) } // returns: Vector(Pair(3,Pair(42,"bb")), Pair(2,Pair(42,"cc"))) //#recordtype2 - assert(Await.result(db.run(as.schema.create >> insertAction >> q2.result), Duration.Inf) == Vector(Pair(3,Pair(42,"bb")), Pair(2,Pair(42,"cc")))) + assert( + Await.result( + db.run(as.schema.create >> insertAction >> q2.result), + Duration.Inf) == Vector( + Pair(3, Pair(42, "bb")), + Pair(2, Pair(42, "cc")))) //#case-class-shape // two custom case class variants @@ -589,12 +630,15 @@ object LiftedEmbedding extends App { // returns: Vector(B(3,"bb"), B(2,"cc")) //#case-class-shape - assert(Await.result(db.run(bs.schema.create >> insertActions >> q3.result), Duration.Inf) == Vector(B(3,"bb"), B(2,"cc"))) + assert( + Await.result( + db.run(bs.schema.create >> insertActions >> q3.result), + Duration.Inf) == Vector(B(3, "bb"), B(2, "cc"))) //#combining-shapes // Combining multiple mapped types - case class LiftedC(p: Pair[Rep[Int],Rep[String]], b: LiftedB) - case class C(p: Pair[Int,String], b: B) + case class LiftedC(p: Pair[Rep[Int], Rep[String]], b: LiftedB) + case class C(p: Pair[Int, String], b: B) implicit object CShape extends CaseClassShape(LiftedC.tupled, C.tupled) @@ -602,27 +646,35 @@ object LiftedEmbedding extends App { def id = column[Int]("id") def s = column[String]("s") def projection = LiftedC( - Pair(column("p1"),column("p2")), // (cols defined inline, type inferred) - LiftedB(id,s) + Pair( + column("p1"), + column("p2") + ), // (cols defined inline, type inferred) + LiftedB(id, s) ) def * = projection } val cs = TableQuery[CRow] val insertActions2 = DBIO.seq( - cs += C(Pair(7,"x"), B(1,"a")), - cs += C(Pair(8,"y"), B(2,"c")), - cs += C(Pair(9,"z"), B(3,"b")) + cs += C(Pair(7, "x"), B(1, "a")), + cs += C(Pair(8, "y"), B(2, "c")), + cs += C(Pair(9, "z"), B(3, "b")) ) val q4 = cs - .map { case c => LiftedC(c.projection.p, LiftedB(c.id,(c.s ++ c.s))) } - .filter { case LiftedC(_, LiftedB(id,_)) => id =!= 1 } - .sortBy { case LiftedC(Pair(_,p2), LiftedB(_,ss)) => ss++p2 } + .map { case c => LiftedC(c.projection.p, LiftedB(c.id, (c.s ++ c.s))) } + .filter { case LiftedC(_, LiftedB(id, _)) => id =!= 1 } + .sortBy { case LiftedC(Pair(_, p2), LiftedB(_, ss)) => ss ++ p2 } // returns: Vector(C(Pair(9,"z"),B(3,"bb")), C(Pair(8,"y"),B(2,"cc"))) //#combining-shapes - assert(Await.result(db.run(cs.schema.create >> insertActions2 >> q4.result), Duration.Inf) == Vector(C(Pair(9,"z"),B(3,"bb")), C(Pair(8,"y"),B(2,"cc")))) + assert( + Await.result( + db.run(cs.schema.create >> insertActions2 >> q4.result), + Duration.Inf) == Vector( + C(Pair(9, "z"), B(3, "bb")), + C(Pair(8, "y"), B(2, "cc")))) () } diff --git a/repos/spark/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/repos/spark/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index afbb9d974d4..52166dc2c27 100644 --- a/repos/spark/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/repos/spark/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -28,9 +28,16 @@ import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, DTStatsAggregator, - TimeTracker} +import org.apache.spark.mllib.tree.configuration.{ + Algo => OldAlgo, + Strategy => OldStrategy +} +import org.apache.spark.mllib.tree.impl.{ + BaggedPoint, + DecisionTreeMetadata, + DTStatsAggregator, + TimeTracker +} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD @@ -38,14 +45,13 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} - private[ml] object RandomForest extends Logging { /** - * Train a random forest. - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @return an unweighted set of trees - */ + * Train a random forest. + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @return an unweighted set of trees + */ def run( input: RDD[LabeledPoint], strategy: OldStrategy, @@ -62,7 +68,11 @@ private[ml] object RandomForest extends Logging { val retaggedInput = input.retag(classOf[LabeledPoint]) val metadata = - DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) + DecisionTreeMetadata.buildMetadata( + retaggedInput, + strategy, + numTrees, + featureSubsetStrategy) logDebug("algo = " + strategy.algo) logDebug("numTrees = " + numTrees) logDebug("seed = " + seed) @@ -77,9 +87,12 @@ private[ml] object RandomForest extends Logging { val splits = findSplits(retaggedInput, metadata, seed) timer.stop("findSplitsBins") logDebug("numBins: feature: number of bins") - logDebug(Range(0, metadata.numFeatures).map { featureIndex => - s"\t$featureIndex\t${metadata.numBins(featureIndex)}" - }.mkString("\n")) + logDebug( + Range(0, metadata.numFeatures) + .map { featureIndex => + s"\t$featureIndex\t${metadata.numBins(featureIndex)}" + } + .mkString("\n")) // Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. @@ -88,12 +101,18 @@ private[ml] object RandomForest extends Logging { val withReplacement = numTrees > 1 val baggedInput = BaggedPoint - .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed) + .convertToBaggedRDD( + treeInput, + strategy.subsamplingRate, + numTrees, + withReplacement, + seed) .persist(StorageLevel.MEMORY_AND_DISK) // depth of the decision tree val maxDepth = strategy.maxDepth - require(maxDepth <= 30, + require( + maxDepth <= 30, s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") // Max memory usage for aggregates @@ -101,19 +120,25 @@ private[ml] object RandomForest extends Logging { val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") val maxMemoryPerNode = { - val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { - // Find numFeaturesPerNode largest bins to get an upper bound on memory usage. - Some(metadata.numBins.zipWithIndex.sortBy(- _._1) - .take(metadata.numFeaturesPerNode).map(_._2)) - } else { - None - } + val featureSubset: Option[Array[Int]] = + if (metadata.subsamplingFeatures) { + // Find numFeaturesPerNode largest bins to get an upper bound on memory usage. + Some( + metadata.numBins.zipWithIndex + .sortBy(-_._1) + .take(metadata.numFeaturesPerNode) + .map(_._2)) + } else { + None + } RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L } - require(maxMemoryPerNode <= maxMemoryUsage, + require( + maxMemoryPerNode <= maxMemoryUsage, s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," + " which is too small for the given features." + - s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}") + s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}" + ) timer.stop("init") @@ -127,11 +152,12 @@ private[ml] object RandomForest extends Logging { // Create an RDD of node Id cache. // At first, all the rows belong to the root nodes (node Id == 1). val nodeIdCache = if (strategy.useNodeIdCache) { - Some(NodeIdCache.init( - data = baggedInput, - numTrees = numTrees, - checkpointInterval = strategy.checkpointInterval, - initVal = 1)) + Some( + NodeIdCache.init( + data = baggedInput, + numTrees = numTrees, + checkpointInterval = strategy.checkpointInterval, + initVal = 1)) } else { None } @@ -143,22 +169,37 @@ private[ml] object RandomForest extends Logging { rng.setSeed(seed) // Allocate and queue root nodes. - val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1)) - Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))) + val topNodes = + Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1)) + Range(0, numTrees).foreach(treeIndex => + nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))) while (nodeQueue.nonEmpty) { // Collect some nodes to split, and choose features for each node (if subsampling). // Each group of nodes may come from one or multiple trees, and at multiple levels. val (nodesForGroup, treeToNodeToIndexInfo) = - RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) + RandomForest.selectNodesToSplit( + nodeQueue, + maxMemoryUsage, + metadata, + rng) // Sanity check (should never occur): - assert(nodesForGroup.nonEmpty, + assert( + nodesForGroup.nonEmpty, s"RandomForest selected empty nodesForGroup. Error for unknown reason.") // Choose node splits, and enqueue new nodes as needed. timer.start("findBestSplits") - RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, - treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache) + RandomForest.findBestSplits( + baggedInput, + metadata, + topNodes, + nodesForGroup, + treeToNodeToIndexInfo, + splits, + nodeQueue, + timer, + nodeIdCache) timer.stop("findBestSplits") } @@ -175,7 +216,8 @@ private[ml] object RandomForest extends Logging { nodeIdCache.get.deleteAllCheckpoints() } catch { case e: IOException => - logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}") + logWarning( + s"delete all checkpoints failed. Error reason: ${e.getMessage}") } } @@ -185,7 +227,10 @@ private[ml] object RandomForest extends Logging { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures, + new DecisionTreeClassificationModel( + uid, + rootNode.toNode, + numFeatures, strategy.getNumClasses) } } else { @@ -196,29 +241,32 @@ private[ml] object RandomForest extends Logging { case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toNode, numFeatures, + new DecisionTreeClassificationModel( + rootNode.toNode, + numFeatures, strategy.getNumClasses) } } else { - topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode, numFeatures)) + topNodes.map(rootNode => + new DecisionTreeRegressionModel(rootNode.toNode, numFeatures)) } } } /** - * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features. - * - * For ordered features, a single bin is updated. - * For unordered features, bins correspond to subsets of categories; either the left or right bin - * for each subset is updated. - * - * @param agg Array storing aggregate calculation, with a set of sufficient statistics for - * each (feature, bin). - * @param treePoint Data point being aggregated. - * @param splits possible splits indexed (numFeatures)(numSplits) - * @param unorderedFeatures Set of indices of unordered features. - * @param instanceWeight Weight (importance) of instance in dataset. - */ + * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features. + * + * For ordered features, a single bin is updated. + * For unordered features, bins correspond to subsets of categories; either the left or right bin + * for each subset is updated. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (feature, bin). + * @param treePoint Data point being aggregated. + * @param splits possible splits indexed (numFeatures)(numSplits) + * @param unorderedFeatures Set of indices of unordered features. + * @param instanceWeight Weight (importance) of instance in dataset. + */ private def mixedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, @@ -250,8 +298,14 @@ private[ml] object RandomForest extends Logging { val featureSplits = splits(featureIndex) var splitIndex = 0 while (splitIndex < numSplits) { - if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { - agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) + if (featureSplits(splitIndex).shouldGoLeft( + featureValue, + featureSplits)) { + agg.featureUpdate( + leftNodeFeatureOffset, + splitIndex, + treePoint.label, + instanceWeight) } splitIndex += 1 } @@ -265,15 +319,15 @@ private[ml] object RandomForest extends Logging { } /** - * Helper for binSeqOp, for regression and for classification with only ordered features. - * - * For each feature, the sufficient statistics of one bin are updated. - * - * @param agg Array storing aggregate calculation, with a set of sufficient statistics for - * each (feature, bin). - * @param treePoint Data point being aggregated. - * @param instanceWeight Weight (importance) of instance in dataset. - */ + * Helper for binSeqOp, for regression and for classification with only ordered features. + * + * For each feature, the sufficient statistics of one bin are updated. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (feature, bin). + * @param treePoint Data point being aggregated. + * @param instanceWeight Weight (importance) of instance in dataset. + */ private def orderedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, @@ -286,7 +340,8 @@ private[ml] object RandomForest extends Logging { // Use subsampled features var featureIndexIdx = 0 while (featureIndexIdx < featuresForNode.get.length) { - val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx)) + val binIndex = + treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx)) agg.update(featureIndexIdx, binIndex, label, instanceWeight) featureIndexIdx += 1 } @@ -303,24 +358,24 @@ private[ml] object RandomForest extends Logging { } /** - * Given a group of nodes, this finds the best split for each node. - * - * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] - * @param metadata Learning and dataset metadata - * @param topNodes Root node for each tree. Used for matching instances with nodes. - * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree - * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, - * where nodeIndexInfo stores the index in the group and the - * feature subsets (if using feature subsets). - * @param splits possible splits for all features, indexed (numFeatures)(numSplits) - * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). - * Updated with new non-leaf nodes which are created. - * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where - * each value in the array is the data point's node Id - * for a corresponding tree. This is used to prevent the need - * to pass the entire tree to the executors during - * the node stat aggregation phase. - */ + * Given a group of nodes, this finds the best split for each node. + * + * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] + * @param metadata Learning and dataset metadata + * @param topNodes Root node for each tree. Used for matching instances with nodes. + * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree + * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, + * where nodeIndexInfo stores the index in the group and the + * feature subsets (if using feature subsets). + * @param splits possible splits for all features, indexed (numFeatures)(numSplits) + * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). + * Updated with new non-leaf nodes which are created. + * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where + * each value in the array is the data point's node Id + * for a corresponding tree. This is used to prevent the need + * to pass the entire tree to the executors during + * the node stat aggregation phase. + */ private[tree] def findBestSplits( input: RDD[BaggedPoint[TreePoint]], metadata: DecisionTreeMetadata, @@ -360,22 +415,23 @@ private[ml] object RandomForest extends Logging { logDebug("numFeatures = " + metadata.numFeatures) logDebug("numClasses = " + metadata.numClasses) logDebug("isMulticlass = " + metadata.isMulticlass) - logDebug("isMulticlassWithCategoricalFeatures = " + - metadata.isMulticlassWithCategoricalFeatures) + logDebug( + "isMulticlassWithCategoricalFeatures = " + + metadata.isMulticlassWithCategoricalFeatures) logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString) /** - * Performs a sequential aggregation over a partition for a particular tree and node. - * - * For each feature, the aggregate sufficient statistics are updated for the relevant - * bins. - * - * @param treeIndex Index of the tree that we want to perform aggregation for. - * @param nodeInfo The node info for the tree node. - * @param agg Array storing aggregate calculation, with a set of sufficient statistics - * for each (node, feature, bin). - * @param baggedPoint Data point being aggregated. - */ + * Performs a sequential aggregation over a partition for a particular tree and node. + * + * For each feature, the aggregate sufficient statistics are updated for the relevant + * bins. + * + * @param treeIndex Index of the tree that we want to perform aggregation for. + * @param nodeInfo The node info for the tree node. + * @param agg Array storing aggregate calculation, with a set of sufficient statistics + * for each (node, feature, bin). + * @param baggedPoint Data point being aggregated. + */ def nodeBinSeqOp( treeIndex: Int, nodeInfo: NodeIndexInfo, @@ -386,58 +442,80 @@ private[ml] object RandomForest extends Logging { val featuresForNode = nodeInfo.featureSubset val instanceWeight = baggedPoint.subsampleWeights(treeIndex) if (metadata.unorderedFeatures.isEmpty) { - orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) + orderedBinSeqOp( + agg(aggNodeIndex), + baggedPoint.datum, + instanceWeight, + featuresForNode) } else { - mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, - metadata.unorderedFeatures, instanceWeight, featuresForNode) + mixedBinSeqOp( + agg(aggNodeIndex), + baggedPoint.datum, + splits, + metadata.unorderedFeatures, + instanceWeight, + featuresForNode) } agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight) } } /** - * Performs a sequential aggregation over a partition. - * - * Each data point contributes to one node. For each feature, - * the aggregate sufficient statistics are updated for the relevant bins. - * - * @param agg Array storing aggregate calculation, with a set of sufficient statistics for - * each (node, feature, bin). - * @param baggedPoint Data point being aggregated. - * @return agg - */ + * Performs a sequential aggregation over a partition. + * + * Each data point contributes to one node. For each feature, + * the aggregate sufficient statistics are updated for the relevant bins. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (node, feature, bin). + * @param baggedPoint Data point being aggregated. + * @return agg + */ def binSeqOp( agg: Array[DTStatsAggregator], baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = { - treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => - val nodeIndex = topNodes(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits) - nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) + treeToNodeToIndexInfo.foreach { + case (treeIndex, nodeIndexToInfo) => + val nodeIndex = topNodes(treeIndex) + .predictImpl(baggedPoint.datum.binnedFeatures, splits) + nodeBinSeqOp( + treeIndex, + nodeIndexToInfo.getOrElse(nodeIndex, null), + agg, + baggedPoint) } agg } /** - * Do the same thing as binSeqOp, but with nodeIdCache. - */ + * Do the same thing as binSeqOp, but with nodeIdCache. + */ def binSeqOpWithNodeIdCache( agg: Array[DTStatsAggregator], - dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = { - treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => - val baggedPoint = dataPoint._1 - val nodeIdCache = dataPoint._2 - val nodeIndex = nodeIdCache(treeIndex) - nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) + dataPoint: (BaggedPoint[TreePoint], Array[Int])) + : Array[DTStatsAggregator] = { + treeToNodeToIndexInfo.foreach { + case (treeIndex, nodeIndexToInfo) => + val baggedPoint = dataPoint._1 + val nodeIdCache = dataPoint._2 + val nodeIndex = nodeIdCache(treeIndex) + nodeBinSeqOp( + treeIndex, + nodeIndexToInfo.getOrElse(nodeIndex, null), + agg, + baggedPoint) } agg } /** - * Get node index in group --> features indices map, - * which is a short cut to find feature indices for a node given node index in group. - */ + * Get node index in group --> features indices map, + * which is a short cut to find feature indices for a node given node index in group. + */ def getNodeToFeatures( - treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): Option[Map[Int, Array[Int]]] = { + treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) + : Option[Map[Int, Array[Int]]] = { if (!metadata.subsamplingFeatures) { None } else { @@ -445,7 +523,8 @@ private[ml] object RandomForest extends Logging { treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo => nodeIdToNodeInfo.values.foreach { nodeIndexInfo => assert(nodeIndexInfo.featureSubset.isDefined) - mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get + mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = + nodeIndexInfo.featureSubset.get } } Some(mutableNodeToFeatures.toMap) @@ -454,10 +533,12 @@ private[ml] object RandomForest extends Logging { // array of nodes to train indexed by node index in group val nodes = new Array[LearningNode](numNodes) - nodesForGroup.foreach { case (treeIndex, nodesForTree) => - nodesForTree.foreach { node => - nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node - } + nodesForGroup.foreach { + case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = + node + } } // Calculate best splits for all nodes in the group @@ -472,112 +553,127 @@ private[ml] object RandomForest extends Logging { val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo) val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures) - val partitionAggregates: RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { - input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points => - // Construct a nodeStatsAggregators array to hold node aggregate stats, - // each node will have a nodeStatsAggregator - val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => - val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures => - nodeToFeatures(nodeIndex) + val partitionAggregates: RDD[(Int, DTStatsAggregator)] = + if (nodeIdCache.nonEmpty) { + input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points => + // Construct a nodeStatsAggregators array to hold node aggregate stats, + // each node will have a nodeStatsAggregator + val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => + val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures => + nodeToFeatures(nodeIndex) + } + new DTStatsAggregator(metadata, featuresForNode) } - new DTStatsAggregator(metadata, featuresForNode) - } - // iterator all instances in current partition and update aggregate stats - points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _)) + // iterator all instances in current partition and update aggregate stats + points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _)) - // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, - // which can be combined with other partition using `reduceByKey` - nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator - } - } else { - input.mapPartitions { points => - // Construct a nodeStatsAggregators array to hold node aggregate stats, - // each node will have a nodeStatsAggregator - val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => - val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => - Some(nodeToFeatures(nodeIndex)) - } - new DTStatsAggregator(metadata, featuresForNode) + // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, + // which can be combined with other partition using `reduceByKey` + nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator } + } else { + input.mapPartitions { points => + // Construct a nodeStatsAggregators array to hold node aggregate stats, + // each node will have a nodeStatsAggregator + val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => + val featuresForNode = nodeToFeaturesBc.value.flatMap { + nodeToFeatures => Some(nodeToFeatures(nodeIndex)) + } + new DTStatsAggregator(metadata, featuresForNode) + } - // iterator all instances in current partition and update aggregate stats - points.foreach(binSeqOp(nodeStatsAggregators, _)) + // iterator all instances in current partition and update aggregate stats + points.foreach(binSeqOp(nodeStatsAggregators, _)) - // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, - // which can be combined with other partition using `reduceByKey` - nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator + // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, + // which can be combined with other partition using `reduceByKey` + nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator + } } - } - val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map { - case (nodeIndex, aggStats) => - val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => - Some(nodeToFeatures(nodeIndex)) - } + val nodeToBestSplits = partitionAggregates + .reduceByKey((a, b) => a.merge(b)) + .map { + case (nodeIndex, aggStats) => + val featuresForNode = nodeToFeaturesBc.value.flatMap { + nodeToFeatures => Some(nodeToFeatures(nodeIndex)) + } - // find best split for each node - val (split: Split, stats: ImpurityStats) = - binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) - (nodeIndex, (split, stats)) - }.collectAsMap() + // find best split for each node + val (split: Split, stats: ImpurityStats) = + binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) + (nodeIndex, (split, stats)) + } + .collectAsMap() timer.stop("chooseSplits") val nodeIdUpdaters = if (nodeIdCache.nonEmpty) { - Array.fill[mutable.Map[Int, NodeIndexUpdater]]( - metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]()) + Array.fill[mutable.Map[Int, NodeIndexUpdater]](metadata.numTrees)( + mutable.Map[Int, NodeIndexUpdater]()) } else { null } // Iterate over all nodes in this group. - nodesForGroup.foreach { case (treeIndex, nodesForTree) => - nodesForTree.foreach { node => - val nodeIndex = node.id - val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) - val aggNodeIndex = nodeInfo.nodeIndexInGroup - val (split: Split, stats: ImpurityStats) = - nodeToBestSplits(aggNodeIndex) - logDebug("best split = " + split) - - // Extract info for this node. Create children if not leaf. - val isLeaf = - (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth) - node.isLeaf = isLeaf - node.stats = stats - logDebug("Node = " + node) - - if (!isLeaf) { - node.split = Some(split) - val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth - val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) - val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) - node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex), - leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator))) - node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex), - rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator))) - - if (nodeIdCache.nonEmpty) { - val nodeIndexUpdater = NodeIndexUpdater( - split = split, - nodeIndex = nodeIndex) - nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater) - } + nodesForGroup.foreach { + case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + val nodeIndex = node.id + val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val (split: Split, stats: ImpurityStats) = + nodeToBestSplits(aggNodeIndex) + logDebug("best split = " + split) + + // Extract info for this node. Create children if not leaf. + val isLeaf = + (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth) + node.isLeaf = isLeaf + node.stats = stats + logDebug("Node = " + node) + + if (!isLeaf) { + node.split = Some(split) + val childIsLeaf = + (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth + val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) + val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) + node.leftChild = Some( + LearningNode( + LearningNode.leftChildIndex(nodeIndex), + leftChildIsLeaf, + ImpurityStats.getEmptyImpurityStats( + stats.leftImpurityCalculator))) + node.rightChild = Some( + LearningNode( + LearningNode.rightChildIndex(nodeIndex), + rightChildIsLeaf, + ImpurityStats.getEmptyImpurityStats( + stats.rightImpurityCalculator))) + + if (nodeIdCache.nonEmpty) { + val nodeIndexUpdater = + NodeIndexUpdater(split = split, nodeIndex = nodeIndex) + nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater) + } - // enqueue left child and right child if they are not leaves - if (!leftChildIsLeaf) { - nodeQueue.enqueue((treeIndex, node.leftChild.get)) - } - if (!rightChildIsLeaf) { - nodeQueue.enqueue((treeIndex, node.rightChild.get)) - } + // enqueue left child and right child if they are not leaves + if (!leftChildIsLeaf) { + nodeQueue.enqueue((treeIndex, node.leftChild.get)) + } + if (!rightChildIsLeaf) { + nodeQueue.enqueue((treeIndex, node.rightChild.get)) + } - logDebug("leftChildIndex = " + node.leftChild.get.id + - ", impurity = " + stats.leftImpurity) - logDebug("rightChildIndex = " + node.rightChild.get.id + - ", impurity = " + stats.rightImpurity) + logDebug( + "leftChildIndex = " + node.leftChild.get.id + + ", impurity = " + stats.leftImpurity) + logDebug( + "rightChildIndex = " + node.rightChild.get.id + + ", impurity = " + stats.rightImpurity) + } } - } } if (nodeIdCache.nonEmpty) { @@ -587,14 +683,14 @@ private[ml] object RandomForest extends Logging { } /** - * Calculate the impurity statistics for a give (feature, split) based upon left/right aggregates. - * @param stats the recycle impurity statistics for this feature's all splits, - * only 'impurity' and 'impurityCalculator' are valid between each iteration - * @param leftImpurityCalculator left node aggregates for this (feature, split) - * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @param metadata learning and dataset metadata for DecisionTree - * @return Impurity statistics for this (feature, split) - */ + * Calculate the impurity statistics for a give (feature, split) based upon left/right aggregates. + * @param stats the recycle impurity statistics for this feature's all splits, + * only 'impurity' and 'impurityCalculator' are valid between each iteration + * @param leftImpurityCalculator left node aggregates for this (feature, split) + * @param rightImpurityCalculator right node aggregate for this (feature, split) + * @param metadata learning and dataset metadata for DecisionTree + * @return Impurity statistics for this (feature, split) + */ private def calculateImpurityStats( stats: ImpurityStats, leftImpurityCalculator: ImpurityCalculator, @@ -621,7 +717,7 @@ private[ml] object RandomForest extends Logging { // If left child or right child doesn't satisfy minimum instances per node, // then this split is invalid, return invalid information gain stats. if ((leftCount < metadata.minInstancesPerNode) || - (rightCount < metadata.minInstancesPerNode)) { + (rightCount < metadata.minInstancesPerNode)) { return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) } @@ -631,7 +727,8 @@ private[ml] object RandomForest extends Logging { val leftWeight = leftCount / totalCount.toDouble val rightWeight = rightCount / totalCount.toDouble - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + val gain = + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity // if information gain doesn't satisfy minimum information gain, // then this split is invalid, return invalid information gain stats. @@ -639,15 +736,19 @@ private[ml] object RandomForest extends Logging { return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) } - new ImpurityStats(gain, impurity, parentImpurityCalculator, - leftImpurityCalculator, rightImpurityCalculator) + new ImpurityStats( + gain, + impurity, + parentImpurityCalculator, + leftImpurityCalculator, + rightImpurityCalculator) } /** - * Find the best split for a node. - * @param binAggregates Bin statistics. - * @return tuple for best split: (Split, information gain, prediction at node) - */ + * Find the best split for a node. + * @param binAggregates Bin statistics. + * @return tuple for best split: (Split, information gain, prediction at node) + */ private[tree] def binsToBestSplit( binAggregates: DTStatsAggregator, splits: Array[Array[Split]], @@ -664,159 +765,209 @@ private[ml] object RandomForest extends Logging { // For each (feature, split), calculate the gain, and select the best (feature, split). val (bestSplit, bestSplitStats) = - Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => - val featureIndex = if (featuresForNode.nonEmpty) { - featuresForNode.get.apply(featureIndexIdx) - } else { - featureIndexIdx - } - val numSplits = binAggregates.metadata.numSplits(featureIndex) - if (binAggregates.metadata.isContinuous(featureIndex)) { - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - var splitIndex = 0 - while (splitIndex < numSplits) { - binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) - splitIndex += 1 + Range(0, binAggregates.metadata.numFeaturesPerNode) + .map { featureIndexIdx => + val featureIndex = if (featuresForNode.nonEmpty) { + featuresForNode.get.apply(featureIndexIdx) + } else { + featureIndexIdx } - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { case splitIdx => - val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIdx, gainAndImpurityStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else if (binAggregates.metadata.isUnordered(featureIndex)) { - // Unordered categorical feature - val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = binAggregates.getParentImpurityCalculator() - .subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else { - // Ordered categorical feature - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val numCategories = binAggregates.metadata.numBins(featureIndex) - - /* Each bin is one category (feature value). - * The bins are ordered based on centroidForCategories, and this ordering determines which - * splits are considered. (With K categories, we consider K - 1 possible splits.) - * - * centroidForCategories is a list: (category, centroid) - */ - val centroidForCategories = Range(0, numCategories).map { case featureValue => - val categoryStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { - if (binAggregates.metadata.isMulticlass) { - // multiclass classification - // For categorical variables in multiclass classification, - // the bins are ordered by the impurity of their corresponding labels. - categoryStats.calculate() - } else if (binAggregates.metadata.isClassification) { - // binary classification - // For categorical variables in binary classification, - // the bins are ordered by the count of class 1. - categoryStats.stats(1) - } else { - // regression - // For categorical variables in regression and binary classification, - // the bins are ordered by the prediction. - categoryStats.predict - } - } else { - Double.MaxValue + val numSplits = binAggregates.metadata.numSplits(featureIndex) + if (binAggregates.metadata.isContinuous(featureIndex)) { + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + val nodeFeatureOffset = + binAggregates.getFeatureOffset(featureIndexIdx) + var splitIndex = 0 + while (splitIndex < numSplits) { + binAggregates.mergeForFeature( + nodeFeatureOffset, + splitIndex + 1, + splitIndex) + splitIndex += 1 + } + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits) + .map { + case splitIdx => + val leftChildStats = binAggregates.getImpurityCalculator( + nodeFeatureOffset, + splitIdx) + val rightChildStats = + binAggregates.getImpurityCalculator( + nodeFeatureOffset, + numSplits) + rightChildStats.subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats( + gainAndImpurityStats, + leftChildStats, + rightChildStats, + binAggregates.metadata) + (splitIdx, gainAndImpurityStats) + } + .maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else if (binAggregates.metadata.isUnordered(featureIndex)) { + // Unordered categorical feature + val leftChildOffset = + binAggregates.getFeatureOffset(featureIndexIdx) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits) + .map { splitIndex => + val leftChildStats = binAggregates.getImpurityCalculator( + leftChildOffset, + splitIndex) + val rightChildStats = binAggregates + .getParentImpurityCalculator() + .subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats( + gainAndImpurityStats, + leftChildStats, + rightChildStats, + binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + } + .maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else { + // Ordered categorical feature + val nodeFeatureOffset = + binAggregates.getFeatureOffset(featureIndexIdx) + val numCategories = binAggregates.metadata.numBins(featureIndex) + + /* Each bin is one category (feature value). + * The bins are ordered based on centroidForCategories, and this ordering determines which + * splits are considered. (With K categories, we consider K - 1 possible splits.) + * + * centroidForCategories is a list: (category, centroid) + */ + val centroidForCategories = Range(0, numCategories).map { + case featureValue => + val categoryStats = + binAggregates.getImpurityCalculator( + nodeFeatureOffset, + featureValue) + val centroid = if (categoryStats.count != 0) { + if (binAggregates.metadata.isMulticlass) { + // multiclass classification + // For categorical variables in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. + categoryStats.calculate() + } else if (binAggregates.metadata.isClassification) { + // binary classification + // For categorical variables in binary classification, + // the bins are ordered by the count of class 1. + categoryStats.stats(1) + } else { + // regression + // For categorical variables in regression and binary classification, + // the bins are ordered by the prediction. + categoryStats.predict + } + } else { + Double.MaxValue + } + (featureValue, centroid) } - (featureValue, centroid) - } - - logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) - - // bins sorted by centroids - val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) - - logDebug("Sorted centroids for categorical variable = " + - categoriesSortedByCentroid.mkString(",")) - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - var splitIndex = 0 - while (splitIndex < numSplits) { - val currentCategory = categoriesSortedByCentroid(splitIndex)._1 - val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 - binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) - splitIndex += 1 + logDebug( + "Centroids for categorical variable: " + centroidForCategories + .mkString(",")) + + // bins sorted by centroids + val categoriesSortedByCentroid = + centroidForCategories.toList.sortBy(_._2) + + logDebug( + "Sorted centroids for categorical variable = " + + categoriesSortedByCentroid.mkString(",")) + + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + var splitIndex = 0 + while (splitIndex < numSplits) { + val currentCategory = categoriesSortedByCentroid(splitIndex)._1 + val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 + binAggregates.mergeForFeature( + nodeFeatureOffset, + nextCategory, + currentCategory) + splitIndex += 1 + } + // lastCategory = index of bin with total aggregates for this (node, feature) + val lastCategory = categoriesSortedByCentroid.last._1 + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits) + .map { splitIndex => + val featureValue = categoriesSortedByCentroid(splitIndex)._1 + val leftChildStats = + binAggregates.getImpurityCalculator( + nodeFeatureOffset, + featureValue) + val rightChildStats = + binAggregates.getImpurityCalculator( + nodeFeatureOffset, + lastCategory) + rightChildStats.subtract(leftChildStats) + gainAndImpurityStats = calculateImpurityStats( + gainAndImpurityStats, + leftChildStats, + rightChildStats, + binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + } + .maxBy(_._2.gain) + val categoriesForSplit = + categoriesSortedByCentroid + .map(_._1.toDouble) + .slice(0, bestFeatureSplitIndex + 1) + val bestFeatureSplit = + new CategoricalSplit( + featureIndex, + categoriesForSplit.toArray, + numCategories) + (bestFeatureSplit, bestFeatureGainStats) } - // lastCategory = index of bin with total aggregates for this (node, feature) - val lastCategory = categoriesSortedByCentroid.last._1 - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val featureValue = categoriesSortedByCentroid(splitIndex)._1 - val leftChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - }.maxBy(_._2.gain) - val categoriesForSplit = - categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) - val bestFeatureSplit = - new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) - (bestFeatureSplit, bestFeatureGainStats) } - }.maxBy(_._2.gain) + .maxBy(_._2.gain) (bestSplit, bestSplitStats) } /** - * Returns splits and bins for decision tree calculation. - * Continuous and categorical features are handled differently. - * - * Continuous features: - * For each feature, there are numBins - 1 possible splits representing the possible binary - * decisions at each node in the tree. - * This finds locations (feature values) for splits using a subsample of the data. - * - * Categorical features: - * For each feature, there is 1 bin per split. - * Splits and bins are handled in 2 ways: - * (a) "unordered features" - * For multiclass classification with a low-arity feature - * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), - * the feature is split based on subsets of categories. - * (b) "ordered features" - * For regression and binary classification, - * and for multiclass classification with a high-arity feature, - * there is one bin per category. - * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @param metadata Learning and dataset metadata - * @param seed random seed - * @return A tuple of (splits, bins). - * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] - * of size (numFeatures, numSplits). - * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] - * of size (numFeatures, numBins). - */ + * Returns splits and bins for decision tree calculation. + * Continuous and categorical features are handled differently. + * + * Continuous features: + * For each feature, there are numBins - 1 possible splits representing the possible binary + * decisions at each node in the tree. + * This finds locations (feature values) for splits using a subsample of the data. + * + * Categorical features: + * For each feature, there is 1 bin per split. + * Splits and bins are handled in 2 ways: + * (a) "unordered features" + * For multiclass classification with a low-arity feature + * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), + * the feature is split based on subsets of categories. + * (b) "ordered features" + * For regression and binary classification, + * and for multiclass classification with a high-arity feature, + * there is one bin per category. + * + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param metadata Learning and dataset metadata + * @param seed random seed + * @return A tuple of (splits, bins). + * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] + * of size (numFeatures, numSplits). + * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] + * of size (numFeatures, numBins). + */ protected[tree] def findSplits( input: RDD[LabeledPoint], metadata: DecisionTreeMetadata, @@ -837,7 +988,10 @@ private[ml] object RandomForest extends Logging { 1.0 } logDebug("fraction of data used for calculating quantiles = " + fraction) - input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()) + input.sample( + withReplacement = false, + fraction, + new XORShiftRandom(seed).nextInt()) } else { input.sparkContext.emptyRDD[LabeledPoint] } @@ -854,17 +1008,23 @@ private[ml] object RandomForest extends Logging { // reduce the parallelism for split computations when there are less // continuous features than input partitions. this prevents tasks from // being spun up that will definitely do no work. - val numPartitions = math.min(continuousFeatures.length, input.partitions.length) + val numPartitions = + math.min(continuousFeatures.length, input.partitions.length) input - .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx)))) + .flatMap(point => + continuousFeatures.map(idx => (idx, point.features(idx)))) .groupByKey(numPartitions) - .map { case (idx, samples) => - val thresholds = findSplitsForContinuousFeature(samples, metadata, idx) - val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh)) - logDebug(s"featureIndex = $idx, numSplits = ${splits.length}") - (idx, splits) - }.collectAsMap() + .map { + case (idx, samples) => + val thresholds = + findSplitsForContinuousFeature(samples, metadata, idx) + val splits: Array[Split] = + thresholds.map(thresh => new ContinuousSplit(idx, thresh)) + logDebug(s"featureIndex = $idx, numSplits = ${splits.length}") + (idx, splits) + } + .collectAsMap() } val numFeatures = metadata.numFeatures @@ -879,7 +1039,8 @@ private[ml] object RandomForest extends Logging { // 2^(maxFeatureValue - 1) - 1 combinations val featureArity = metadata.featureArity(i) Array.tabulate[Split](metadata.numSplits(i)) { splitIndex => - val categories = extractMultiClassCategories(splitIndex + 1, featureArity) + val categories = + extractMultiClassCategories(splitIndex + 1, featureArity) new CategoricalSplit(i, categories.toArray, featureArity) } @@ -893,11 +1054,11 @@ private[ml] object RandomForest extends Logging { } /** - * Nested method to extract list of eligible categories given an index. It extracts the - * position of ones in a binary representation of the input. If binary - * representation of an number is 01101 (13), the output list should (3.0, 2.0, - * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones. - */ + * Nested method to extract list of eligible categories given an index. It extracts the + * position of ones in a binary representation of the input. If binary + * representation of an number is 01101 (13), the output list should (3.0, 2.0, + * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones. + */ private[tree] def extractMultiClassCategories( input: Int, maxFeatureValue: Int): List[Double] = { @@ -917,32 +1078,34 @@ private[ml] object RandomForest extends Logging { } /** - * Find splits for a continuous feature - * NOTE: Returned number of splits is set based on `featureSamples` and - * could be different from the specified `numSplits`. - * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly. - * @param featureSamples feature values of each sample - * @param metadata decision tree metadata - * NOTE: `metadata.numbins` will be changed accordingly - * if there are not enough splits to be found - * @param featureIndex feature index to find splits - * @return array of splits - */ + * Find splits for a continuous feature + * NOTE: Returned number of splits is set based on `featureSamples` and + * could be different from the specified `numSplits`. + * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly. + * @param featureSamples feature values of each sample + * @param metadata decision tree metadata + * NOTE: `metadata.numbins` will be changed accordingly + * if there are not enough splits to be found + * @param featureIndex feature index to find splits + * @return array of splits + */ private[tree] def findSplitsForContinuousFeature( featureSamples: Iterable[Double], metadata: DecisionTreeMetadata, featureIndex: Int): Array[Double] = { - require(metadata.isContinuous(featureIndex), + require( + metadata.isContinuous(featureIndex), "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") val splits = { val numSplits = metadata.numSplits(featureIndex) // get count for each distinct value - val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) { - case ((m, cnt), x) => - (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1) - } + val (valueCountMap, numSamples) = + featureSamples.foldLeft((Map.empty[Double, Int], 0)) { + case ((m, cnt), x) => + (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1) + } // sort distinct values val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray @@ -985,42 +1148,47 @@ private[ml] object RandomForest extends Logging { } // TODO: Do not fail; just ignore the useless feature. - assert(splits.length > 0, + assert( + splits.length > 0, s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + - " Please remove this feature and then try again.") + " Please remove this feature and then try again." + ) splits } private[tree] class NodeIndexInfo( val nodeIndexInGroup: Int, - val featureSubset: Option[Array[Int]]) extends Serializable + val featureSubset: Option[Array[Int]]) + extends Serializable /** - * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration. - * This tracks the memory usage for aggregates and stops adding nodes when too much memory - * will be needed; this allows an adaptive number of nodes since different nodes may require - * different amounts of memory (if featureSubsetStrategy is not "all"). - * - * @param nodeQueue Queue of nodes to split. - * @param maxMemoryUsage Bound on size of aggregate statistics. - * @return (nodesForGroup, treeToNodeToIndexInfo). - * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree. - * - * treeToNodeToIndexInfo holds indices selected features for each node: - * treeIndex --> (global) node index --> (node index in group, feature indices). - * The (global) node index is the index in the tree; the node index in group is the - * index in [0, numNodesInGroup) of the node in this group. - * The feature indices are None if not subsampling features. - */ + * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration. + * This tracks the memory usage for aggregates and stops adding nodes when too much memory + * will be needed; this allows an adaptive number of nodes since different nodes may require + * different amounts of memory (if featureSubsetStrategy is not "all"). + * + * @param nodeQueue Queue of nodes to split. + * @param maxMemoryUsage Bound on size of aggregate statistics. + * @return (nodesForGroup, treeToNodeToIndexInfo). + * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree. + * + * treeToNodeToIndexInfo holds indices selected features for each node: + * treeIndex --> (global) node index --> (node index in group, feature indices). + * The (global) node index is the index in the tree; the node index in group is the + * index in [0, numNodesInGroup) of the node in this group. + * The feature indices are None if not subsampling features. + */ private[tree] def selectNodesToSplit( nodeQueue: mutable.Queue[(Int, LearningNode)], maxMemoryUsage: Long, metadata: DecisionTreeMetadata, - rng: Random): (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = { + rng: Random) + : (Map[Int, Array[LearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = { // Collect some nodes to split: // nodesForGroup(treeIndex) = nodes to split - val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[LearningNode]]() + val mutableNodesForGroup = + new mutable.HashMap[Int, mutable.ArrayBuffer[LearningNode]]() val mutableTreeToNodeToIndexInfo = new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]() var memUsage: Long = 0L @@ -1028,21 +1196,32 @@ private[ml] object RandomForest extends Logging { while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) { val (treeIndex, node) = nodeQueue.head // Choose subset of features for node (if subsampling). - val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { - Some(SamplingUtils.reservoirSampleAndCount(Range(0, - metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1) - } else { - None - } + val featureSubset: Option[Array[Int]] = + if (metadata.subsamplingFeatures) { + Some( + SamplingUtils + .reservoirSampleAndCount( + Range(0, metadata.numFeatures).iterator, + metadata.numFeaturesPerNode, + rng.nextLong()) + ._1) + } else { + None + } // Check if enough memory remains to add this node to the group. - val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L + val nodeMemUsage = + RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L if (memUsage + nodeMemUsage <= maxMemoryUsage) { nodeQueue.dequeue() - mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[LearningNode]()) += + mutableNodesForGroup.getOrElseUpdate( + treeIndex, + new mutable.ArrayBuffer[LearningNode]()) += node mutableTreeToNodeToIndexInfo - .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id) - = new NodeIndexInfo(numNodesInGroup, featureSubset) + .getOrElseUpdate( + treeIndex, + new mutable.HashMap[Int, NodeIndexInfo]())(node.id) = + new NodeIndexInfo(numNodesInGroup, featureSubset) } numNodesInGroup += 1 memUsage += nodeMemUsage @@ -1050,20 +1229,23 @@ private[ml] object RandomForest extends Logging { // Convert mutable maps to immutable ones. val nodesForGroup: Map[Int, Array[LearningNode]] = mutableNodesForGroup.mapValues(_.toArray).toMap - val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap + val treeToNodeToIndexInfo = + mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap (nodesForGroup, treeToNodeToIndexInfo) } /** - * Get the number of values to be stored for this node in the bin aggregates. - * @param featureSubset Indices of features which may be split at this node. - * If None, then use all features. - */ + * Get the number of values to be stored for this node in the bin aggregates. + * @param featureSubset Indices of features which may be split at this node. + * If None, then use all features. + */ private def aggregateSizeForNode( metadata: DecisionTreeMetadata, featureSubset: Option[Array[Int]]): Long = { val totalBins = if (featureSubset.nonEmpty) { - featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum + featureSubset.get + .map(featureIndex => metadata.numBins(featureIndex).toLong) + .sum } else { metadata.numBins.map(_.toLong).sum } @@ -1075,25 +1257,27 @@ private[ml] object RandomForest extends Logging { } /** - * Given a Random Forest model, compute the importance of each feature. - * This generalizes the idea of "Gini" importance to other losses, - * following the explanation of Gini importance from "Random Forests" documentation - * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. - * - * This feature importance is calculated as follows: - * - Average over trees: - * - importance(feature j) = sum (over nodes which split on feature j) of the gain, - * where gain is scaled by the number of instances passing through node - * - Normalize importances for tree to sum to 1. - * - Normalize feature importance vector to sum to 1. - * - * @param trees Unweighted forest of trees - * @param numFeatures Number of features in model (even if not all are explicitly used by - * the model). - * If -1, then numFeatures is set based on the max feature index in all trees. - * @return Feature importance values, of length numFeatures. - */ - private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = { + * Given a Random Forest model, compute the importance of each feature. + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - Average over trees: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree to sum to 1. + * - Normalize feature importance vector to sum to 1. + * + * @param trees Unweighted forest of trees + * @param numFeatures Number of features in model (even if not all are explicitly used by + * the model). + * If -1, then numFeatures is set based on the max feature index in all trees. + * @return Feature importance values, of length numFeatures. + */ + private[ml] def featureImportances( + trees: Array[DecisionTreeModel], + numFeatures: Int): Vector = { val totalImportances = new OpenHashMap[Int, Double]() trees.foreach { tree => // Aggregate feature importance vector for this tree @@ -1103,9 +1287,10 @@ private[ml] object RandomForest extends Logging { // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count? val treeNorm = importances.map(_._2).sum if (treeNorm != 0) { - importances.foreach { case (idx, impt) => - val normImpt = impt / treeNorm - totalImportances.changeValue(idx, normImpt, _ + normImpt) + importances.foreach { + case (idx, impt) => + val normImpt = impt / treeNorm + totalImportances.changeValue(idx, normImpt, _ + normImpt) } } } @@ -1120,40 +1305,44 @@ private[ml] object RandomForest extends Logging { maxFeatureIndex + 1 } if (d == 0) { - assert(totalImportances.size == 0, s"Unknown error in computing feature" + - s" importance: No splits found, but some non-zero importances.") + assert( + totalImportances.size == 0, + s"Unknown error in computing feature" + + s" importance: No splits found, but some non-zero importances.") } val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip Vectors.sparse(d, indices.toArray, values.toArray) } /** - * Given a Decision Tree model, compute the importance of each feature. - * This generalizes the idea of "Gini" importance to other losses, - * following the explanation of Gini importance from "Random Forests" documentation - * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. - * - * This feature importance is calculated as follows: - * - importance(feature j) = sum (over nodes which split on feature j) of the gain, - * where gain is scaled by the number of instances passing through node - * - Normalize importances for tree to sum to 1. - * - * @param tree Decision tree to compute importances for. - * @param numFeatures Number of features in model (even if not all are explicitly used by - * the model). - * If -1, then numFeatures is set based on the max feature index in all trees. - * @return Feature importance values, of length numFeatures. - */ - private[ml] def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = { + * Given a Decision Tree model, compute the importance of each feature. + * This generalizes the idea of "Gini" importance to other losses, + * following the explanation of Gini importance from "Random Forests" documentation + * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + * + * This feature importance is calculated as follows: + * - importance(feature j) = sum (over nodes which split on feature j) of the gain, + * where gain is scaled by the number of instances passing through node + * - Normalize importances for tree to sum to 1. + * + * @param tree Decision tree to compute importances for. + * @param numFeatures Number of features in model (even if not all are explicitly used by + * the model). + * If -1, then numFeatures is set based on the max feature index in all trees. + * @return Feature importance values, of length numFeatures. + */ + private[ml] def featureImportances( + tree: DecisionTreeModel, + numFeatures: Int): Vector = { featureImportances(Array(tree), numFeatures) } /** - * Recursive method for computing feature importances for one tree. - * This walks down the tree, adding to the importance of 1 feature at each node. - * @param node Current node in recursion - * @param importances Aggregate feature importances, modified by this method - */ + * Recursive method for computing feature importances for one tree. + * This walks down the tree, adding to the importance of 1 feature at each node. + * @param node Current node in recursion + * @param importances Aggregate feature importances, modified by this method + */ private[impl] def computeFeatureImportance( node: Node, importances: OpenHashMap[Int, Double]): Unit = { @@ -1165,15 +1354,15 @@ private[ml] object RandomForest extends Logging { computeFeatureImportance(n.leftChild, importances) computeFeatureImportance(n.rightChild, importances) case n: LeafNode => - // do nothing + // do nothing } } /** - * Normalize the values of this map to sum to 1, in place. - * If all values are 0, this method does nothing. - * @param map Map with non-negative values. - */ + * Normalize the values of this map to sum to 1, in place. + * If all values are 0, this method does nothing. + * @param map Map with non-negative values. + */ private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = { val total = map.map(_._2).sum if (total != 0) { diff --git a/repos/spark/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/repos/spark/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 7ed1c51360f..11507597c83 100644 --- a/repos/spark/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/repos/spark/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -25,134 +25,144 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project} -import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource} +import org.apache.spark.sql.execution.datasources.{ + BucketSpec, + CreateTableUsingAsSelect, + DataSource +} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.sources.HadoopFsRelation /** - * :: Experimental :: - * Interface used to write a [[DataFrame]] to external storage systems (e.g. file systems, - * key-value stores, etc) or data streams. Use [[DataFrame.write]] to access this. - * - * @since 1.4.0 - */ + * :: Experimental :: + * Interface used to write a [[DataFrame]] to external storage systems (e.g. file systems, + * key-value stores, etc) or data streams. Use [[DataFrame.write]] to access this. + * + * @since 1.4.0 + */ @Experimental -final class DataFrameWriter private[sql](df: DataFrame) { +final class DataFrameWriter private[sql] (df: DataFrame) { /** - * Specifies the behavior when data or table already exists. Options include: - * - `SaveMode.Overwrite`: overwrite the existing data. - * - `SaveMode.Append`: append the data. - * - `SaveMode.Ignore`: ignore the operation (i.e. no-op). - * - `SaveMode.ErrorIfExists`: default option, throw an exception at runtime. - * - * @since 1.4.0 - */ + * Specifies the behavior when data or table already exists. Options include: + * - `SaveMode.Overwrite`: overwrite the existing data. + * - `SaveMode.Append`: append the data. + * - `SaveMode.Ignore`: ignore the operation (i.e. no-op). + * - `SaveMode.ErrorIfExists`: default option, throw an exception at runtime. + * + * @since 1.4.0 + */ def mode(saveMode: SaveMode): DataFrameWriter = { this.mode = saveMode this } /** - * Specifies the behavior when data or table already exists. Options include: - * - `overwrite`: overwrite the existing data. - * - `append`: append the data. - * - `ignore`: ignore the operation (i.e. no-op). - * - `error`: default option, throw an exception at runtime. - * - * @since 1.4.0 - */ + * Specifies the behavior when data or table already exists. Options include: + * - `overwrite`: overwrite the existing data. + * - `append`: append the data. + * - `ignore`: ignore the operation (i.e. no-op). + * - `error`: default option, throw an exception at runtime. + * + * @since 1.4.0 + */ def mode(saveMode: String): DataFrameWriter = { this.mode = saveMode.toLowerCase match { - case "overwrite" => SaveMode.Overwrite - case "append" => SaveMode.Append - case "ignore" => SaveMode.Ignore + case "overwrite" => SaveMode.Overwrite + case "append" => SaveMode.Append + case "ignore" => SaveMode.Ignore case "error" | "default" => SaveMode.ErrorIfExists - case _ => throw new IllegalArgumentException(s"Unknown save mode: $saveMode. " + - "Accepted modes are 'overwrite', 'append', 'ignore', 'error'.") + case _ => + throw new IllegalArgumentException( + s"Unknown save mode: $saveMode. " + + "Accepted modes are 'overwrite', 'append', 'ignore', 'error'.") } this } /** - * Specifies the underlying output data source. Built-in options include "parquet", "json", etc. - * - * @since 1.4.0 - */ + * Specifies the underlying output data source. Built-in options include "parquet", "json", etc. + * + * @since 1.4.0 + */ def format(source: String): DataFrameWriter = { this.source = source this } /** - * Adds an output option for the underlying data source. - * - * @since 1.4.0 - */ + * Adds an output option for the underlying data source. + * + * @since 1.4.0 + */ def option(key: String, value: String): DataFrameWriter = { this.extraOptions += (key -> value) this } /** - * Adds an output option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: Boolean): DataFrameWriter = option(key, value.toString) + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Boolean): DataFrameWriter = + option(key, value.toString) /** - * Adds an output option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: Long): DataFrameWriter = option(key, value.toString) + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Long): DataFrameWriter = + option(key, value.toString) /** - * Adds an output option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: Double): DataFrameWriter = option(key, value.toString) + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Double): DataFrameWriter = + option(key, value.toString) /** - * (Scala-specific) Adds output options for the underlying data source. - * - * @since 1.4.0 - */ - def options(options: scala.collection.Map[String, String]): DataFrameWriter = { + * (Scala-specific) Adds output options for the underlying data source. + * + * @since 1.4.0 + */ + def options( + options: scala.collection.Map[String, String]): DataFrameWriter = { this.extraOptions ++= options this } /** - * Adds output options for the underlying data source. - * - * @since 1.4.0 - */ + * Adds output options for the underlying data source. + * + * @since 1.4.0 + */ def options(options: java.util.Map[String, String]): DataFrameWriter = { this.options(options.asScala) this } /** - * Partitions the output by the given columns on the file system. If specified, the output is - * laid out on the file system similar to Hive's partitioning scheme. As an example, when we - * partition a dataset by year and then month, the directory layout would look like: - * - * - year=2016/month=01/ - * - year=2016/month=02/ - * - * Partitioning is one of the most widely used techniques to optimize physical data layout. - * It provides a coarse-grained index for skipping unnecessary data reads when queries have - * predicates on the partitioned columns. In order for partitioning to work well, the number - * of distinct values in each column should typically be less than tens of thousands. - * - * This was initially applicable for Parquet but in 1.5+ covers JSON, text, ORC and avro as well. - * - * @since 1.4.0 - */ + * Partitions the output by the given columns on the file system. If specified, the output is + * laid out on the file system similar to Hive's partitioning scheme. As an example, when we + * partition a dataset by year and then month, the directory layout would look like: + * + * - year=2016/month=01/ + * - year=2016/month=02/ + * + * Partitioning is one of the most widely used techniques to optimize physical data layout. + * It provides a coarse-grained index for skipping unnecessary data reads when queries have + * predicates on the partitioned columns. In order for partitioning to work well, the number + * of distinct values in each column should typically be less than tens of thousands. + * + * This was initially applicable for Parquet but in 1.5+ covers JSON, text, ORC and avro as well. + * + * @since 1.4.0 + */ @scala.annotation.varargs def partitionBy(colNames: String*): DataFrameWriter = { this.partitioningColumns = Option(colNames) @@ -160,27 +170,30 @@ final class DataFrameWriter private[sql](df: DataFrame) { } /** - * Buckets the output by the given columns. If specified, the output is laid out on the file - * system similar to Hive's bucketing scheme. - * - * This is applicable for Parquet, JSON and ORC. - * - * @since 2.0 - */ + * Buckets the output by the given columns. If specified, the output is laid out on the file + * system similar to Hive's bucketing scheme. + * + * This is applicable for Parquet, JSON and ORC. + * + * @since 2.0 + */ @scala.annotation.varargs - def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter = { + def bucketBy( + numBuckets: Int, + colName: String, + colNames: String*): DataFrameWriter = { this.numBuckets = Option(numBuckets) this.bucketColumnNames = Option(colName +: colNames) this } /** - * Sorts the output in each bucket by the given columns. - * - * This is applicable for Parquet, JSON and ORC. - * - * @since 2.0 - */ + * Sorts the output in each bucket by the given columns. + * + * This is applicable for Parquet, JSON and ORC. + * + * @since 2.0 + */ @scala.annotation.varargs def sortBy(colName: String, colNames: String*): DataFrameWriter = { this.sortColumnNames = Option(colName +: colNames) @@ -188,20 +201,20 @@ final class DataFrameWriter private[sql](df: DataFrame) { } /** - * Saves the content of the [[DataFrame]] at the specified path. - * - * @since 1.4.0 - */ + * Saves the content of the [[DataFrame]] at the specified path. + * + * @since 1.4.0 + */ def save(path: String): Unit = { this.extraOptions += ("path" -> path) save() } /** - * Saves the content of the [[DataFrame]] as the specified table. - * - * @since 1.4.0 - */ + * Saves the content of the [[DataFrame]] as the specified table. + * + * @since 1.4.0 + */ def save(): Unit = { assertNotBucketed() val dataSource = DataSource( @@ -215,34 +228,34 @@ final class DataFrameWriter private[sql](df: DataFrame) { } /** - * Specifies the name of the [[ContinuousQuery]] that can be started with `startStream()`. - * This name must be unique among all the currently active queries in the associated SQLContext. - * - * @since 2.0.0 - */ + * Specifies the name of the [[ContinuousQuery]] that can be started with `startStream()`. + * This name must be unique among all the currently active queries in the associated SQLContext. + * + * @since 2.0.0 + */ def queryName(queryName: String): DataFrameWriter = { this.extraOptions += ("queryName" -> queryName) this } /** - * Starts the execution of the streaming query, which will continually output results to the given - * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with - * the stream. - * - * @since 2.0.0 - */ + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with + * the stream. + * + * @since 2.0.0 + */ def startStream(path: String): ContinuousQuery = { option("path", path).startStream() } /** - * Starts the execution of the streaming query, which will continually output results to the given - * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with - * the stream. - * - * @since 2.0.0 - */ + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with + * the stream. + * + * @since 2.0.0 + */ def startStream(): ContinuousQuery = { val dataSource = DataSource( @@ -252,89 +265,107 @@ final class DataFrameWriter private[sql](df: DataFrame) { partitionColumns = normalizedParCols.getOrElse(Nil)) df.sqlContext.sessionState.continuousQueryManager.startQuery( - extraOptions.getOrElse("queryName", StreamExecution.nextName), df, dataSource.createSink()) + extraOptions.getOrElse("queryName", StreamExecution.nextName), + df, + dataSource.createSink()) } /** - * Inserts the content of the [[DataFrame]] to the specified table. It requires that - * the schema of the [[DataFrame]] is the same as the schema of the table. - * - * Because it inserts data to an existing table, format or options will be ignored. - * - * @since 1.4.0 - */ + * Inserts the content of the [[DataFrame]] to the specified table. It requires that + * the schema of the [[DataFrame]] is the same as the schema of the table. + * + * Because it inserts data to an existing table, format or options will be ignored. + * + * @since 1.4.0 + */ def insertInto(tableName: String): Unit = { - insertInto(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)) + insertInto( + df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)) } private def insertInto(tableIdent: TableIdentifier): Unit = { assertNotBucketed() - val partitions = normalizedParCols.map(_.map(col => col -> (None: Option[String])).toMap) + val partitions = + normalizedParCols.map(_.map(col => col -> (None: Option[String])).toMap) val overwrite = mode == SaveMode.Overwrite // A partitioned relation's schema can be different from the input logicalPlan, since // partition columns are all moved after data columns. We Project to adjust the ordering. // TODO: this belongs to the analyzer. - val input = normalizedParCols.map { parCols => - val (inputPartCols, inputDataCols) = df.logicalPlan.output.partition { attr => - parCols.contains(attr.name) + val input = normalizedParCols + .map { parCols => + val (inputPartCols, inputDataCols) = df.logicalPlan.output.partition { + attr => parCols.contains(attr.name) + } + Project(inputDataCols ++ inputPartCols, df.logicalPlan) } - Project(inputDataCols ++ inputPartCols, df.logicalPlan) - }.getOrElse(df.logicalPlan) - - df.sqlContext.executePlan( - InsertIntoTable( - UnresolvedRelation(tableIdent), - partitions.getOrElse(Map.empty[String, Option[String]]), - input, - overwrite, - ifNotExists = false)).toRdd + .getOrElse(df.logicalPlan) + + df.sqlContext + .executePlan( + InsertIntoTable( + UnresolvedRelation(tableIdent), + partitions.getOrElse(Map.empty[String, Option[String]]), + input, + overwrite, + ifNotExists = false)) + .toRdd } - private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => - cols.map(normalize(_, "Partition")) + private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { + cols => cols.map(normalize(_, "Partition")) } - private def normalizedBucketColNames: Option[Seq[String]] = bucketColumnNames.map { cols => - cols.map(normalize(_, "Bucketing")) - } + private def normalizedBucketColNames: Option[Seq[String]] = + bucketColumnNames.map { cols => cols.map(normalize(_, "Bucketing")) } - private def normalizedSortColNames: Option[Seq[String]] = sortColumnNames.map { cols => - cols.map(normalize(_, "Sorting")) - } + private def normalizedSortColNames: Option[Seq[String]] = + sortColumnNames.map { cols => cols.map(normalize(_, "Sorting")) } private def getBucketSpec: Option[BucketSpec] = { if (sortColumnNames.isDefined) { - require(numBuckets.isDefined, "sortBy must be used together with bucketBy") + require( + numBuckets.isDefined, + "sortBy must be used together with bucketBy") } for { n <- numBuckets } yield { - require(n > 0 && n < 100000, "Bucket number must be greater than 0 and less than 100000.") + require( + n > 0 && n < 100000, + "Bucket number must be greater than 0 and less than 100000.") // partitionBy columns cannot be used in bucketBy if (normalizedParCols.nonEmpty && - normalizedBucketColNames.get.toSet.intersect(normalizedParCols.get.toSet).nonEmpty) { - throw new AnalysisException( - s"bucketBy columns '${bucketColumnNames.get.mkString(", ")}' should not be part of " + + normalizedBucketColNames.get.toSet + .intersect(normalizedParCols.get.toSet) + .nonEmpty) { + throw new AnalysisException( + s"bucketBy columns '${bucketColumnNames.get.mkString(", ")}' should not be part of " + s"partitionBy columns '${partitioningColumns.get.mkString(", ")}'") } - BucketSpec(n, normalizedBucketColNames.get, normalizedSortColNames.getOrElse(Nil)) + BucketSpec( + n, + normalizedBucketColNames.get, + normalizedSortColNames.getOrElse(Nil)) } } /** - * The given column name may not be equal to any of the existing column names if we were in - * case-insensitive context. Normalize the given column name to the real one so that we don't - * need to care about case sensitivity afterwards. - */ + * The given column name may not be equal to any of the existing column names if we were in + * case-insensitive context. Normalize the given column name to the real one so that we don't + * need to care about case sensitivity afterwards. + */ private def normalize(columnName: String, columnType: String): String = { val validColumnNames = df.logicalPlan.output.map(_.name) - validColumnNames.find(df.sqlContext.sessionState.analyzer.resolver(_, columnName)) - .getOrElse(throw new AnalysisException(s"$columnType column $columnName not found in " + - s"existing columns (${validColumnNames.mkString(", ")})")) + validColumnNames + .find(df.sqlContext.sessionState.analyzer.resolver(_, columnName)) + .getOrElse( + throw new AnalysisException( + s"$columnType column $columnName not found in " + + s"existing columns (${validColumnNames.mkString(", ")})")) } private def assertNotBucketed(): Unit = { @@ -345,25 +376,26 @@ final class DataFrameWriter private[sql](df: DataFrame) { } /** - * Saves the content of the [[DataFrame]] as the specified table. - * - * In the case the table already exists, behavior of this function depends on the - * save mode, specified by the `mode` function (default to throwing an exception). - * When `mode` is `Overwrite`, the schema of the [[DataFrame]] does not need to be - * the same as that of the existing table. - * When `mode` is `Append`, the schema of the [[DataFrame]] need to be - * the same as that of the existing table, and format or options will be ignored. - * - * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input - * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC - * and Parquet), the table is persisted in a Hive compatible format, which means other systems - * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL - * specific format. - * - * @since 1.4.0 - */ + * Saves the content of the [[DataFrame]] as the specified table. + * + * In the case the table already exists, behavior of this function depends on the + * save mode, specified by the `mode` function (default to throwing an exception). + * When `mode` is `Overwrite`, the schema of the [[DataFrame]] does not need to be + * the same as that of the existing table. + * When `mode` is `Append`, the schema of the [[DataFrame]] need to be + * the same as that of the existing table, and format or options will be ignored. + * + * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input + * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC + * and Parquet), the table is persisted in a Hive compatible format, which means other systems + * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL + * specific format. + * + * @since 1.4.0 + */ def saveAsTable(tableName: String): Unit = { - saveAsTable(df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)) + saveAsTable( + df.sqlContext.sessionState.sqlParser.parseTableIdentifier(tableName)) } private def saveAsTable(tableIdent: TableIdentifier): Unit = { @@ -371,7 +403,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { (tableExists, mode) match { case (true, SaveMode.Ignore) => - // Do nothing + // Do nothing case (true, SaveMode.ErrorIfExists) => throw new AnalysisException(s"Table $tableIdent already exists.") @@ -386,30 +418,35 @@ final class DataFrameWriter private[sql](df: DataFrame) { getBucketSpec, mode, extraOptions.toMap, - df.logicalPlan) + df.logicalPlan + ) df.sqlContext.executePlan(cmd).toRdd } } /** - * Saves the content of the [[DataFrame]] to a external database table via JDBC. In the case the - * table already exists in the external database, behavior of this function depends on the - * save mode, specified by the `mode` function (default to throwing an exception). - * - * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash - * your external database systems. - * - * @param url JDBC database url of the form `jdbc:subprotocol:subname` - * @param table Name of the table in the external database. - * @param connectionProperties JDBC database connection arguments, a list of arbitrary string - * tag/value. Normally at least a "user" and "password" property - * should be included. - * @since 1.4.0 - */ - def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { + * Saves the content of the [[DataFrame]] to a external database table via JDBC. In the case the + * table already exists in the external database, behavior of this function depends on the + * save mode, specified by the `mode` function (default to throwing an exception). + * + * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash + * your external database systems. + * + * @param url JDBC database url of the form `jdbc:subprotocol:subname` + * @param table Name of the table in the external database. + * @param connectionProperties JDBC database connection arguments, a list of arbitrary string + * tag/value. Normally at least a "user" and "password" property + * should be included. + * @since 1.4.0 + */ + def jdbc( + url: String, + table: String, + connectionProperties: Properties): Unit = { val props = new Properties() - extraOptions.foreach { case (key, value) => - props.put(key, value) + extraOptions.foreach { + case (key, value) => + props.put(key, value) } // connectionProperties should override settings in extraOptions props.putAll(connectionProperties) @@ -450,89 +487,89 @@ final class DataFrameWriter private[sql](df: DataFrame) { } /** - * Saves the content of the [[DataFrame]] in JSON format at the specified path. - * This is equivalent to: - * {{{ - * format("json").save(path) - * }}} - * - * You can set the following JSON-specific option(s) for writing JSON files: - *
  • `compression` (default `null`): compression codec to use when saving to file. This can be - * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, - * `snappy` and `deflate`).
  • - * - * @since 1.4.0 - */ + * Saves the content of the [[DataFrame]] in JSON format at the specified path. + * This is equivalent to: + * {{{ + * format("json").save(path) + * }}} + * + * You can set the following JSON-specific option(s) for writing JSON files: + *
  • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, + * `snappy` and `deflate`).
  • + * + * @since 1.4.0 + */ def json(path: String): Unit = format("json").save(path) /** - * Saves the content of the [[DataFrame]] in Parquet format at the specified path. - * This is equivalent to: - * {{{ - * format("parquet").save(path) - * }}} - * - * You can set the following Parquet-specific option(s) for writing Parquet files: - *
  • `compression` (default `null`): compression codec to use when saving to file. This can be - * one of the known case-insensitive shorten names(`none`, `snappy`, `gzip`, and `lzo`). - * This will overwrite `spark.sql.parquet.compression.codec`.
  • - * - * @since 1.4.0 - */ + * Saves the content of the [[DataFrame]] in Parquet format at the specified path. + * This is equivalent to: + * {{{ + * format("parquet").save(path) + * }}} + * + * You can set the following Parquet-specific option(s) for writing Parquet files: + *
  • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names(`none`, `snappy`, `gzip`, and `lzo`). + * This will overwrite `spark.sql.parquet.compression.codec`.
  • + * + * @since 1.4.0 + */ def parquet(path: String): Unit = format("parquet").save(path) /** - * Saves the content of the [[DataFrame]] in ORC format at the specified path. - * This is equivalent to: - * {{{ - * format("orc").save(path) - * }}} - * - * You can set the following ORC-specific option(s) for writing ORC files: - *
  • `compression` (default `null`): compression codec to use when saving to file. This can be - * one of the known case-insensitive shorten names(`none`, `snappy`, `zlib`, and `lzo`). - * This will overwrite `orc.compress`.
  • - * - * @since 1.5.0 - * @note Currently, this method can only be used together with `HiveContext`. - */ + * Saves the content of the [[DataFrame]] in ORC format at the specified path. + * This is equivalent to: + * {{{ + * format("orc").save(path) + * }}} + * + * You can set the following ORC-specific option(s) for writing ORC files: + *
  • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names(`none`, `snappy`, `zlib`, and `lzo`). + * This will overwrite `orc.compress`.
  • + * + * @since 1.5.0 + * @note Currently, this method can only be used together with `HiveContext`. + */ def orc(path: String): Unit = format("orc").save(path) /** - * Saves the content of the [[DataFrame]] in a text file at the specified path. - * The DataFrame must have only one column that is of string type. - * Each row becomes a new line in the output file. For example: - * {{{ - * // Scala: - * df.write.text("/path/to/output") - * - * // Java: - * df.write().text("/path/to/output") - * }}} - * - * You can set the following option(s) for writing text files: - *
  • `compression` (default `null`): compression codec to use when saving to file. This can be - * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, - * `snappy` and `deflate`).
  • - * - * @since 1.6.0 - */ + * Saves the content of the [[DataFrame]] in a text file at the specified path. + * The DataFrame must have only one column that is of string type. + * Each row becomes a new line in the output file. For example: + * {{{ + * // Scala: + * df.write.text("/path/to/output") + * + * // Java: + * df.write().text("/path/to/output") + * }}} + * + * You can set the following option(s) for writing text files: + *
  • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, + * `snappy` and `deflate`).
  • + * + * @since 1.6.0 + */ def text(path: String): Unit = format("text").save(path) /** - * Saves the content of the [[DataFrame]] in CSV format at the specified path. - * This is equivalent to: - * {{{ - * format("csv").save(path) - * }}} - * - * You can set the following CSV-specific option(s) for writing CSV files: - *
  • `compression` (default `null`): compression codec to use when saving to file. This can be - * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, - * `snappy` and `deflate`).
  • - * - * @since 2.0.0 - */ + * Saves the content of the [[DataFrame]] in CSV format at the specified path. + * This is equivalent to: + * {{{ + * format("csv").save(path) + * }}} + * + * You can set the following CSV-specific option(s) for writing CSV files: + *
  • `compression` (default `null`): compression codec to use when saving to file. This can be + * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, + * `snappy` and `deflate`).
  • + * + * @since 2.0.0 + */ def csv(path: String): Unit = format("csv").save(path) /////////////////////////////////////////////////////////////////////////////////////// @@ -543,7 +580,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { private var mode: SaveMode = SaveMode.ErrorIfExists - private var extraOptions = new scala.collection.mutable.HashMap[String, String] + private var extraOptions = + new scala.collection.mutable.HashMap[String, String] private var partitioningColumns: Option[Seq[String]] = None