Skip to content

Commit

Permalink
[SPARK-33518][ML] Improve performance of ML ALS recommendForAll by GEMV
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
There were a lot of works on improving ALS's recommendForAll

For now, I found that it maybe futhermore optimized by

1, using GEMV and sharing a pre-allocated buffer per task;

2, using guava.ordering instead of BoundedPriorityQueue;

### Why are the changes needed?
In my test, using `f2jBLAS.sgemv`, it is about 2.3X faster than existing impl.

|Impl| Master | GEMM | GEMV | GEMV + array aggregator | GEMV + guava ordering + array aggregator  | GEMV + guava ordering|
|------|----------|------------|----------|------------|------------|------------|
|Duration|341229|363741|191201|189790|148417|147222|

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
existing testsuites

Closes #30468 from zhengruifeng/als_rec_opt.

Authored-by: zhengruifeng <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
  • Loading branch information
zhengruifeng authored and srowen committed Dec 19, 2020
1 parent de234ee commit 44563a0
Showing 1 changed file with 33 additions and 20 deletions.
53 changes: 33 additions & 20 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.util.{Sorting, Try}
import scala.util.hashing.byteswap64

import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.google.common.collect.{Ordering => GuavaOrdering}
import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
Expand All @@ -47,7 +48,7 @@ import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
import org.apache.spark.util.random.XORShiftRandom

Expand Down Expand Up @@ -456,30 +457,39 @@ class ALSModel private[ml] (
num: Int,
blockSize: Int): DataFrame = {
import srcFactors.sparkSession.implicits._
import scala.collection.JavaConverters._

val srcFactorsBlocked = blockify(srcFactors.as[(Int, Array[Float])], blockSize)
val dstFactorsBlocked = blockify(dstFactors.as[(Int, Array[Float])], blockSize)
val ratings = srcFactorsBlocked.crossJoin(dstFactorsBlocked)
.as[(Seq[(Int, Array[Float])], Seq[(Int, Array[Float])])]
.flatMap { case (srcIter, dstIter) =>
val m = srcIter.size
val n = math.min(dstIter.size, num)
val output = new Array[(Int, Int, Float)](m * n)
var i = 0
val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2))
srcIter.foreach { case (srcId, srcFactor) =>
dstIter.foreach { case (dstId, dstFactor) =>
// We use F2jBLAS which is faster than a call to native BLAS for vector dot product
val score = BLAS.f2jBLAS.sdot(rank, srcFactor, 1, dstFactor, 1)
pq += dstId -> score
.as[(Array[Int], Array[Float], Array[Int], Array[Float])]
.mapPartitions { iter =>
var scores: Array[Float] = null
var idxOrd: GuavaOrdering[Int] = null
iter.flatMap { case (srcIds, srcMat, dstIds, dstMat) =>
require(srcMat.length == srcIds.length * rank)
require(dstMat.length == dstIds.length * rank)
val m = srcIds.length
val n = dstIds.length
if (scores == null || scores.length < n) {
scores = Array.ofDim[Float](n)
idxOrd = new GuavaOrdering[Int] {
override def compare(left: Int, right: Int): Int = {
Ordering[Float].compare(scores(left), scores(right))
}
}
}
pq.foreach { case (dstId, score) =>
output(i) = (srcId, dstId, score)
i += 1

Iterator.range(0, m).flatMap { i =>
// buffer = i-th vec in srcMat * dstMat
BLAS.f2jBLAS.sgemv("T", rank, n, 1.0F, dstMat, 0, rank,
srcMat, i * rank, 1, 0.0F, scores, 0, 1)

val srcId = srcIds(i)
idxOrd.greatestOf(Iterator.range(0, n).asJava, num).asScala
.iterator.map { j => (srcId, dstIds(j), scores(j)) }
}
pq.clear()
}
output.toSeq
}
// We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output.
val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2))
Expand All @@ -499,9 +509,12 @@ class ALSModel private[ml] (
*/
private def blockify(
factors: Dataset[(Int, Array[Float])],
blockSize: Int): Dataset[Seq[(Int, Array[Float])]] = {
blockSize: Int): Dataset[(Array[Int], Array[Float])] = {
import factors.sparkSession.implicits._
factors.mapPartitions(_.grouped(blockSize))
factors.mapPartitions { iter =>
iter.grouped(blockSize)
.map(block => (block.map(_._1).toArray, block.flatMap(_._2).toArray))
}
}

}
Expand Down

0 comments on commit 44563a0

Please sign in to comment.