diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 4a5743650c547..60dd7367053e2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -379,20 +379,14 @@ class ALSModel private[ml] ( // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output. val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2)) val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn) - .toDF(srcOutputColumn, "recommendations") - - // There is some performance hit from converting the (Int, Float) tuples to - // (dstOutputColumn: Int, rating: Float) structs using .rdd. Need SPARK-16483 for a fix. - val schema = new StructType() - .add(srcOutputColumn, IntegerType) - .add("recommendations", - ArrayType( - StructType( - StructField(dstOutputColumn, IntegerType, nullable = false) :: - StructField("rating", FloatType, nullable = false) :: - Nil - ))) - recs.sparkSession.createDataFrame(recs.rdd, schema) + .toDF("id", "recommendations") + + val arrayType = ArrayType( + new StructType() + .add(dstOutputColumn, IntegerType) + .add("rating", FloatType) + ) + recs.select($"id" as srcOutputColumn, $"recommendations" cast arrayType) } }