Skip to content

Commit

Permalink
clean-ups, comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Feb 28, 2017
1 parent 832b066 commit ebd2604
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
32 changes: 18 additions & 14 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.CholeskyDecomposition
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -250,7 +251,8 @@ class ALSModel private[ml] (

private val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
if (userFeatures != null && itemFeatures != null) {
// TODO: try dot-producting on Seqs or another non-converted type for potential optimization
// TODO(SPARK-19759): try dot-producting on Seqs or another non-converted type for
// potential optimization.
blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1)
} else {
Float.NaN
Expand Down Expand Up @@ -288,48 +290,50 @@ class ALSModel private[ml] (
override def write: MLWriter = new ALSModel.ALSModelWriter(this)

/**
* Returns top `num` items recommended for each user, for all users.
* @param num number of recommendations for each user
* Returns top `numItems` items recommended for each user, for all users.
* @param numItems max number of recommendations for each user
* @return a DataFrame of (userCol: Int, recommendations), where recommendations are
* stored as an array of (itemId: Int, rating: Double) tuples.
*/
@Since("2.2.0")
def recommendForAllUsers(num: Int): DataFrame = {
recommendForAll(userFactors, itemFactors, $(userCol), num)
def recommendForAllUsers(numItems: Int): DataFrame = {
recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems)
}

/**
* Returns top `num` users recommended for each item, for all items.
* @param num number of recommendations for each item
* Returns top `numUsers` users recommended for each item, for all items.
* @param numUsers max number of recommendations for each item
* @return a DataFrame of (itemCol: Int, recommendations), where recommendations are
* stored as an array of (userId: Int, rating: Double) tuples.
*/
@Since("2.2.0")
def recommendForAllItems(num: Int): DataFrame = {
recommendForAll(itemFactors, userFactors, $(itemCol), num)
def recommendForAllItems(numUsers: Int): DataFrame = {
recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers)
}

/**
* Makes recommendations for all users (or items).
* @param srcFactors src factors for which to generate recommendations
* @param dstFactors dst factors used to make recommendations
* @param srcOutputColumn name of the column for the source in the output DataFrame
* @param num number of recommendations for each record
* @param dstOutputColumn name of the column for the destination in the output DataFrame
* @param num max number of recommendations for each record
* @return a DataFrame of (srcOutputColumn: Int, recommendations), where recommendations are
* stored as an array of (dstId: Int, rating: Double) tuples.
*/
private def recommendForAll(
srcFactors: DataFrame,
dstFactors: DataFrame,
srcOutputColumn: String,
dstOutputColumn: String,
num: Int): DataFrame = {
import srcFactors.sparkSession.implicits._

val ratings = srcFactors.crossJoin(dstFactors)
.select(
srcFactors("id").as("srcId"),
dstFactors("id").as("dstId"),
predict(srcFactors("features"), dstFactors("features")).as($(predictionCol)))
srcFactors("id"),
dstFactors("id"),
predict(srcFactors("features"), dstFactors("features")))
// We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output.
val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2))
ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ private[recommendation] class TopByKeyAggregator[K1: TypeTag, K2: TypeTag, V: Ty
Encoders.kryo[BoundedPriorityQueue[(K2, V)]]
}

override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]]
override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]]()
}

0 comments on commit ebd2604

Please sign in to comment.