diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd index 9b9e40321ea93..3cd0579aea8d3 100644 --- a/bin/compute-classpath.cmd +++ b/bin/compute-classpath.cmd @@ -38,7 +38,7 @@ if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" rem Build up classpath set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH% -if "x%SPARK_CONF_DIR%"!="x" ( +if not "x%SPARK_CONF_DIR%"=="x" ( set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR% ) else ( set CLASSPATH=%CLASSPATH%;%FWDIR%conf diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 445110d63e184..152bde5f6994f 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -51,6 +51,11 @@ table.sortable thead { cursor: pointer; } +table.sortable td { + word-wrap: break-word; + max-width: 600px; +} + .progress { margin-bottom: 0px; position: relative } diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index f8584b90cabe6..d89bb50076c9a 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -168,8 +168,6 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { arr.iterator.asInstanceOf[Iterator[T]] case Right(it) => // There is not enough space to cache this partition in memory - logWarning(s"Not enough space to cache partition $key in memory! " + - s"Free memory is ${blockManager.memoryStore.freeMemory} bytes.") val returnValues = it.asInstanceOf[Iterator[T]] if (putLevel.useDisk) { logWarning(s"Persisting partition $key to disk instead.") diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 97109b9f41b60..396cdd1247e07 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -779,20 +779,20 @@ class SparkContext(config: SparkConf) extends Logging { /** * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values * with `+=`. Only the driver can access the accumuable's `value`. - * @tparam T accumulator type - * @tparam R type that can be added to the accumulator + * @tparam R accumulator result type + * @tparam T type that can be added to the accumulator */ - def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) = + def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) = new Accumulable(initialValue, param) /** * Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the * Spark UI. Tasks can add values to the accumuable using the `+=` operator. Only the driver can * access the accumuable's `value`. - * @tparam T accumulator type - * @tparam R type that can be added to the accumulator + * @tparam R accumulator result type + * @tparam T type that can be added to the accumulator */ - def accumulable[T, R](initialValue: T, name: String)(implicit param: AccumulableParam[T, R]) = + def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) = new Accumulable(initialValue, param, Some(name)) /** diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 72cac42cd2b2b..aba713cb4267a 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -43,9 +43,8 @@ import org.apache.spark.util.{AkkaUtils, Utils} * :: DeveloperApi :: * Holds all the runtime environment objects for a running Spark instance (either master or worker), * including the serializer, Akka actor system, block manager, map output tracker, etc. Currently - * Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these - * objects needs to have the right SparkEnv set. You can get the current environment with - * SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set. + * Spark code finds the SparkEnv through a global variable, so all the threads can access the same + * SparkEnv. It can be accessed by SparkEnv.get (e.g. after creating a SparkContext). * * NOTE: This is not intended for external use. This is exposed for Shark and may be made private * in a future release. @@ -119,30 +118,28 @@ class SparkEnv ( } object SparkEnv extends Logging { - private val env = new ThreadLocal[SparkEnv] - @volatile private var lastSetSparkEnv : SparkEnv = _ + @volatile private var env: SparkEnv = _ private[spark] val driverActorSystemName = "sparkDriver" private[spark] val executorActorSystemName = "sparkExecutor" def set(e: SparkEnv) { - lastSetSparkEnv = e - env.set(e) + env = e } /** - * Returns the ThreadLocal SparkEnv, if non-null. Else returns the SparkEnv - * previously set in any thread. + * Returns the SparkEnv. */ def get: SparkEnv = { - Option(env.get()).getOrElse(lastSetSparkEnv) + env } /** * Returns the ThreadLocal SparkEnv. */ + @deprecated("Use SparkEnv.get instead", "1.2") def getThreadLocal: SparkEnv = { - env.get() + env } private[spark] def create( diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 924141475383d..c74f86548ef85 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -196,7 +196,6 @@ private[spark] class PythonRDD( override def run(): Unit = Utils.logUncaughtExceptions { try { - SparkEnv.set(env) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // Partition index @@ -248,6 +247,11 @@ private[spark] class PythonRDD( // will kill the whole executor (see org.apache.spark.executor.Executor). _exception = e worker.shutdownOutput() + } finally { + // Release memory used by this thread for shuffles + env.shuffleMemoryManager.releaseMemoryForThisThread() + // Release memory used by this thread for unrolling blocks + env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() } } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 942dc7d7eac87..4cd4f4f96fd16 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -163,18 +163,23 @@ private[broadcast] object HttpBroadcast extends Logging { private def write(id: Long, value: Any) { val file = getFile(id) - val out: OutputStream = { - if (compress) { - compressionCodec.compressedOutputStream(new FileOutputStream(file)) - } else { - new BufferedOutputStream(new FileOutputStream(file), bufferSize) + val fileOutputStream = new FileOutputStream(file) + try { + val out: OutputStream = { + if (compress) { + compressionCodec.compressedOutputStream(fileOutputStream) + } else { + new BufferedOutputStream(fileOutputStream, bufferSize) + } } + val ser = SparkEnv.get.serializer.newInstance() + val serOut = ser.serializeStream(out) + serOut.writeObject(value) + serOut.close() + files += file + } finally { + fileOutputStream.close() } - val ser = SparkEnv.get.serializer.newInstance() - val serOut = ser.serializeStream(out) - serOut.writeObject(value) - serOut.close() - files += file } private def read[T: ClassTag](id: Long): T = { diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 065ddda50e65e..f2687ce6b42b4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -130,7 +130,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") System.exit(-1) - case AssociationErrorEvent(cause, _, remoteAddress, _) => + case AssociationErrorEvent(cause, _, remoteAddress, _, _) => println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") println(s"Cause was: $cause") System.exit(-1) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 32790053a6be8..98a93d1fcb2a3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -154,7 +154,7 @@ private[spark] class AppClient( logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() - case AssociationErrorEvent(cause, _, address, _) if isPossibleMaster(address) => + case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => logWarning(s"Could not connect to $address: $cause") case StopAppClient => diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index aa85aa060d9c1..08a99bbe68578 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -83,15 +83,21 @@ private[spark] class FileSystemPersistenceEngine( val serialized = serializer.toBinary(value) val out = new FileOutputStream(file) - out.write(serialized) - out.close() + try { + out.write(serialized) + } finally { + out.close() + } } def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = { val fileData = new Array[Byte](file.length().asInstanceOf[Int]) val dis = new DataInputStream(new FileInputStream(file)) - dis.readFully(fileData) - dis.close() + try { + dis.readFully(fileData) + } finally { + dis.close() + } val clazz = m.runtimeClass.asInstanceOf[Class[T]] val serializer = serialization.serializerFor(clazz) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 3b13f43a1868c..9b52cb06fb6fa 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -29,7 +29,6 @@ import scala.language.postfixOps import akka.actor._ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} -import org.apache.commons.io.FileUtils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 6d0d0bbe5ecec..63a8ac817b618 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -54,7 +54,7 @@ private[spark] class WorkerWatcher(workerUrl: String) case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => logInfo(s"Successfully connected to $workerUrl") - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound) + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) if isWorker(remoteAddress) => // These logs may not be seen if the worker (and associated pipe) has died logError(s"Could not initialize connection to worker $workerUrl. Exiting.") diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 9bbfcdc4a0b6e..616c7e6a46368 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -148,7 +148,6 @@ private[spark] class Executor( override def run() { val startTime = System.currentTimeMillis() - SparkEnv.set(env) Thread.currentThread.setContextClassLoader(replClassLoader) val ser = SparkEnv.get.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") @@ -158,7 +157,6 @@ private[spark] class Executor( val startGCTime = gcTime try { - SparkEnv.set(env) Accumulators.clear() val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) updateDependencies(taskFiles, taskJars) diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index f368209980f93..4f6f5e235811d 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -20,11 +20,14 @@ package org.apache.spark.network.nio import java.net._ import java.nio._ import java.nio.channels._ +import java.util.concurrent.ConcurrentLinkedQueue import java.util.LinkedList import org.apache.spark._ +import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.util.control.NonFatal private[nio] abstract class Connection(val channel: SocketChannel, val selector: Selector, @@ -51,7 +54,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, @volatile private var closed = false var onCloseCallback: Connection => Unit = null - var onExceptionCallback: (Connection, Exception) => Unit = null + val onExceptionCallbacks = new ConcurrentLinkedQueue[(Connection, Throwable) => Unit] var onKeyInterestChangeCallback: (Connection, Int) => Unit = null val remoteAddress = getRemoteAddress() @@ -130,20 +133,24 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, onCloseCallback = callback } - def onException(callback: (Connection, Exception) => Unit) { - onExceptionCallback = callback + def onException(callback: (Connection, Throwable) => Unit) { + onExceptionCallbacks.add(callback) } def onKeyInterestChange(callback: (Connection, Int) => Unit) { onKeyInterestChangeCallback = callback } - def callOnExceptionCallback(e: Exception) { - if (onExceptionCallback != null) { - onExceptionCallback(this, e) - } else { - logError("Error in connection to " + getRemoteConnectionManagerId() + - " and OnExceptionCallback not registered", e) + def callOnExceptionCallbacks(e: Throwable) { + onExceptionCallbacks foreach { + callback => + try { + callback(this, e) + } catch { + case NonFatal(e) => { + logWarning("Ignored error in onExceptionCallback", e) + } + } } } @@ -323,7 +330,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logError("Error connecting to " + address, e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) } } } @@ -348,7 +355,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logWarning("Error finishing connection to " + address, e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) } } true @@ -393,7 +400,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() return false } @@ -420,7 +427,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, case e: Exception => logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() } @@ -577,7 +584,7 @@ private[spark] class ReceivingConnection( } catch { case e: Exception => { logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() return false } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 01cd27a907eea..6b00190c5eccc 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -34,6 +34,8 @@ import scala.language.postfixOps import org.apache.spark._ import org.apache.spark.util.Utils +import scala.util.Try +import scala.util.control.NonFatal private[nio] class ConnectionManager( port: Int, @@ -51,14 +53,23 @@ private[nio] class ConnectionManager( class MessageStatus( val message: Message, val connectionManagerId: ConnectionManagerId, - completionHandler: MessageStatus => Unit) { + completionHandler: Try[Message] => Unit) { - /** This is non-None if message has been ack'd */ - var ackMessage: Option[Message] = None + def success(ackMessage: Message) { + if (ackMessage == null) { + failure(new NullPointerException) + } + else { + completionHandler(scala.util.Success(ackMessage)) + } + } - def markDone(ackMessage: Option[Message]) { - this.ackMessage = ackMessage - completionHandler(this) + def failWithoutAck() { + completionHandler(scala.util.Failure(new IOException("Failed without being ACK'd"))) + } + + def failure(e: Throwable) { + completionHandler(scala.util.Failure(e)) } } @@ -72,14 +83,32 @@ private[nio] class ConnectionManager( conf.getInt("spark.core.connection.handler.threads.max", 60), conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-message-executor")) + Utils.namedThreadFactory("handle-message-executor")) { + + override def afterExecute(r: Runnable, t: Throwable): Unit = { + super.afterExecute(r, t) + if (t != null && NonFatal(t)) { + logError("Error in handleMessageExecutor is not handled properly", t) + } + } + + } private val handleReadWriteExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.io.threads.min", 4), conf.getInt("spark.core.connection.io.threads.max", 32), conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-read-write-executor")) + Utils.namedThreadFactory("handle-read-write-executor")) { + + override def afterExecute(r: Runnable, t: Throwable): Unit = { + super.afterExecute(r, t) + if (t != null && NonFatal(t)) { + logError("Error in handleReadWriteExecutor is not handled properly", t) + } + } + + } // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : // which should be executed asap @@ -153,17 +182,24 @@ private[nio] class ConnectionManager( } handleReadWriteExecutor.execute(new Runnable { override def run() { - var register: Boolean = false try { - register = conn.write() - } finally { - writeRunnableStarted.synchronized { - writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() - if (needReregister && conn.changeInterestForWrite()) { - conn.registerInterest() + var register: Boolean = false + try { + register = conn.write() + } finally { + writeRunnableStarted.synchronized { + writeRunnableStarted -= key + val needReregister = register || conn.resetForceReregister() + if (needReregister && conn.changeInterestForWrite()) { + conn.registerInterest() + } } } + } catch { + case NonFatal(e) => { + logError("Error when writing to " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } } } ) @@ -187,16 +223,23 @@ private[nio] class ConnectionManager( } handleReadWriteExecutor.execute(new Runnable { override def run() { - var register: Boolean = false try { - register = conn.read() - } finally { - readRunnableStarted.synchronized { - readRunnableStarted -= key - if (register && conn.changeInterestForRead()) { - conn.registerInterest() + var register: Boolean = false + try { + register = conn.read() + } finally { + readRunnableStarted.synchronized { + readRunnableStarted -= key + if (register && conn.changeInterestForRead()) { + conn.registerInterest() + } } } + } catch { + case NonFatal(e) => { + logError("Error when reading from " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } } } ) @@ -213,19 +256,25 @@ private[nio] class ConnectionManager( handleConnectExecutor.execute(new Runnable { override def run() { + try { + var tries: Int = 10 + while (tries >= 0) { + if (conn.finishConnect(false)) return + // Sleep ? + Thread.sleep(1) + tries -= 1 + } - var tries: Int = 10 - while (tries >= 0) { - if (conn.finishConnect(false)) return - // Sleep ? - Thread.sleep(1) - tries -= 1 + // fallback to previous behavior : we should not really come here since this method was + // triggered since channel became connectable : but at times, the first finishConnect need + // not succeed : hence the loop to retry a few 'times'. + conn.finishConnect(true) + } catch { + case NonFatal(e) => { + logError("Error when finishConnect for " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } - - // fallback to previous behavior : we should not really come here since this method was - // triggered since channel became connectable : but at times, the first finishConnect need - // not succeed : hence the loop to retry a few 'times'. - conn.finishConnect(true) } } ) } @@ -246,16 +295,16 @@ private[nio] class ConnectionManager( handleConnectExecutor.execute(new Runnable { override def run() { try { - conn.callOnExceptionCallback(e) + conn.callOnExceptionCallbacks(e) } catch { // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) + case NonFatal(e) => logDebug("Ignoring exception", e) } try { conn.close() } catch { // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) + case NonFatal(e) => logDebug("Ignoring exception", e) } } }) @@ -448,7 +497,7 @@ private[nio] class ConnectionManager( messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) .foreach(status => { logInfo("Notifying " + status) - status.markDone(None) + status.failWithoutAck() }) messageStatuses.retain((i, status) => { @@ -477,7 +526,7 @@ private[nio] class ConnectionManager( for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { logInfo("Notifying " + s) - s.markDone(None) + s.failWithoutAck() } messageStatuses.retain((i, status) => { @@ -492,7 +541,7 @@ private[nio] class ConnectionManager( } } - def handleConnectionError(connection: Connection, e: Exception) { + def handleConnectionError(connection: Connection, e: Throwable) { logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId()) removeConnection(connection) @@ -510,9 +559,17 @@ private[nio] class ConnectionManager( val runnable = new Runnable() { val creationTime = System.currentTimeMillis def run() { - logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") - handleMessage(connectionManagerId, message, connection) - logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") + try { + logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") + handleMessage(connectionManagerId, message, connection) + logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") + } catch { + case NonFatal(e) => { + logError("Error when handling messages from " + + connection.getRemoteConnectionManagerId(), e) + connection.callOnExceptionCallbacks(e) + } + } } } handleMessageExecutor.execute(runnable) @@ -651,7 +708,7 @@ private[nio] class ConnectionManager( messageStatuses.get(bufferMessage.ackId) match { case Some(status) => { messageStatuses -= bufferMessage.ackId - status.markDone(Some(message)) + status.success(message) } case None => { /** @@ -770,6 +827,12 @@ private[nio] class ConnectionManager( val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, newConnectionId, securityManager) + newConnection.onException { + case (conn, e) => { + logError("Exception while sending message.", e) + reportSendingMessageFailure(message.id, e) + } + } logTrace("creating new sending connection: " + newConnectionId) registerRequests.enqueue(newConnection) @@ -782,13 +845,36 @@ private[nio] class ConnectionManager( "connectionid: " + connection.connectionId) if (authEnabled) { - checkSendAuthFirst(connectionManagerId, connection) + try { + checkSendAuthFirst(connectionManagerId, connection) + } catch { + case NonFatal(e) => { + reportSendingMessageFailure(message.id, e) + } + } } logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") connection.send(message) wakeupSelector() } + private def reportSendingMessageFailure(messageId: Int, e: Throwable): Unit = { + // need to tell sender it failed + messageStatuses.synchronized { + val s = messageStatuses.get(messageId) + s match { + case Some(msgStatus) => { + messageStatuses -= messageId + logInfo("Notifying " + msgStatus.connectionManagerId) + msgStatus.failure(e) + } + case None => { + logError("no messageStatus for failed message id: " + messageId) + } + } + } + } + private def wakeupSelector() { selector.wakeup() } @@ -807,9 +893,11 @@ private[nio] class ConnectionManager( override def run(): Unit = { messageStatuses.synchronized { messageStatuses.remove(message.id).foreach ( s => { - promise.failure( - new IOException("sendMessageReliably failed because ack " + - s"was not received within $ackTimeout sec")) + val e = new IOException("sendMessageReliably failed because ack " + + s"was not received within $ackTimeout sec") + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } }) } } @@ -817,15 +905,23 @@ private[nio] class ConnectionManager( val status = new MessageStatus(message, connectionManagerId, s => { timeoutTask.cancel() - s.ackMessage match { - case None => // Indicates a failure where we either never sent or never got ACK'd - promise.failure(new IOException("sendMessageReliably failed without being ACK'd")) - case Some(ackMessage) => + s match { + case scala.util.Failure(e) => + // Indicates a failure where we either never sent or never got ACK'd + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } + case scala.util.Success(ackMessage) => if (ackMessage.hasError) { - promise.failure( - new IOException("sendMessageReliably failed with ACK that signalled a remote error")) + val e = new IOException( + "sendMessageReliably failed with ACK that signalled a remote error") + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } } else { - promise.success(ackMessage) + if (!promise.trySuccess(ackMessage)) { + logWarning("Drop ackMessage because promise is completed") + } } } }) diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index 5d77d37378458..56ac7a69be0d3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -131,7 +131,6 @@ private[spark] class PipedRDD[T: ClassTag]( // Start a thread to feed the process input from our parent's iterator new Thread("stdin writer for " + command) { override def run() { - SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) // input the pipe context firstly diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8135cdbb4c31f..788eb1ff4e455 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -630,7 +630,6 @@ class DAGScheduler( protected def runLocallyWithinThread(job: ActiveJob) { var jobResult: JobResult = JobSucceeded try { - SparkEnv.set(env) val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) val taskContext = diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 4dc550413c13c..6d697e3d003f6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -216,8 +216,6 @@ private[spark] class TaskSchedulerImpl( * that tasks are balanced across the cluster. */ def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized { - SparkEnv.set(sc.env) - // Mark each slave as alive and remember its hostname // Also track if new executor is added var newExecAvail = false diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index b11786368e661..e0f2fd622f54c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -372,6 +372,13 @@ private[spark] class MesosSchedulerBackend( recordSlaveLost(d, slaveId, ExecutorExited(status)) } + override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { + driver.killTask( + TaskID.newBuilder() + .setValue(taskId.toString).build() + ) + } + // TODO: query Mesos for number of cores override def defaultParallelism() = sc.conf.getInt("spark.default.parallelism", 8) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index e9304f6bb45d0..bac459e835a3f 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -73,7 +73,21 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) val outputStream = new FileOutputStream(file) - blockManager.dataSerializeStream(blockId, outputStream, values) + try { + try { + blockManager.dataSerializeStream(blockId, outputStream, values) + } finally { + // Close outputStream here because it should be closed before file is deleted. + outputStream.close() + } + } catch { + case e: Throwable => + if (file.exists()) { + file.delete() + } + throw e + } + val length = file.length val timeTaken = System.currentTimeMillis - startTime diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 0a09c24d61879..edbc729c17ade 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -132,8 +132,6 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) PutResult(res.size, res.data, droppedBlocks) case Right(iteratorValues) => // Not enough space to unroll this block; drop to disk if applicable - logWarning(s"Not enough space to store block $blockId in memory! " + - s"Free memory is $freeMemory bytes.") if (level.useDisk && allowPersistToDisk) { logWarning(s"Persisting block $blockId to disk instead.") val res = blockManager.diskStore.putIterator(blockId, iteratorValues, level, returnValues) @@ -265,6 +263,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) Left(vector.toArray) } else { // We ran out of space while unrolling the values for this block + logUnrollFailureMessage(blockId, vector.estimateSize()) Right(vector.iterator ++ values) } @@ -424,7 +423,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * Reserve additional memory for unrolling blocks used by this thread. * Return whether the request is granted. */ - private[spark] def reserveUnrollMemoryForThisThread(memory: Long): Boolean = { + def reserveUnrollMemoryForThisThread(memory: Long): Boolean = { accountingLock.synchronized { val granted = freeMemory > currentUnrollMemory + memory if (granted) { @@ -439,7 +438,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * Release memory used by this thread for unrolling blocks. * If the amount is not specified, remove the current thread's allocation altogether. */ - private[spark] def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = { + def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = { val threadId = Thread.currentThread().getId accountingLock.synchronized { if (memory < 0) { @@ -457,16 +456,50 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) /** * Return the amount of memory currently occupied for unrolling blocks across all threads. */ - private[spark] def currentUnrollMemory: Long = accountingLock.synchronized { + def currentUnrollMemory: Long = accountingLock.synchronized { unrollMemoryMap.values.sum } /** * Return the amount of memory currently occupied for unrolling blocks by this thread. */ - private[spark] def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized { + def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized { unrollMemoryMap.getOrElse(Thread.currentThread().getId, 0L) } + + /** + * Return the number of threads currently unrolling blocks. + */ + def numThreadsUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } + + /** + * Log information about current memory usage. + */ + def logMemoryUsage(): Unit = { + val blocksMemory = currentMemory + val unrollMemory = currentUnrollMemory + val totalMemory = blocksMemory + unrollMemory + logInfo( + s"Memory use = ${Utils.bytesToString(blocksMemory)} (blocks) + " + + s"${Utils.bytesToString(unrollMemory)} (scratch space shared across " + + s"$numThreadsUnrolling thread(s)) = ${Utils.bytesToString(totalMemory)}. " + + s"Storage limit = ${Utils.bytesToString(maxMemory)}." + ) + } + + /** + * Log a warning for failing to unroll a block. + * + * @param blockId ID of the block we are trying to unroll. + * @param finalVectorSize Final size of the vector before unrolling failed. + */ + def logUnrollFailureMessage(blockId: BlockId, finalVectorSize: Long): Unit = { + logWarning( + s"Not enough space to cache $blockId in memory! " + + s"(computed ${Utils.bytesToString(finalVectorSize)} so far)" + ) + logMemoryUsage() + } } private[spark] case class ResultWithDroppedBlocks( diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index f0006b42aee4f..32e6b15bb0999 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -21,6 +21,7 @@ import java.text.SimpleDateFormat import java.util.{Locale, Date} import scala.xml.Node + import org.apache.spark.Logging /** Utility functions for generating XML pages with spark content. */ @@ -169,6 +170,7 @@ private[spark] object UIUtils extends Logging { refreshInterval: Option[Int] = None): Seq[Node] = { val appName = activeTab.appName + val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..." val header = activeTab.headerTabs.map { tab =>
  • {tab.name} @@ -187,7 +189,9 @@ private[spark] object UIUtils extends Logging { - +
    @@ -216,8 +220,10 @@ private[spark] object UIUtils extends Logging {

    - + + + {title}

    diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index db01be596e073..2414e4c65237e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -103,7 +103,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val taskHeaders: Seq[String] = Seq( - "Index", "ID", "Attempt", "Status", "Locality Level", "Executor", + "Index", "ID", "Attempt", "Status", "Locality Level", "Executor ID / Host", "Launch Time", "Duration", "GC Time", "Accumulators") ++ {if (hasInput) Seq("Input") else Nil} ++ {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ @@ -282,7 +282,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { } {info.status} {info.taskLocality} - {info.host} + {info.executorId} / {info.host} {UIUtils.formatDate(new Date(info.launchTime))} {formatDuration} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a67124140f9da..3d307b3c16d3e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -35,8 +35,6 @@ import scala.util.control.{ControlThrowable, NonFatal} import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder -import org.apache.commons.io.FileUtils -import org.apache.commons.io.filefilter.TrueFileFilter import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration import org.apache.log4j.PropertyConfigurator @@ -710,18 +708,20 @@ private[spark] object Utils extends Logging { * Determines if a directory contains any files newer than cutoff seconds. * * @param dir must be the path to a directory, or IllegalArgumentException is thrown - * @param cutoff measured in seconds. Returns true if there are any files in dir newer than this. + * @param cutoff measured in seconds. Returns true if there are any files or directories in the + * given directory whose last modified time is later than this many seconds ago */ def doesDirectoryContainAnyNewFiles(dir: File, cutoff: Long): Boolean = { - val currentTimeMillis = System.currentTimeMillis if (!dir.isDirectory) { - throw new IllegalArgumentException (dir + " is not a directory!") - } else { - val files = FileUtils.listFilesAndDirs(dir, TrueFileFilter.TRUE, TrueFileFilter.TRUE) - val cutoffTimeInMillis = (currentTimeMillis - (cutoff * 1000)) - val newFiles = files.filter { _.lastModified > cutoffTimeInMillis } - newFiles.nonEmpty + throw new IllegalArgumentException("$dir is not a directory!") } + val filesAndDirs = dir.listFiles() + val cutoffTimeInMillis = System.currentTimeMillis - (cutoff * 1000) + + filesAndDirs.exists(_.lastModified() > cutoffTimeInMillis) || + filesAndDirs.filter(_.isDirectory).exists( + subdir => doesDirectoryContainAnyNewFiles(subdir, cutoff) + ) } /** diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 1fef79ad1001f..cbc0bd178d894 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -146,7 +146,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) + Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~123B, and no exception should be thrown @@ -164,7 +164,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) + Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index a8e92e36fe0d8..02ac20984add9 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -73,11 +73,10 @@ def fail(msg): def run_cmd(cmd): + print cmd if isinstance(cmd, list): - print " ".join(cmd) return subprocess.check_output(cmd) else: - print cmd return subprocess.check_output(cmd.split(" ")) diff --git a/dev/run-tests b/dev/run-tests index c3d8f49cdd993..f47fcf66ff7e7 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -24,6 +24,16 @@ cd "$FWDIR" # Remove work directory rm -rf ./work +source "$FWDIR/dev/run-tests-codes.sh" + +CURRENT_BLOCK=$BLOCK_GENERAL + +function handle_error () { + echo "[error] Got a return code of $? on line $1 of the run-tests script." + exit $CURRENT_BLOCK +} + + # Build against the right verison of Hadoop. { if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then @@ -32,7 +42,7 @@ rm -rf ./work elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.0" ]; then export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=2.0.0-mr1-cdh4.1.1" elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.2" ]; then - export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Dhadoop.version=2.2.0" + export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0" elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.3" ]; then export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" fi @@ -91,26 +101,34 @@ if [ -n "$AMPLAB_JENKINS" ]; then fi fi -# Fail fast -set -e set -o pipefail +trap 'handle_error $LINENO' ERR echo "" echo "=========================================================================" echo "Running Apache RAT checks" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_RAT + ./dev/check-license echo "" echo "=========================================================================" echo "Running Scala style checks" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_SCALA_STYLE + ./dev/lint-scala echo "" echo "=========================================================================" echo "Running Python style checks" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_PYTHON_STYLE + ./dev/lint-python echo "" @@ -118,6 +136,8 @@ echo "=========================================================================" echo "Building Spark" echo "=========================================================================" +CURRENT_BLOCK=$BLOCK_BUILD + { # We always build with Hive because the PySpark Spark SQL tests need it. BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive" @@ -141,6 +161,8 @@ echo "=========================================================================" echo "Running Spark unit tests" echo "=========================================================================" +CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS + { # If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled. # This must be a single argument, as it is. @@ -175,10 +197,16 @@ echo "" echo "=========================================================================" echo "Running PySpark tests" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS + ./python/run-tests echo "" echo "=========================================================================" echo "Detecting binary incompatibilites with MiMa" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_MIMA + ./dev/mima diff --git a/python/epydoc.conf b/dev/run-tests-codes.sh similarity index 55% rename from python/epydoc.conf rename to dev/run-tests-codes.sh index 8593e08deda19..1348e0609dda4 100644 --- a/python/epydoc.conf +++ b/dev/run-tests-codes.sh @@ -1,4 +1,4 @@ -[epydoc] # Epydoc section marker (required by ConfigParser) +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more @@ -17,22 +17,11 @@ # limitations under the License. # -# Information about the project. -name: Spark 1.0.0 Python API Docs -url: http://spark.apache.org - -# The list of modules to document. Modules can be named using -# dotted names, module filenames, or package directory names. -# This option may be repeated. -modules: pyspark - -# Write html output to the directory "apidocs" -output: html -target: docs/ - -private: no - -exclude: pyspark.cloudpickle pyspark.worker pyspark.join - pyspark.java_gateway pyspark.examples pyspark.shell pyspark.tests - pyspark.rddsampler pyspark.daemon - pyspark.mllib.tests pyspark.shuffle +readonly BLOCK_GENERAL=10 +readonly BLOCK_RAT=11 +readonly BLOCK_SCALA_STYLE=12 +readonly BLOCK_PYTHON_STYLE=13 +readonly BLOCK_BUILD=14 +readonly BLOCK_SPARK_UNIT_TESTS=15 +readonly BLOCK_PYSPARK_UNIT_TESTS=16 +readonly BLOCK_MIMA=17 diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 0b1e31b9413cf..451f3b771cc76 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -26,9 +26,23 @@ FWDIR="$(cd `dirname $0`/..; pwd)" cd "$FWDIR" +source "$FWDIR/dev/run-tests-codes.sh" + COMMENTS_URL="https://api.github.com/repos/apache/spark/issues/$ghprbPullId/comments" PULL_REQUEST_URL="https://github.com/apache/spark/pull/$ghprbPullId" +# Important Environment Variables +# --- +# $ghprbActualCommit +#+ This is the hash of the most recent commit in the PR. +#+ The merge-base of this and master is the commit from which the PR was branched. +# $sha1 +#+ If the patch merges cleanly, this is a reference to the merge commit hash +#+ (e.g. "origin/pr/2606/merge"). +#+ If the patch does not merge cleanly, it is equal to $ghprbActualCommit. +#+ The merge-base of this and master in the case of a clean merge is the most recent commit +#+ against master. + COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" # GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" @@ -84,42 +98,46 @@ function post_message () { fi } + +# We diff master...$ghprbActualCommit because that gets us changes introduced in the PR +#+ and not anything else added to master since the PR was branched. + # check PR merge-ability and check for new public classes { if [ "$sha1" == "$ghprbActualCommit" ]; then - merge_note=" * This patch **does not** merge cleanly!" + merge_note=" * This patch **does not merge cleanly**." else merge_note=" * This patch merges cleanly." + fi + + source_files=$( + git diff master...$ghprbActualCommit --name-only `# diff patch against master from branch point` \ + | grep -v -e "\/test" `# ignore files in test directories` \ + | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \ + | tr "\n" " " + ) + new_public_classes=$( + git diff master...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \ + | grep "^\+" `# filter in only added lines` \ + | sed -r -e "s/^\+//g" `# remove the leading +` \ + | grep -e "trait " -e "class " `# filter in lines with these key words` \ + | grep -e "{" -e "(" `# filter in lines with these key words, too` \ + | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \ + | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \ + | sed -r -e "s/\{.*//g" `# remove from the { onwards` \ + | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ + | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ + | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ + | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \ + | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \ + | tr -d "\n" `# remove actual LF characters` + ) - source_files=$( - git diff master... --name-only `# diff patch against master from branch point` \ - | grep -v -e "\/test" `# ignore files in test directories` \ - | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \ - | tr "\n" " " - ) - new_public_classes=$( - git diff master... ${source_files} `# diff patch against master from branch point` \ - | grep "^\+" `# filter in only added lines` \ - | sed -r -e "s/^\+//g" `# remove the leading +` \ - | grep -e "trait " -e "class " `# filter in lines with these key words` \ - | grep -e "{" -e "(" `# filter in lines with these key words, too` \ - | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \ - | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \ - | sed -r -e "s/\{.*//g" `# remove from the { onwards` \ - | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ - | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ - | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ - | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \ - | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \ - | tr -d "\n" `# remove actual LF characters` - ) - - if [ "$new_public_classes" == "" ]; then - public_classes_note=" * This patch adds no public classes." - else - public_classes_note=" * This patch adds the following public classes _(experimental)_:" - public_classes_note="${public_classes_note}\n${new_public_classes}" - fi + if [ -z "$new_public_classes" ]; then + public_classes_note=" * This patch adds no public classes." + else + public_classes_note=" * This patch adds the following public classes _(experimental)_:" + public_classes_note="${public_classes_note}\n${new_public_classes}" fi } @@ -147,12 +165,30 @@ function post_message () { post_message "$fail_message" exit $test_result + elif [ "$test_result" -eq "0" ]; then + test_result_note=" * This patch **passes all tests**." else - if [ "$test_result" -eq "0" ]; then - test_result_note=" * This patch **passes** unit tests." + if [ "$test_result" -eq "$BLOCK_GENERAL" ]; then + failing_test="some tests" + elif [ "$test_result" -eq "$BLOCK_RAT" ]; then + failing_test="RAT tests" + elif [ "$test_result" -eq "$BLOCK_SCALA_STYLE" ]; then + failing_test="Scala style tests" + elif [ "$test_result" -eq "$BLOCK_PYTHON_STYLE" ]; then + failing_test="Python style tests" + elif [ "$test_result" -eq "$BLOCK_BUILD" ]; then + failing_test="to build" + elif [ "$test_result" -eq "$BLOCK_SPARK_UNIT_TESTS" ]; then + failing_test="Spark unit tests" + elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then + failing_test="PySpark unit tests" + elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then + failing_test="MiMa tests" else - test_result_note=" * This patch **fails** unit tests." + failing_test="some tests" fi + + test_result_note=" * This patch **fails $failing_test**." fi } diff --git a/dev/scalastyle b/dev/scalastyle index efb5f291ea3b7..c3b356bcb3c06 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -26,6 +26,8 @@ echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 yarn/scalasty >> scalastyle.txt ERRORS=$(cat scalastyle.txt | grep -e "\") +rm scalastyle.txt + if test ! -z "$ERRORS"; then echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS" exit 1 diff --git a/docs/README.md b/docs/README.md index 79708c3df9106..0facecdd5f767 100644 --- a/docs/README.md +++ b/docs/README.md @@ -54,19 +54,19 @@ phase, use the following sytax: // supported languages too. {% endhighlight %} -## API Docs (Scaladoc and Epydoc) +## API Docs (Scaladoc and Sphinx) You can build just the Spark scaladoc by running `sbt/sbt doc` from the SPARK_PROJECT_ROOT directory. -Similarly, you can build just the PySpark epydoc by running `epydoc --config epydoc.conf` from the -SPARK_PROJECT_ROOT/pyspark directory. Documentation is only generated for classes that are listed as +Similarly, you can build just the PySpark docs by running `make html` from the +SPARK_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as public in `__init__.py`. When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a jekyll plugin to run `sbt/sbt doc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. The jekyll plugin also generates the -PySpark docs using [epydoc](http://epydoc.sourceforge.net/). +PySpark docs [Sphinx](http://sphinx-doc.org/). NOTE: To skip the step of building and copying over the Scala and Python API docs, run `SKIP_API=1 jekyll`. diff --git a/docs/_config.yml b/docs/_config.yml index 7bc3a78e2d265..f4bf242ac191b 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -8,6 +8,9 @@ gems: kramdown: entity_output: numeric +include: + - _static + # These allow the documentation to be updated with nerw releases # of Spark, Scala, and Mesos. SPARK_VERSION: 1.0.0-SNAPSHOT diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 3b02e090aec28..4566a2fff562b 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -63,19 +63,20 @@ puts "cp -r " + source + "/. " + dest cp_r(source + "/.", dest) - # Build Epydoc for Python - puts "Moving to python directory and building epydoc." - cd("../python") - puts `epydoc --config epydoc.conf` + # Build Sphinx docs for Python - puts "Moving back into docs dir." - cd("../docs") + puts "Moving to python/docs directory and building sphinx." + cd("../python/docs") + puts `make html` + + puts "Moving back into home dir." + cd("../../") puts "Making directory api/python" - mkdir_p "api/python" + mkdir_p "docs/api/python" - puts "cp -r ../python/docs/. api/python" - cp_r("../python/docs/.", "api/python") + puts "cp -r python/docs/_build/html/. docs/api/python" + cp_r("python/docs/_build/html/.", "docs/api/python") cd("..") end diff --git a/docs/building-spark.md b/docs/building-spark.md index 901c157162fee..b2940ee4029e8 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -171,6 +171,21 @@ can be set to control the SBT build. For example: sbt/sbt -Pyarn -Phadoop-2.3 assembly +# Testing with SBT + +Some of the tests require Spark to be packaged first, so always run `sbt/sbt assembly` the first time. The following is an example of a correct (build, test) sequence: + + sbt/sbt -Pyarn -Phadoop-2.3 -Phive assembly + sbt/sbt -Pyarn -Phadoop-2.3 -Phive test + +To run only a specific test suite as follows: + + sbt/sbt -Pyarn -Phadoop-2.3 -Phive "test-only org.apache.spark.repl.ReplSuite" + +To run test suites of a specific sub project as follows: + + sbt/sbt -Pyarn -Phadoop-2.3 -Phive core/test + # Speeding up Compilation with Zinc [Zinc](https://github.com/typesafehub/zinc) is a long-running server version of SBT's incremental diff --git a/docs/configuration.md b/docs/configuration.md index 1c33855365170..f311f0d2a6206 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -103,6 +103,14 @@ of the most common options to set are: (e.g. 512m, 2g). + + spark.driver.memory + 512m + + Amount of memory to use for the driver process, i.e. where SparkContext is initialized. + (e.g. 512m, 2g). + + spark.serializer org.apache.spark.serializer.
    JavaSerializer diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 941dfb988b9fb..0d6b82b4944f3 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -32,6 +32,7 @@ import tempfile import time import urllib2 +import warnings from optparse import OptionParser from sys import stderr import boto @@ -61,8 +62,8 @@ def parse_args(): "-s", "--slaves", type="int", default=1, help="Number of slaves to launch (default: %default)") parser.add_option( - "-w", "--wait", type="int", default=120, - help="Seconds to wait for nodes to start (default: %default)") + "-w", "--wait", type="int", + help="DEPRECATED (no longer necessary) - Seconds to wait for nodes to start") parser.add_option( "-k", "--key-pair", help="Key pair to use on instances") @@ -195,18 +196,6 @@ def get_or_make_group(conn, name): return conn.create_security_group(name, "Spark EC2 group") -# Wait for a set of launched instances to exit the "pending" state -# (i.e. either to start running or to fail and be terminated) -def wait_for_instances(conn, instances): - while True: - for i in instances: - i.update() - if len([i for i in instances if i.state == 'pending']) > 0: - time.sleep(5) - else: - return - - # Check whether a given EC2 instance object is in a state we consider active, # i.e. not terminating or terminated. We count both stopping and stopped as # active since we can restart stopped clusters. @@ -594,7 +583,7 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten - ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v3") + ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v4") print "Deploying files to master..." deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, modules) @@ -619,14 +608,64 @@ def setup_spark_cluster(master, opts): print "Ganglia started at http://%s:5080/ganglia" % master -# Wait for a whole cluster (masters, slaves and ZooKeeper) to start up -def wait_for_cluster(conn, wait_secs, master_nodes, slave_nodes): - print "Waiting for instances to start up..." - time.sleep(5) - wait_for_instances(conn, master_nodes) - wait_for_instances(conn, slave_nodes) - print "Waiting %d more seconds..." % wait_secs - time.sleep(wait_secs) +def is_ssh_available(host, opts): + "Checks if SSH is available on the host." + try: + with open(os.devnull, 'w') as devnull: + ret = subprocess.check_call( + ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', + '%s@%s' % (opts.user, host), stringify_command('true')], + stdout=devnull, + stderr=devnull + ) + return ret == 0 + except subprocess.CalledProcessError as e: + return False + + +def is_cluster_ssh_available(cluster_instances, opts): + for i in cluster_instances: + if not is_ssh_available(host=i.ip_address, opts=opts): + return False + else: + return True + + +def wait_for_cluster_state(cluster_instances, cluster_state, opts): + """ + cluster_instances: a list of boto.ec2.instance.Instance + cluster_state: a string representing the desired state of all the instances in the cluster + value can be 'ssh-ready' or a valid value from boto.ec2.instance.InstanceState such as + 'running', 'terminated', etc. + (would be nice to replace this with a proper enum: http://stackoverflow.com/a/1695250) + """ + sys.stdout.write( + "Waiting for all instances in cluster to enter '{s}' state.".format(s=cluster_state) + ) + sys.stdout.flush() + + num_attempts = 0 + + while True: + time.sleep(3 * num_attempts) + + for i in cluster_instances: + s = i.update() # capture output to suppress print to screen in newer versions of boto + + if cluster_state == 'ssh-ready': + if all(i.state == 'running' for i in cluster_instances) and \ + is_cluster_ssh_available(cluster_instances, opts): + break + else: + if all(i.state == cluster_state for i in cluster_instances): + break + + num_attempts += 1 + + sys.stdout.write(".") + sys.stdout.flush() + + sys.stdout.write("\n") # Get number of local disks available for a given EC2 instance type. @@ -868,6 +907,16 @@ def real_main(): (opts, action, cluster_name) = parse_args() # Input parameter validation + if opts.wait is not None: + # NOTE: DeprecationWarnings are silent in 2.7+ by default. + # To show them, run Python with the -Wdefault switch. + # See: https://docs.python.org/3.5/whatsnew/2.7.html + warnings.warn( + "This option is deprecated and has no effect. " + "spark-ec2 automatically waits as long as necessary for clusters to startup.", + DeprecationWarning + ) + if opts.ebs_vol_num > 8: print >> stderr, "ebs-vol-num cannot be greater than 8" sys.exit(1) @@ -890,7 +939,11 @@ def real_main(): (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) else: (master_nodes, slave_nodes) = launch_cluster(conn, opts, cluster_name) - wait_for_cluster(conn, opts.wait, master_nodes, slave_nodes) + wait_for_cluster_state( + cluster_instances=(master_nodes + slave_nodes), + cluster_state='ssh-ready', + opts=opts + ) setup_cluster(conn, master_nodes, slave_nodes, opts, True) elif action == "destroy": @@ -919,7 +972,11 @@ def real_main(): else: group_names = [opts.security_group_prefix + "-master", opts.security_group_prefix + "-slaves"] - + wait_for_cluster_state( + cluster_instances=(master_nodes + slave_nodes), + cluster_state='terminated', + opts=opts + ) attempt = 1 while attempt <= 3: print "Attempt %d" % attempt @@ -1019,7 +1076,11 @@ def real_main(): for inst in master_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.start() - wait_for_cluster(conn, opts.wait, master_nodes, slave_nodes) + wait_for_cluster_state( + cluster_instances=(master_nodes + slave_nodes), + cluster_state='ssh-ready', + opts=opts + ) setup_cluster(conn, master_nodes, slave_nodes, opts, False) else: diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala new file mode 100644 index 0000000000000..ae6057758d6fc --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scala.reflect.runtime.universe._ + +/** + * Abstract class for parameter case classes. + * This overrides the [[toString]] method to print all case class fields by name and value. + * @tparam T Concrete parameter class. + */ +abstract class AbstractParams[T: TypeTag] { + + private def tag: TypeTag[T] = typeTag[T] + + /** + * Finds all case class fields in concrete class instance, and outputs them in JSON-style format: + * { + * [field name]:\t[field value]\n + * [field name]:\t[field value]\n + * ... + * } + */ + override def toString: String = { + val tpe = tag.tpe + val allAccessors = tpe.declarations.collect { + case m: MethodSymbol if m.isCaseAccessor => m + } + val mirror = runtimeMirror(getClass.getClassLoader) + val instanceMirror = mirror.reflect(this) + allAccessors.map { f => + val paramName = f.name.toString + val fieldMirror = instanceMirror.reflectField(f) + val paramValue = fieldMirror.get + s" $paramName:\t$paramValue" + }.mkString("{\n", ",\n", "\n}") + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index a6f78d2441db1..1edd2432a0352 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -55,7 +55,7 @@ object BinaryClassification { stepSize: Double = 1.0, algorithm: Algorithm = LR, regType: RegType = L2, - regParam: Double = 0.1) + regParam: Double = 0.1) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala index d6b2fe430e5a4..e49129c4e7844 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala @@ -35,6 +35,7 @@ import org.apache.spark.{SparkConf, SparkContext} object Correlations { case class Params(input: String = "data/mllib/sample_linear_regression_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala new file mode 100644 index 0000000000000..cb1abbd18fd4d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, RowMatrix} +import org.apache.spark.{SparkConf, SparkContext} + +/** + * Compute the similar columns of a matrix, using cosine similarity. + * + * The input matrix must be stored in row-oriented dense format, one line per row with its entries + * separated by space. For example, + * {{{ + * 0.5 1.0 + * 2.0 3.0 + * 4.0 5.0 + * }}} + * represents a 3-by-2 matrix, whose first row is (0.5, 1.0). + * + * Example invocation: + * + * bin/run-example mllib.CosineSimilarity \ + * --threshold 0.1 data/mllib/sample_svm_data.txt + */ +object CosineSimilarity { + case class Params(inputFile: String = null, threshold: Double = 0.1) + extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("CosineSimilarity") { + head("CosineSimilarity: an example app.") + opt[Double]("threshold") + .required() + .text(s"threshold similarity: to tradeoff computation vs quality estimate") + .action((x, c) => c.copy(threshold = x)) + arg[String]("") + .required() + .text(s"input file, one row per line, space-separated") + .action((x, c) => c.copy(inputFile = x)) + note( + """ + |For example, the following command runs this app on a dataset: + | + | ./bin/spark-submit --class org.apache.spark.examples.mllib.CosineSimilarity \ + | examplesjar.jar \ + | --threshold 0.1 data/mllib/sample_svm_data.txt + """.stripMargin) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + System.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName("CosineSimilarity") + val sc = new SparkContext(conf) + + // Load and parse the data file. + val rows = sc.textFile(params.inputFile).map { line => + val values = line.split(' ').map(_.toDouble) + Vectors.dense(values) + }.cache() + val mat = new RowMatrix(rows) + + // Compute similar columns perfectly, with brute force. + val exact = mat.columnSimilarities() + + // Compute similar columns with estimation using DIMSUM + val approx = mat.columnSimilarities(params.threshold) + + val exactEntries = exact.entries.map { case MatrixEntry(i, j, u) => ((i, j), u) } + val approxEntries = approx.entries.map { case MatrixEntry(i, j, v) => ((i, j), v) } + val MAE = exactEntries.leftOuterJoin(approxEntries).values.map { + case (u, Some(v)) => + math.abs(u - v) + case (u, None) => + math.abs(u) + }.mean() + + println(s"Average absolute error in estimate is: $MAE") + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 4adc91d2fbe65..837d0591478c5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -62,7 +62,7 @@ object DecisionTreeRunner { minInfoGain: Double = 0.0, numTrees: Int = 1, featureSubsetStrategy: String = "auto", - fracTest: Double = 0.2) + fracTest: Double = 0.2) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() @@ -138,9 +138,11 @@ object DecisionTreeRunner { def run(params: Params) { - val conf = new SparkConf().setAppName("DecisionTreeRunner") + val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params") val sc = new SparkContext(conf) + println(s"DecisionTreeRunner with parameters:\n$params") + // Load training data and cache it. val origExamples = params.dataFormat match { case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache() @@ -235,7 +237,10 @@ object DecisionTreeRunner { minInstancesPerNode = params.minInstancesPerNode, minInfoGain = params.minInfoGain) if (params.numTrees == 1) { + val startTime = System.nanoTime() val model = DecisionTree.train(training, strategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") if (model.numNodes < 20) { println(model.toDebugString) // Print full model. } else { @@ -259,8 +264,11 @@ object DecisionTreeRunner { } else { val randomSeed = Utils.random.nextInt() if (params.algo == Classification) { + val startTime = System.nanoTime() val model = RandomForest.trainClassifier(training, strategy, params.numTrees, params.featureSubsetStrategy, randomSeed) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") if (model.totalNumNodes < 30) { println(model.toDebugString) // Print full model. } else { @@ -275,8 +283,11 @@ object DecisionTreeRunner { println(s"Test accuracy = $testAccuracy") } if (params.algo == Regression) { + val startTime = System.nanoTime() val model = RandomForest.trainRegressor(training, strategy, params.numTrees, params.featureSubsetStrategy, randomSeed) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") if (model.totalNumNodes < 30) { println(model.toDebugString) // Print full model. } else { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala index 89dfa26c2299c..11e35598baf50 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala @@ -44,7 +44,7 @@ object DenseKMeans { input: String = null, k: Int = -1, numIterations: Int = 10, - initializationMode: InitializationMode = Parallel) + initializationMode: InitializationMode = Parallel) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index 05b7d66f8dffd..e1f9622350135 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -47,7 +47,7 @@ object LinearRegression extends App { numIterations: Int = 100, stepSize: Double = 1.0, regType: RegType = L2, - regParam: Double = 0.1) + regParam: Double = 0.1) extends AbstractParams[Params] val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 98aaedb9d7dc9..fc6678013b932 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -55,7 +55,7 @@ object MovieLensALS { rank: Int = 10, numUserBlocks: Int = -1, numProductBlocks: Int = -1, - implicitPrefs: Boolean = false) + implicitPrefs: Boolean = false) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala index 4532512c01f84..6e4e2d07f284b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala @@ -36,6 +36,7 @@ import org.apache.spark.{SparkConf, SparkContext} object MultivariateSummarizer { case class Params(input: String = "data/mllib/sample_linear_regression_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala index f01b8266e3fe3..663c12734af68 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala @@ -33,6 +33,7 @@ import org.apache.spark.SparkContext._ object SampledRDDs { case class Params(input: String = "data/mllib/sample_binary_classification_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala index 952fa2a5109a4..f1ff4e6911f5e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -37,7 +37,7 @@ object SparseNaiveBayes { input: String = null, minPartitions: Int = 0, numFeatures: Int = -1, - lambda: Double = 1.0) + lambda: Double = 1.0) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/mllib/pom.xml b/mllib/pom.xml index a5eeef88e9d62..696e9396f627c 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -57,7 +57,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.9 + 0.10 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index e9f41758581e3..f7251e65e04f1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -29,6 +29,8 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ +import org.apache.spark.mllib.feature.Word2Vec +import org.apache.spark.mllib.feature.Word2VecModel import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.random.{RandomRDDs => RG} @@ -42,9 +44,9 @@ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils - /** * :: DeveloperApi :: * The Java stubs necessary for the Python mllib bindings. @@ -287,6 +289,59 @@ class PythonMLLibAPI extends Serializable { ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha) } + /** + * Java stub for Python mllib Word2Vec fit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + * @param dataJRDD input JavaRDD + * @param vectorSize size of vector + * @param learningRate initial learning rate + * @param numPartitions number of partitions + * @param numIterations number of iterations + * @param seed initial seed for random generator + * @return A handle to java Word2VecModelWrapper instance at python side + */ + def trainWord2Vec( + dataJRDD: JavaRDD[java.util.ArrayList[String]], + vectorSize: Int, + learningRate: Double, + numPartitions: Int, + numIterations: Int, + seed: Long): Word2VecModelWrapper = { + val data = dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER) + val word2vec = new Word2Vec() + .setVectorSize(vectorSize) + .setLearningRate(learningRate) + .setNumPartitions(numPartitions) + .setNumIterations(numIterations) + .setSeed(seed) + val model = word2vec.fit(data) + data.unpersist() + new Word2VecModelWrapper(model) + } + + private[python] class Word2VecModelWrapper(model: Word2VecModel) { + def transform(word: String): Vector = { + model.transform(word) + } + + def findSynonyms(word: String, num: Int): java.util.List[java.lang.Object] = { + val vec = transform(word) + findSynonyms(vec, num) + } + + def findSynonyms(vector: Vector, num: Int): java.util.List[java.lang.Object] = { + val result = model.findSynonyms(vector, num) + val similarity = Vectors.dense(result.map(_._2)) + val words = result.map(_._1) + val ret = new java.util.LinkedList[java.lang.Object]() + ret.add(words) + ret.add(similarity) + ret + } + } + /** * Java stub for Python mllib DecisionTree.train(). * This stub returns a handle to the Java object instead of the content of the Java object. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index 3afb47767281c..4734251127bb4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.feature -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => brzNorm} import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -47,7 +47,7 @@ class Normalizer(p: Double) extends VectorTransformer { * @return normalized vector. If the norm of the input is zero, it will return the input vector. */ override def transform(vector: Vector): Vector = { - var norm = vector.toBreeze.norm(p) + var norm = brzNorm(vector.toBreeze, p) if (norm != 0.0) { // For dense vector, we've to allocate new memory for new output vector. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index fc1444705364a..d321994c2a651 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -67,7 +67,7 @@ private case class VocabWord( class Word2Vec extends Serializable with Logging { private var vectorSize = 100 - private var startingAlpha = 0.025 + private var learningRate = 0.025 private var numPartitions = 1 private var numIterations = 1 private var seed = Utils.random.nextLong() @@ -84,7 +84,7 @@ class Word2Vec extends Serializable with Logging { * Sets initial learning rate (default: 0.025). */ def setLearningRate(learningRate: Double): this.type = { - this.startingAlpha = learningRate + this.learningRate = learningRate this } @@ -286,7 +286,7 @@ class Word2Vec extends Serializable with Logging { val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) - var alpha = startingAlpha + var alpha = learningRate for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) @@ -300,8 +300,8 @@ class Word2Vec extends Serializable with Logging { lwc = wordCount // TODO: discount by iteration? alpha = - startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) - if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 + learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) + if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001 logInfo("wordCount = " + wordCount + ", alpha = " + alpha) } wc += sentence.size @@ -437,7 +437,7 @@ class Word2VecModel private[mllib] ( * Find synonyms of a word * @param word a word * @param num number of synonyms to find - * @return array of (word, similarity) + * @return array of (word, cosineSimilarity) */ def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index d0fe4179685ca..00dfc86c9e0bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -75,6 +75,8 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double def predict(testData: Vector): Double = { predictPoint(testData, weights, intercept) } + + override def toString() = "(weights=%s, intercept=%s)".format(weights, intercept) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b311d10023894..03eeaa707715b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -532,6 +532,14 @@ object DecisionTree extends Serializable with Logging { Some(mutableNodeToFeatures.toMap) } + // array of nodes to train indexed by node index in group + val nodes = new Array[Node](numNodes) + nodesForGroup.foreach { case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node + } + } + // Calculate best splits for all nodes in the group timer.start("chooseSplits") @@ -568,7 +576,7 @@ object DecisionTree extends Serializable with Logging { // find best split for each node val (split: Split, stats: InformationGainStats, predict: Predict) = - binsToBestSplit(aggStats, splits, featuresForNode) + binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) (nodeIndex, (split, stats, predict)) }.collectAsMap() @@ -587,17 +595,30 @@ object DecisionTree extends Serializable with Logging { // Extract info for this node. Create children if not leaf. val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth) assert(node.id == nodeIndex) - node.predict = predict.predict + node.predict = predict node.isLeaf = isLeaf node.stats = Some(stats) + node.impurity = stats.impurity logDebug("Node = " + node) if (!isLeaf) { node.split = Some(split) - node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex))) - node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex))) - nodeQueue.enqueue((treeIndex, node.leftNode.get)) - nodeQueue.enqueue((treeIndex, node.rightNode.get)) + val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth + val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) + val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) + node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex), + stats.leftPredict, stats.leftImpurity, leftChildIsLeaf)) + node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex), + stats.rightPredict, stats.rightImpurity, rightChildIsLeaf)) + + // enqueue left child and right child if they are not leaves + if (!leftChildIsLeaf) { + nodeQueue.enqueue((treeIndex, node.leftNode.get)) + } + if (!rightChildIsLeaf) { + nodeQueue.enqueue((treeIndex, node.rightNode.get)) + } + logDebug("leftChildIndex = " + node.leftNode.get.id + ", impurity = " + stats.leftImpurity) logDebug("rightChildIndex = " + node.rightNode.get.id + @@ -617,7 +638,8 @@ object DecisionTree extends Serializable with Logging { private def calculateGainForSplit( leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, - metadata: DecisionTreeMetadata): InformationGainStats = { + metadata: DecisionTreeMetadata, + impurity: Double): InformationGainStats = { val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count @@ -630,11 +652,6 @@ object DecisionTree extends Serializable with Logging { val totalCount = leftCount + rightCount - val parentNodeAgg = leftImpurityCalculator.copy - parentNodeAgg.add(rightImpurityCalculator) - - val impurity = parentNodeAgg.calculate() - val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 val rightImpurity = rightImpurityCalculator.calculate() @@ -649,7 +666,18 @@ object DecisionTree extends Serializable with Logging { return InformationGainStats.invalidInformationGainStats } - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity) + // calculate left and right predict + val leftPredict = calculatePredict(leftImpurityCalculator) + val rightPredict = calculatePredict(rightImpurityCalculator) + + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, + leftPredict, rightPredict) + } + + private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = { + val predict = impurityCalculator.predict + val prob = impurityCalculator.prob(predict) + new Predict(predict, prob) } /** @@ -657,17 +685,17 @@ object DecisionTree extends Serializable with Logging { * Note that this function is called only once for each node. * @param leftImpurityCalculator left node aggregates for a split * @param rightImpurityCalculator right node aggregates for a split - * @return predict value for current node + * @return predict value and impurity for current node */ - private def calculatePredict( + private def calculatePredictImpurity( leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator): Predict = { + rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { val parentNodeAgg = leftImpurityCalculator.copy parentNodeAgg.add(rightImpurityCalculator) - val predict = parentNodeAgg.predict - val prob = parentNodeAgg.prob(predict) + val predict = calculatePredict(parentNodeAgg) + val impurity = parentNodeAgg.calculate() - new Predict(predict, prob) + (predict, impurity) } /** @@ -678,10 +706,16 @@ object DecisionTree extends Serializable with Logging { private def binsToBestSplit( binAggregates: DTStatsAggregator, splits: Array[Array[Split]], - featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = { + featuresForNode: Option[Array[Int]], + node: Node): (Split, InformationGainStats, Predict) = { - // calculate predict only once - var predict: Option[Predict] = None + // calculate predict and impurity if current node is top node + val level = Node.indexToLevel(node.id) + var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) { + None + } else { + Some((node.predict, node.impurity)) + } // For each (feature, split), calculate the gain, and select the best (feature, split). val (bestSplit, bestSplitStats) = @@ -708,9 +742,10 @@ object DecisionTree extends Serializable with Logging { val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) - predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata) + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIdx, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) @@ -722,9 +757,10 @@ object DecisionTree extends Serializable with Logging { Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) - predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata) + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIndex, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) @@ -794,9 +830,10 @@ object DecisionTree extends Serializable with Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) - predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata) + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIndex, gainStats) }.maxBy(_._2.gain) val categoriesForSplit = @@ -807,9 +844,7 @@ object DecisionTree extends Serializable with Logging { } }.maxBy(_._2.gain) - assert(predict.isDefined, "must calculate predict for each node") - - (bestSplit, bestSplitStats, predict.get) + (bestSplit, bestSplitStats, predictWithImpurity.get._1) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index a89e71e115806..9a50ecb550c38 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -26,13 +26,17 @@ import org.apache.spark.annotation.DeveloperApi * @param impurity current node impurity * @param leftImpurity left node impurity * @param rightImpurity right node impurity + * @param leftPredict left node predict + * @param rightPredict right node predict */ @DeveloperApi class InformationGainStats( val gain: Double, val impurity: Double, val leftImpurity: Double, - val rightImpurity: Double) extends Serializable { + val rightImpurity: Double, + val leftPredict: Predict, + val rightPredict: Predict) extends Serializable { override def toString = { "gain = %f, impurity = %f, left impurity = %f, right impurity = %f" @@ -58,5 +62,6 @@ private[tree] object InformationGainStats { * denote that current split doesn't satisfies minimum info gain or * minimum number of instances per node. */ - val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0) + val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, + new Predict(0.0, 0.0), new Predict(0.0, 0.0)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 56c3e25d9285f..2179da8dbe03e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -32,7 +32,8 @@ import org.apache.spark.mllib.linalg.Vector * * @param id integer node id, from 1 * @param predict predicted value at the node - * @param isLeaf whether the leaf is a node + * @param impurity current node impurity + * @param isLeaf whether the node is a leaf * @param split split to calculate left and right nodes * @param leftNode left child * @param rightNode right child @@ -41,7 +42,8 @@ import org.apache.spark.mllib.linalg.Vector @DeveloperApi class Node ( val id: Int, - var predict: Double, + var predict: Predict, + var impurity: Double, var isLeaf: Boolean, var split: Option[Split], var leftNode: Option[Node], @@ -49,7 +51,7 @@ class Node ( var stats: Option[InformationGainStats]) extends Serializable with Logging { override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + - "split = " + split + ", stats = " + stats + "impurity = " + impurity + "split = " + split + ", stats = " + stats /** * build the left node and right nodes if not leaf @@ -62,6 +64,7 @@ class Node ( logDebug("id = " + id + ", split = " + split) logDebug("stats = " + stats) logDebug("predict = " + predict) + logDebug("impurity = " + impurity) if (!isLeaf) { leftNode = Some(nodes(Node.leftChildIndex(id))) rightNode = Some(nodes(Node.rightChildIndex(id))) @@ -77,7 +80,7 @@ class Node ( */ def predict(features: Vector) : Double = { if (isLeaf) { - predict + predict.predict } else{ if (split.get.featureType == Continuous) { if (features(split.get.feature) <= split.get.threshold) { @@ -109,7 +112,7 @@ class Node ( } else { Some(rightNode.get.deepCopy()) } - new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats) + new Node(id, predict, impurity, isLeaf, split, leftNodeCopy, rightNodeCopy, stats) } /** @@ -154,7 +157,7 @@ class Node ( } val prefix: String = " " * indentFactor if (isLeaf) { - prefix + s"Predict: $predict\n" + prefix + s"Predict: ${predict.predict}\n" } else { prefix + s"If ${splitToString(split.get, left=true)}\n" + leftNode.get.subtreeToString(indentFactor + 1) + @@ -170,7 +173,27 @@ private[tree] object Node { /** * Return a node with the given node id (but nothing else set). */ - def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None) + def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0, + false, None, None, None, None) + + /** + * Construct a node with nodeIndex, predict, impurity and isLeaf parameters. + * This is used in `DecisionTree.findBestSplits` to construct child nodes + * after finding the best splits for parent nodes. + * Other fields are set at next level. + * @param nodeIndex integer node id, from 1 + * @param predict predicted value at the node + * @param impurity current node impurity + * @param isLeaf whether the node is a leaf + * @return new node instance + */ + def apply( + nodeIndex: Int, + predict: Predict, + impurity: Double, + isLeaf: Boolean): Node = { + new Node(nodeIndex, predict, impurity, isLeaf, None, None, None, None) + } /** * Return the index of the left child of this node. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala index fb76dccfdf79e..2bf9d9816ae45 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.feature import org.scalatest.FunSuite +import breeze.linalg.{norm => brzNorm} + import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -50,10 +52,10 @@ class NormalizerSuite extends FunSuite with LocalSparkContext { assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) - assert(data1(0).toBreeze.norm(1) ~== 1.0 absTol 1E-5) - assert(data1(2).toBreeze.norm(1) ~== 1.0 absTol 1E-5) - assert(data1(3).toBreeze.norm(1) ~== 1.0 absTol 1E-5) - assert(data1(4).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(0).toBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(2).toBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(3).toBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(4).toBreeze, 1) ~== 1.0 absTol 1E-5) assert(data1(0) ~== Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))) absTol 1E-5) assert(data1(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) @@ -77,10 +79,10 @@ class NormalizerSuite extends FunSuite with LocalSparkContext { assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) - assert(data2(0).toBreeze.norm(2) ~== 1.0 absTol 1E-5) - assert(data2(2).toBreeze.norm(2) ~== 1.0 absTol 1E-5) - assert(data2(3).toBreeze.norm(2) ~== 1.0 absTol 1E-5) - assert(data2(4).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(0).toBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(2).toBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(3).toBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(4).toBreeze, 2) ~== 1.0 absTol 1E-5) assert(data2(0) ~== Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))) absTol 1E-5) assert(data2(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index a48ed71a1c5fc..98a72b0c4d750 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -253,7 +253,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = rootNode.stats.get assert(stats.gain > 0) - assert(rootNode.predict === 1) + assert(rootNode.predict.predict === 1) assert(stats.impurity > 0.2) } @@ -282,7 +282,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = rootNode.stats.get assert(stats.gain > 0) - assert(rootNode.predict === 0.6) + assert(rootNode.predict.predict === 0.6) assert(stats.impurity > 0.2) } @@ -352,7 +352,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.gain === 0) assert(stats.leftImpurity === 0) assert(stats.rightImpurity === 0) - assert(rootNode.predict === 1) + assert(rootNode.predict.predict === 1) } test("Binary classification stump with fixed label 0 for Entropy") { @@ -377,7 +377,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.gain === 0) assert(stats.leftImpurity === 0) assert(stats.rightImpurity === 0) - assert(rootNode.predict === 0) + assert(rootNode.predict.predict === 0) } test("Binary classification stump with fixed label 1 for Entropy") { @@ -402,7 +402,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.gain === 0) assert(stats.leftImpurity === 0) assert(stats.rightImpurity === 0) - assert(rootNode.predict === 1) + assert(rootNode.predict.predict === 1) } test("Second level node building with vs. without groups") { @@ -471,7 +471,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats1.impurity === stats2.impurity) assert(stats1.leftImpurity === stats2.leftImpurity) assert(stats1.rightImpurity === stats2.rightImpurity) - assert(children1(i).predict === children2(i).predict) + assert(children1(i).predict.predict === children2(i).predict.predict) } } @@ -646,7 +646,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val model = DecisionTree.train(rdd, strategy) assert(model.topNode.isLeaf) - assert(model.topNode.predict == 0.0) + assert(model.topNode.predict.predict == 0.0) val predicts = rdd.map(p => model.predict(p.features)).collect() predicts.foreach { predict => assert(predict == 0.0) @@ -693,7 +693,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val model = DecisionTree.train(input, strategy) assert(model.topNode.isLeaf) - assert(model.topNode.predict == 0.0) + assert(model.topNode.predict.predict == 0.0) val predicts = input.map(p => model.predict(p.features)).collect() predicts.foreach { predict => assert(predict == 0.0) @@ -705,6 +705,92 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val gain = rootNode.stats.get assert(gain == InformationGainStats.invalidInformationGainStats) } + + test("Avoid aggregation on the last level") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) + arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, + numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue leaf nodes into node queue + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } + + test("Avoid aggregation if impurity is 0.0") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) + arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue a node into node queue if its impurity is 0.0 + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } } object DecisionTreeSuite { diff --git a/pom.xml b/pom.xml index 7756c89b00cad..3b6d4ecbae2c1 100644 --- a/pom.xml +++ b/pom.xml @@ -118,7 +118,7 @@ 0.18.1 shaded-protobuf org.spark-project.akka - 2.2.3-shaded-protobuf + 2.3.4-spark 1.7.5 1.2.17 1.0.4 diff --git a/python/docs/conf.py b/python/docs/conf.py index c368cf81a003b..8e6324f058251 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -55,9 +55,9 @@ # built documents. # # The short X.Y version. -version = '1.1' +version = '1.2-SNAPSHOT' # The full version, including alpha/beta/rc tags. -release = '' +release = '1.2-SNAPSHOT' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -102,7 +102,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'default' +html_theme = 'nature' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -121,7 +121,7 @@ # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +html_logo = "../../docs/img/spark-logo-hd.png" # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 @@ -154,10 +154,10 @@ #html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +html_domain_indices = False # If false, no index is generated. -#html_use_index = True +html_use_index = False # If true, the index is split into individual pages for each letter. #html_split_index = False diff --git a/python/docs/index.rst b/python/docs/index.rst index 25b3f9bd93e63..d66e051b15371 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -3,7 +3,7 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -Welcome to PySpark API reference! +Welcome to Spark Python API Docs! =================================== Contents: @@ -24,14 +24,12 @@ Core classes: Main entry point for Spark functionality. :class:`pyspark.RDD` - + A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Indices and tables ================== -* :ref:`genindex` -* :ref:`modindex` * :ref:`search` diff --git a/python/docs/modules.rst b/python/docs/modules.rst deleted file mode 100644 index 183564659fbcf..0000000000000 --- a/python/docs/modules.rst +++ /dev/null @@ -1,7 +0,0 @@ -. -= - -.. toctree:: - :maxdepth: 4 - - pyspark diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst index e95d19e97f151..4548b8739ed91 100644 --- a/python/docs/pyspark.mllib.rst +++ b/python/docs/pyspark.mllib.rst @@ -20,6 +20,14 @@ pyspark.mllib.clustering module :undoc-members: :show-inheritance: +pyspark.mllib.feature module +------------------------------- + +.. automodule:: pyspark.mllib.feature + :members: + :undoc-members: + :show-inheritance: + pyspark.mllib.linalg module --------------------------- diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 1a2e774738fe7..e39e6514d77a1 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -20,33 +20,21 @@ Public classes: - - L{SparkContext} + - :class:`SparkContext`: Main entry point for Spark functionality. - - L{RDD} + - L{RDD} A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. - - L{Broadcast} + - L{Broadcast} A broadcast variable that gets reused across tasks. - - L{Accumulator} + - L{Accumulator} An "add-only" shared variable that tasks can only add values to. - - L{SparkConf} + - L{SparkConf} For configuring Spark. - - L{SparkFiles} + - L{SparkFiles} Access files shipped with jobs. - - L{StorageLevel} + - L{StorageLevel} Finer-grained cache persistence levels. -Spark SQL: - - L{SQLContext} - Main entry point for SQL functionality. - - L{SchemaRDD} - A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, SchemaRDDs also support SQL. - - L{Row} - A Row of data returned by a Spark SQL query. - -Hive: - - L{HiveContext} - Main entry point for accessing data stored in Apache Hive.. """ # The following block allows us to import python's random instead of mllib.random for scripts in diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index b64875a3f495a..dc7cd0bce56f3 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -83,11 +83,11 @@ def __init__(self, loadDefaults=True, _jvm=None, _jconf=None): """ Create a new Spark configuration. - @param loadDefaults: whether to load values from Java system + :param loadDefaults: whether to load values from Java system properties (True by default) - @param _jvm: internal parameter used to pass a handle to the + :param _jvm: internal parameter used to pass a handle to the Java VM; does not need to be set by users - @param _jconf: Optionally pass in an existing SparkConf handle + :param _jconf: Optionally pass in an existing SparkConf handle to use its parameters """ if _jconf: @@ -139,7 +139,7 @@ def setAll(self, pairs): """ Set multiple parameters, passed as a list of key-value pairs. - @param pairs: list of key-value pairs to set + :param pairs: list of key-value pairs to set """ for (k, v) in pairs: self._jconf.set(k, v) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index e9418320ff781..6fb30d65c5edd 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -73,21 +73,21 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. - @param master: Cluster URL to connect to + :param master: Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - @param appName: A name for your job, to display on the cluster web UI. - @param sparkHome: Location where Spark is installed on cluster nodes. - @param pyFiles: Collection of .zip or .py files to send to the cluster + :param appName: A name for your job, to display on the cluster web UI. + :param sparkHome: Location where Spark is installed on cluster nodes. + :param pyFiles: Collection of .zip or .py files to send to the cluster and add to PYTHONPATH. These can be paths on the local file system or HDFS, HTTP, HTTPS, or FTP URLs. - @param environment: A dictionary of environment variables to set on + :param environment: A dictionary of environment variables to set on worker nodes. - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. Set 1 to disable batching or -1 to use an unlimited batch size. - @param serializer: The serializer for RDDs. - @param conf: A L{SparkConf} object setting Spark properties. - @param gateway: Use an existing gateway and JVM, otherwise a new JVM + :param serializer: The serializer for RDDs. + :param conf: A L{SparkConf} object setting Spark properties. + :param gateway: Use an existing gateway and JVM, otherwise a new JVM will be instantiated. @@ -410,22 +410,23 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. The mechanism is as follows: + 1. A Java RDD is created from the SequenceFile or other InputFormat, and the key and value Writable classes 2. Serialization is attempted via Pyrolite pickling 3. If this fails, the fallback is to call 'toString' on each key and value 4. C{PickleSerializer} is used to deserialize pickled objects on the Python side - @param path: path to sequncefile - @param keyClass: fully qualified classname of key Writable class + :param path: path to sequncefile + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: - @param valueConverter: - @param minSplits: minimum splits in dataset + :param keyConverter: + :param valueConverter: + :param minSplits: minimum splits in dataset (default min(2, sc.defaultParallelism)) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ minSplits = minSplits or min(self.defaultParallelism, 2) @@ -445,18 +446,18 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv A Hadoop configuration can be passed in as a Python dict. This will be converted into a Configuration in Java - @param path: path to Hadoop file - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param path: path to Hadoop file + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapreduce.lib.input.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) @@ -475,17 +476,17 @@ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=N This will be converted into a Configuration in Java. The mechanism is the same as for sc.sequenceFile. - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapreduce.lib.input.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) @@ -506,18 +507,18 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter= A Hadoop configuration can be passed in as a Python dict. This will be converted into a Configuration in Java. - @param path: path to Hadoop file - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param path: path to Hadoop file + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapred.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) @@ -536,17 +537,17 @@ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, This will be converted into a Configuration in Java. The mechanism is the same as for sc.sequenceFile. - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapred.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index ac142fb49a90c..cd43982191702 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -79,21 +79,24 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, """ Train a logistic regression model on the given data. - @param data: The training data. - @param iterations: The number of iterations (default: 100). - @param step: The step parameter used in SGD + :param data: The training data. + :param iterations: The number of iterations (default: 100). + :param step: The step parameter used in SGD (default: 1.0). - @param miniBatchFraction: Fraction of data to be used for each SGD + :param miniBatchFraction: Fraction of data to be used for each SGD iteration. - @param initialWeights: The initial weights (default: None). - @param regParam: The regularizer parameter (default: 1.0). - @param regType: The type of regularizer used for training + :param initialWeights: The initial weights (default: None). + :param regParam: The regularizer parameter (default: 1.0). + :param regType: The type of regularizer used for training our model. - Allowed values: "l1" for using L1Updater, - "l2" for using - SquaredL2Updater, - "none" for no regularizer. - (default: "none") + + :Allowed values: + - "l1" for using L1Updater + - "l2" for using SquaredL2Updater + - "none" for no regularizer + + (default: "none") + @param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features @@ -148,21 +151,24 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, """ Train a support vector machine on the given data. - @param data: The training data. - @param iterations: The number of iterations (default: 100). - @param step: The step parameter used in SGD + :param data: The training data. + :param iterations: The number of iterations (default: 100). + :param step: The step parameter used in SGD (default: 1.0). - @param regParam: The regularizer parameter (default: 1.0). - @param miniBatchFraction: Fraction of data to be used for each SGD + :param regParam: The regularizer parameter (default: 1.0). + :param miniBatchFraction: Fraction of data to be used for each SGD iteration. - @param initialWeights: The initial weights (default: None). - @param regType: The type of regularizer used for training + :param initialWeights: The initial weights (default: None). + :param regType: The type of regularizer used for training our model. - Allowed values: "l1" for using L1Updater, - "l2" for using - SquaredL2Updater, - "none" for no regularizer. - (default: "none") + + :Allowed values: + - "l1" for using L1Updater + - "l2" for using SquaredL2Updater, + - "none" for no regularizer. + + (default: "none") + @param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features @@ -232,10 +238,10 @@ def train(cls, data, lambda_=1.0): classification. By making every vector a 0-1 vector, it can also be used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}). - @param data: RDD of NumPy vectors, one per element, where the first + :param data: RDD of NumPy vectors, one per element, where the first coordinate is the label and the rest is the feature vector (e.g. a count vector). - @param lambda_: The smoothing parameter + :param lambda_: The smoothing parameter """ sc = data.context jlist = sc._jvm.PythonMLLibAPI().trainNaiveBayes(data._to_java_object_rdd(), lambda_) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py new file mode 100644 index 0000000000000..a44a27fd3b6a6 --- /dev/null +++ b/python/pyspark/mllib/feature.py @@ -0,0 +1,193 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python package for feature in MLlib. +""" +from pyspark.serializers import PickleSerializer, AutoBatchedSerializer + +from pyspark.mllib.linalg import _convert_to_vector + +__all__ = ['Word2Vec', 'Word2VecModel'] + + +class Word2VecModel(object): + """ + class for Word2Vec model + """ + def __init__(self, sc, java_model): + """ + :param sc: Spark context + :param java_model: Handle to Java model object + """ + self._sc = sc + self._java_model = java_model + + def __del__(self): + self._sc._gateway.detach(self._java_model) + + def transform(self, word): + """ + :param word: a word + :return: vector representation of word + Transforms a word to its vector representation + + Note: local use only + """ + # TODO: make transform usable in RDD operations from python side + result = self._java_model.transform(word) + return PickleSerializer().loads(str(self._sc._jvm.SerDe.dumps(result))) + + def findSynonyms(self, x, num): + """ + :param x: a word or a vector representation of word + :param num: number of synonyms to find + :return: array of (word, cosineSimilarity) + Find synonyms of a word + + Note: local use only + """ + # TODO: make findSynonyms usable in RDD operations from python side + ser = PickleSerializer() + if type(x) == str: + jlist = self._java_model.findSynonyms(x, num) + else: + bytes = bytearray(ser.dumps(_convert_to_vector(x))) + vec = self._sc._jvm.SerDe.loads(bytes) + jlist = self._java_model.findSynonyms(vec, num) + words, similarity = ser.loads(str(self._sc._jvm.SerDe.dumps(jlist))) + return zip(words, similarity) + + +class Word2Vec(object): + """ + Word2Vec creates vector representation of words in a text corpus. + The algorithm first constructs a vocabulary from the corpus + and then learns vector representation of words in the vocabulary. + The vector representation can be used as features in + natural language processing and machine learning algorithms. + + We used skip-gram model in our implementation and hierarchical softmax + method to train the model. The variable names in the implementation + matches the original C implementation. + For original C implementation, see https://code.google.com/p/word2vec/ + For research papers, see + Efficient Estimation of Word Representations in Vector Space + and + Distributed Representations of Words and Phrases and their Compositionality. + + >>> sentence = "a b " * 100 + "a c " * 10 + >>> localDoc = [sentence, sentence] + >>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" ")) + >>> model = Word2Vec().setVectorSize(10).setSeed(42L).fit(doc) + >>> syms = model.findSynonyms("a", 2) + >>> str(syms[0][0]) + 'b' + >>> str(syms[1][0]) + 'c' + >>> len(syms) + 2 + >>> vec = model.transform("a") + >>> len(vec) + 10 + >>> syms = model.findSynonyms(vec, 2) + >>> str(syms[0][0]) + 'b' + >>> str(syms[1][0]) + 'c' + >>> len(syms) + 2 + """ + def __init__(self): + """ + Construct Word2Vec instance + """ + self.vectorSize = 100 + self.learningRate = 0.025 + self.numPartitions = 1 + self.numIterations = 1 + self.seed = 42L + + def setVectorSize(self, vectorSize): + """ + Sets vector size (default: 100). + """ + self.vectorSize = vectorSize + return self + + def setLearningRate(self, learningRate): + """ + Sets initial learning rate (default: 0.025). + """ + self.learningRate = learningRate + return self + + def setNumPartitions(self, numPartitions): + """ + Sets number of partitions (default: 1). Use a small number for accuracy. + """ + self.numPartitions = numPartitions + return self + + def setNumIterations(self, numIterations): + """ + Sets number of iterations (default: 1), which should be smaller than or equal to number of + partitions. + """ + self.numIterations = numIterations + return self + + def setSeed(self, seed): + """ + Sets random seed. + """ + self.seed = seed + return self + + def fit(self, data): + """ + Computes the vector representation of each word in vocabulary. + + :param data: training data. RDD of subtype of Iterable[String] + :return: python Word2VecModel instance + """ + sc = data.context + ser = PickleSerializer() + vectorSize = self.vectorSize + learningRate = self.learningRate + numPartitions = self.numPartitions + numIterations = self.numIterations + seed = self.seed + + model = sc._jvm.PythonMLLibAPI().trainWord2Vec( + data._to_java_object_rdd(), vectorSize, + learningRate, numPartitions, numIterations, seed) + return Word2VecModel(sc, model) + + +def _test(): + import doctest + from pyspark import SparkContext + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 51014a8ceb785..24c5480b2f753 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -238,8 +238,8 @@ def __init__(self, size, *args): (index, value) pairs, or two separate arrays of indices and values (sorted by index). - @param size: Size of the vector. - @param args: Non-zero entries, as a dictionary, list of tupes, + :param size: Size of the vector. + :param args: Non-zero entries, as a dictionary, list of tupes, or two sorted lists containing indices and values. >>> print SparseVector(4, {1: 1.0, 3: 5.5}) @@ -458,8 +458,8 @@ def sparse(size, *args): (index, value) pairs, or two separate arrays of indices and values (sorted by index). - @param size: Size of the vector. - @param args: Non-zero entries, as a dictionary, list of tupes, + :param size: Size of the vector. + :param args: Non-zero entries, as a dictionary, list of tupes, or two sorted lists containing indices and values. >>> print Vectors.sparse(4, {1: 1.0, 3: 5.5}) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index cbdbc09858013..12b322aaae796 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -22,7 +22,7 @@ from pyspark.mllib.linalg import SparseVector, _convert_to_vector from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel' +__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel', 'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD'] @@ -31,8 +31,8 @@ class LabeledPoint(object): """ The features and labels of a data point. - @param label: Label for this data point. - @param features: Vector of features for this point (NumPy array, list, + :param label: Label for this data point. + :param features: Vector of features for this point (NumPy array, list, pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix) """ @@ -66,6 +66,9 @@ def weights(self): def intercept(self): return self._intercept + def __repr__(self): + return "(weights=%s, intercept=%s)" % (self._coeff, self._intercept) + class LinearRegressionModelBase(LinearModel): @@ -142,21 +145,24 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, """ Train a linear regression model on the given data. - @param data: The training data. - @param iterations: The number of iterations (default: 100). - @param step: The step parameter used in SGD + :param data: The training data. + :param iterations: The number of iterations (default: 100). + :param step: The step parameter used in SGD (default: 1.0). - @param miniBatchFraction: Fraction of data to be used for each SGD + :param miniBatchFraction: Fraction of data to be used for each SGD iteration. - @param initialWeights: The initial weights (default: None). - @param regParam: The regularizer parameter (default: 1.0). - @param regType: The type of regularizer used for training + :param initialWeights: The initial weights (default: None). + :param regParam: The regularizer parameter (default: 1.0). + :param regType: The type of regularizer used for training our model. - Allowed values: "l1" for using L1Updater, - "l2" for using - SquaredL2Updater, - "none" for no regularizer. - (default: "none") + + :Allowed values: + - "l1" for using L1Updater, + - "l2" for using SquaredL2Updater, + - "none" for no regularizer. + + (default: "none") + @param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index f72e88ba6e2ba..5c20e100e144f 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -32,7 +32,7 @@ from pyspark.serializers import PickleSerializer from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint -from pyspark.tests import PySparkTestCase +from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase _have_scipy = False diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index afdcdbdf3ae01..5d7abfb96b7fe 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -48,6 +48,7 @@ def __del__(self): def predict(self, x): """ Predict the label of one or more examples. + :param x: Data point (feature vector), or an RDD of data points (feature vectors). """ diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 8233d4e81f1ca..1357fd4fbc8aa 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -77,10 +77,10 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None method parses each line into a LabeledPoint, where the feature indices are converted to zero-based. - @param sc: Spark context - @param path: file or directory path in any Hadoop-supported file + :param sc: Spark context + :param path: file or directory path in any Hadoop-supported file system URI - @param numFeatures: number of features, which will be determined + :param numFeatures: number of features, which will be determined from the input data if a nonpositive value is given. This is useful when the dataset is already split into multiple files and you @@ -88,7 +88,7 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None features may not present in certain files, which leads to inconsistent feature dimensions. - @param minPartitions: min number of partitions + :param minPartitions: min number of partitions @return: labeled data stored as an RDD of LabeledPoint >>> from tempfile import NamedTemporaryFile @@ -126,8 +126,8 @@ def saveAsLibSVMFile(data, dir): """ Save labeled data in LIBSVM format. - @param data: an RDD of LabeledPoint to be saved - @param dir: directory to save the data + :param data: an RDD of LabeledPoint to be saved + :param dir: directory to save the data >>> from tempfile import NamedTemporaryFile >>> from fileinput import input @@ -149,10 +149,10 @@ def loadLabeledPoints(sc, path, minPartitions=None): """ Load labeled points saved using RDD.saveAsTextFile. - @param sc: Spark context - @param path: file or directory path in any Hadoop-supported file + :param sc: Spark context + :param path: file or directory path in any Hadoop-supported file system URI - @param minPartitions: min number of partitions + :param minPartitions: min number of partitions @return: labeled data stored as an RDD of LabeledPoint >>> from tempfile import NamedTemporaryFile diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index dc6497772e502..6797d50659a92 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -752,7 +752,7 @@ def max(self, key=None): """ Find the maximum item in this RDD. - @param key: A function used to generate key for comparing + :param key: A function used to generate key for comparing >>> rdd = sc.parallelize([1.0, 5.0, 43.0, 10.0]) >>> rdd.max() @@ -768,7 +768,7 @@ def min(self, key=None): """ Find the minimum item in this RDD. - @param key: A function used to generate key for comparing + :param key: A function used to generate key for comparing >>> rdd = sc.parallelize([2.0, 5.0, 43.0, 10.0]) >>> rdd.min() @@ -1115,9 +1115,9 @@ def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None converted for output using either user specified converters or, by default, L{org.apache.spark.api.python.JavaToWritableConverter}. - @param conf: Hadoop job configuration, passed in as a dict - @param keyConverter: (None by default) - @param valueConverter: (None by default) + :param conf: Hadoop job configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._toPickleSerialization() @@ -1135,16 +1135,16 @@ def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueCl C{conf} is applied on top of the base Hadoop conf associated with the SparkContext of this RDD to create a merged Hadoop MapReduce job configuration for saving the data. - @param path: path to Hadoop file - @param outputFormatClass: fully qualified classname of Hadoop OutputFormat + :param path: path to Hadoop file + :param outputFormatClass: fully qualified classname of Hadoop OutputFormat (e.g. "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.IntWritable", None by default) - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.Text", None by default) - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop job configuration, passed in as a dict (None by default) + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop job configuration, passed in as a dict (None by default) """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._toPickleSerialization() @@ -1161,9 +1161,9 @@ def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None): converted for output using either user specified converters or, by default, L{org.apache.spark.api.python.JavaToWritableConverter}. - @param conf: Hadoop job configuration, passed in as a dict - @param keyConverter: (None by default) - @param valueConverter: (None by default) + :param conf: Hadoop job configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._toPickleSerialization() @@ -1182,17 +1182,17 @@ def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=No C{conf} is applied on top of the base Hadoop conf associated with the SparkContext of this RDD to create a merged Hadoop MapReduce job configuration for saving the data. - @param path: path to Hadoop file - @param outputFormatClass: fully qualified classname of Hadoop OutputFormat + :param path: path to Hadoop file + :param outputFormatClass: fully qualified classname of Hadoop OutputFormat (e.g. "org.apache.hadoop.mapred.SequenceFileOutputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.IntWritable", None by default) - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.Text", None by default) - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: (None by default) - @param compressionCodecClass: (None by default) + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: (None by default) + :param compressionCodecClass: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._toPickleSerialization() @@ -1208,11 +1208,12 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None): Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file system, using the L{org.apache.hadoop.io.Writable} types that we convert from the RDD's key and value types. The mechanism is as follows: + 1. Pyrolite is used to convert pickled Python RDD into RDD of Java objects. 2. Keys and values of this Java RDD are converted to Writables and written out. - @param path: path to sequence file - @param compressionCodecClass: (None by default) + :param path: path to sequence file + :param compressionCodecClass: (None by default) """ pickledRDD = self._toPickleSerialization() batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) @@ -2008,7 +2009,7 @@ def countApproxDistinct(self, relativeSD=0.05): of The Art Cardinality Estimation Algorithm", available here. - @param relativeSD Relative accuracy. Smaller values create + :param relativeSD Relative accuracy. Smaller values create counters that require more space. It must be greater than 0.000017. diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 2672da36c1f50..099fa54cf2bd7 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -211,7 +211,7 @@ def __eq__(self, other): return (isinstance(other, BatchedSerializer) and other.serializer == self.serializer) - def __str__(self): + def __repr__(self): return "BatchedSerializer<%s>" % str(self.serializer) @@ -279,7 +279,7 @@ def __eq__(self, other): return (isinstance(other, CartesianDeserializer) and self.key_ser == other.key_ser and self.val_ser == other.val_ser) - def __str__(self): + def __repr__(self): return "CartesianDeserializer<%s, %s>" % \ (str(self.key_ser), str(self.val_ser)) @@ -306,7 +306,7 @@ def __eq__(self, other): return (isinstance(other, PairDeserializer) and self.key_ser == other.key_ser and self.val_ser == other.val_ser) - def __str__(self): + def __repr__(self): return "PairDeserializer<%s, %s>" % (str(self.key_ser), str(self.val_ser)) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index ce597cbe91e15..d57a802e4734a 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -396,7 +396,6 @@ def _external_items(self): for v in self.data.iteritems(): yield v self.data.clear() - gc.collect() # remove the merged partition for j in range(self.spills): @@ -428,7 +427,7 @@ def _recursive_merged_items(self, start): subdirs = [os.path.join(d, "parts", str(i)) for d in self.localdirs] m = ExternalMerger(self.agg, self.memory_limit, self.serializer, - subdirs, self.scale * self.partitions) + subdirs, self.scale * self.partitions, self.partitions) m.pdata = [{} for _ in range(self.partitions)] limit = self._next_limit() @@ -486,7 +485,7 @@ def sorted(self, iterator, key=None, reverse=False): goes above the limit. """ global MemoryBytesSpilled, DiskBytesSpilled - batch = 10 + batch = 100 chunks, current_chunk = [], [] iterator = iter(iterator) while True: diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 974b5e287bc00..d3d36eb995ab6 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -15,28 +15,38 @@ # limitations under the License. # +""" +public classes of Spark SQL: + + - L{SQLContext} + Main entry point for SQL functionality. + - L{SchemaRDD} + A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In + addition to normal RDD operations, SchemaRDDs also support SQL. + - L{Row} + A Row of data returned by a Spark SQL query. + - L{HiveContext} + Main entry point for accessing data stored in Apache Hive.. +""" -import sys -import types import itertools -import warnings import decimal import datetime import keyword import warnings +import json from array import array from operator import itemgetter +from itertools import imap + +from py4j.protocol import Py4JError +from py4j.java_collections import ListConverter, MapConverter from pyspark.rdd import RDD from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync -from itertools import chain, ifilter, imap - -from py4j.protocol import Py4JError -from py4j.java_collections import ListConverter, MapConverter - __all__ = [ "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", @@ -62,6 +72,18 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + @classmethod + def typeName(cls): + return cls.__name__[:-4].lower() + + def jsonValue(self): + return self.typeName() + + def json(self): + return json.dumps(self.jsonValue(), + separators=(',', ':'), + sort_keys=True) + class PrimitiveTypeSingleton(type): @@ -201,10 +223,20 @@ def __init__(self, elementType, containsNull=True): self.elementType = elementType self.containsNull = containsNull - def __str__(self): + def __repr__(self): return "ArrayType(%s,%s)" % (self.elementType, str(self.containsNull).lower()) + def jsonValue(self): + return {"type": self.typeName(), + "elementType": self.elementType.jsonValue(), + "containsNull": self.containsNull} + + @classmethod + def fromJson(cls, json): + return ArrayType(_parse_datatype_json_value(json["elementType"]), + json["containsNull"]) + class MapType(DataType): @@ -245,6 +277,18 @@ def __repr__(self): return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, str(self.valueContainsNull).lower()) + def jsonValue(self): + return {"type": self.typeName(), + "keyType": self.keyType.jsonValue(), + "valueType": self.valueType.jsonValue(), + "valueContainsNull": self.valueContainsNull} + + @classmethod + def fromJson(cls, json): + return MapType(_parse_datatype_json_value(json["keyType"]), + _parse_datatype_json_value(json["valueType"]), + json["valueContainsNull"]) + class StructField(DataType): @@ -283,6 +327,17 @@ def __repr__(self): return "StructField(%s,%s,%s)" % (self.name, self.dataType, str(self.nullable).lower()) + def jsonValue(self): + return {"name": self.name, + "type": self.dataType.jsonValue(), + "nullable": self.nullable} + + @classmethod + def fromJson(cls, json): + return StructField(json["name"], + _parse_datatype_json_value(json["type"]), + json["nullable"]) + class StructType(DataType): @@ -312,42 +367,30 @@ def __repr__(self): return ("StructType(List(%s))" % ",".join(str(field) for field in self.fields)) + def jsonValue(self): + return {"type": self.typeName(), + "fields": [f.jsonValue() for f in self.fields]} -def _parse_datatype_list(datatype_list_string): - """Parses a list of comma separated data types.""" - index = 0 - datatype_list = [] - start = 0 - depth = 0 - while index < len(datatype_list_string): - if depth == 0 and datatype_list_string[index] == ",": - datatype_string = datatype_list_string[start:index].strip() - datatype_list.append(_parse_datatype_string(datatype_string)) - start = index + 1 - elif datatype_list_string[index] == "(": - depth += 1 - elif datatype_list_string[index] == ")": - depth -= 1 + @classmethod + def fromJson(cls, json): + return StructType([StructField.fromJson(f) for f in json["fields"]]) - index += 1 - # Handle the last data type - datatype_string = datatype_list_string[start:index].strip() - datatype_list.append(_parse_datatype_string(datatype_string)) - return datatype_list +_all_primitive_types = dict((v.typeName(), v) + for v in globals().itervalues() + if type(v) is PrimitiveTypeSingleton and + v.__base__ == PrimitiveType) -_all_primitive_types = dict((k, v) for k, v in globals().iteritems() - if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType) +_all_complex_types = dict((v.typeName(), v) + for v in [ArrayType, MapType, StructType]) -def _parse_datatype_string(datatype_string): - """Parses the given data type string. - +def _parse_datatype_json_string(json_string): + """Parses the given data type JSON string. >>> def check_datatype(datatype): - ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype)) - ... python_datatype = _parse_datatype_string( - ... scala_datatype.toString()) + ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) + ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... return datatype == python_datatype >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) True @@ -385,51 +428,14 @@ def _parse_datatype_string(datatype_string): >>> check_datatype(complex_maptype) True """ - index = datatype_string.find("(") - if index == -1: - # It is a primitive type. - index = len(datatype_string) - type_or_field = datatype_string[:index] - rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip() - - if type_or_field in _all_primitive_types: - return _all_primitive_types[type_or_field]() - - elif type_or_field == "ArrayType": - last_comma_index = rest_part.rfind(",") - containsNull = True - if rest_part[last_comma_index + 1:].strip().lower() == "false": - containsNull = False - elementType = _parse_datatype_string( - rest_part[:last_comma_index].strip()) - return ArrayType(elementType, containsNull) - - elif type_or_field == "MapType": - last_comma_index = rest_part.rfind(",") - valueContainsNull = True - if rest_part[last_comma_index + 1:].strip().lower() == "false": - valueContainsNull = False - keyType, valueType = _parse_datatype_list( - rest_part[:last_comma_index].strip()) - return MapType(keyType, valueType, valueContainsNull) - - elif type_or_field == "StructField": - first_comma_index = rest_part.find(",") - name = rest_part[:first_comma_index].strip() - last_comma_index = rest_part.rfind(",") - nullable = True - if rest_part[last_comma_index + 1:].strip().lower() == "false": - nullable = False - dataType = _parse_datatype_string( - rest_part[first_comma_index + 1:last_comma_index].strip()) - return StructField(name, dataType, nullable) - - elif type_or_field == "StructType": - # rest_part should be in the format like - # List(StructField(field1,IntegerType,false)). - field_list_string = rest_part[rest_part.find("(") + 1:-1] - fields = _parse_datatype_list(field_list_string) - return StructType(fields) + return _parse_datatype_json_value(json.loads(json_string)) + + +def _parse_datatype_json_value(json_value): + if type(json_value) is unicode and json_value in _all_primitive_types.keys(): + return _all_primitive_types[json_value]() + else: + return _all_complex_types[json_value["type"]].fromJson(json_value) # Mapping Python types to Spark SQL DateType @@ -899,8 +905,8 @@ class SQLContext(object): def __init__(self, sparkContext, sqlContext=None): """Create a new SQLContext. - @param sparkContext: The SparkContext to wrap. - @param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new + :param sparkContext: The SparkContext to wrap. + :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new SQLContext in the JVM, instead we make all calls to this object. >>> srdd = sqlCtx.inferSchema(rdd) @@ -983,7 +989,7 @@ def registerFunction(self, name, f, returnType=StringType()): self._sc.pythonExec, broadcast_vars, self._sc._javaAccumulator, - str(returnType)) + returnType.json()) def inferSchema(self, rdd): """Infer and apply a schema to an RDD of L{Row}. @@ -1119,7 +1125,7 @@ def applySchema(self, rdd, schema): batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) jrdd = self._pythonToJava(rdd._jrdd, batched) - srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema)) + srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) return SchemaRDD(srdd.toJavaSchemaRDD(), self) def registerRDDAsTable(self, rdd, tableName): @@ -1209,7 +1215,7 @@ def jsonFile(self, path, schema=None): if schema is None: srdd = self._ssql_ctx.jsonFile(path) else: - scala_datatype = self._ssql_ctx.parseDataType(str(schema)) + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonFile(path, scala_datatype) return SchemaRDD(srdd.toJavaSchemaRDD(), self) @@ -1279,7 +1285,7 @@ def func(iterator): if schema is None: srdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) else: - scala_datatype = self._ssql_ctx.parseDataType(str(schema)) + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return SchemaRDD(srdd.toJavaSchemaRDD(), self) @@ -1325,8 +1331,8 @@ class HiveContext(SQLContext): def __init__(self, sparkContext, hiveContext=None): """Create a new HiveContext. - @param sparkContext: The SparkContext to wrap. - @param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new + :param sparkContext: The SparkContext to wrap. + :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new HiveContext in the JVM, instead we make all calls to this object. """ SQLContext.__init__(self, sparkContext) @@ -1614,7 +1620,7 @@ def saveAsTable(self, tableName): def schema(self): """Returns the schema of this SchemaRDD (represented by a L{StructType}).""" - return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString()) + return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json()) def schemaString(self): """Returns the output schema in the tree format.""" diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 6fb6bc998c752..7f05d48ade2b3 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -67,10 +67,10 @@ SPARK_HOME = os.environ["SPARK_HOME"] -class TestMerger(unittest.TestCase): +class MergerTests(unittest.TestCase): def setUp(self): - self.N = 1 << 16 + self.N = 1 << 14 self.l = [i for i in xrange(self.N)] self.data = zip(self.l, self.l) self.agg = Aggregator(lambda x: [x], @@ -115,7 +115,7 @@ def test_medium_dataset(self): sum(xrange(self.N)) * 3) def test_huge_dataset(self): - m = ExternalMerger(self.agg, 10) + m = ExternalMerger(self.agg, 10, partitions=3) m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10)) self.assertTrue(m.spills >= 1) self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)), @@ -123,7 +123,7 @@ def test_huge_dataset(self): m._cleanup() -class TestSorter(unittest.TestCase): +class SorterTests(unittest.TestCase): def test_in_memory_sort(self): l = range(1024) random.shuffle(l) @@ -244,16 +244,25 @@ def tearDown(self): sys.path = self._old_sys_path -class TestCheckpoint(PySparkTestCase): +class ReusedPySparkTestCase(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.sc = SparkContext('local[4]', cls.__name__, batchSize=2) + + @classmethod + def tearDownClass(cls): + cls.sc.stop() + + +class CheckpointTests(ReusedPySparkTestCase): def setUp(self): - PySparkTestCase.setUp(self) self.checkpointDir = tempfile.NamedTemporaryFile(delete=False) os.unlink(self.checkpointDir.name) self.sc.setCheckpointDir(self.checkpointDir.name) def tearDown(self): - PySparkTestCase.tearDown(self) shutil.rmtree(self.checkpointDir.name) def test_basic_checkpointing(self): @@ -288,7 +297,7 @@ def test_checkpoint_and_restore(self): self.assertEquals([1, 2, 3, 4], recovered.collect()) -class TestAddFile(PySparkTestCase): +class AddFileTests(PySparkTestCase): def test_add_py_file(self): # To ensure that we're actually testing addPyFile's effects, check that @@ -354,7 +363,7 @@ def func(x): self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) -class TestRDDFunctions(PySparkTestCase): +class RDDTests(ReusedPySparkTestCase): def test_id(self): rdd = self.sc.parallelize(range(10)) @@ -365,12 +374,6 @@ def test_id(self): self.assertEqual(id + 1, id2) self.assertEqual(id2, rdd2.id()) - def test_failed_sparkcontext_creation(self): - # Regression test for SPARK-1550 - self.sc.stop() - self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) - self.sc = SparkContext("local") - def test_save_as_textfile_with_unicode(self): # Regression test for SPARK-970 x = u"\u00A1Hola, mundo!" @@ -636,7 +639,7 @@ def test_distinct(self): self.assertEquals(result.count(), 3) -class TestProfiler(PySparkTestCase): +class ProfilerTests(PySparkTestCase): def setUp(self): self._old_sys_path = list(sys.path) @@ -666,10 +669,9 @@ def heavy_foo(x): self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) -class TestSQL(PySparkTestCase): +class SQLTests(ReusedPySparkTestCase): def setUp(self): - PySparkTestCase.setUp(self) self.sqlCtx = SQLContext(self.sc) def test_udf(self): @@ -754,27 +756,19 @@ def test_serialize_nested_array_and_map(self): self.assertEqual("2", row.d) -class TestIO(PySparkTestCase): - - def test_stdout_redirection(self): - import subprocess - - def func(x): - subprocess.check_call('ls', shell=True) - self.sc.parallelize([1]).foreach(func) +class InputFormatTests(ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc) -class TestInputFormat(PySparkTestCase): - - def setUp(self): - PySparkTestCase.setUp(self) - self.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(self.tempdir.name) - self.sc._jvm.WriteInputFormatTestDataGenerator.generateData(self.tempdir.name, self.sc._jsc) - - def tearDown(self): - PySparkTestCase.tearDown(self) - shutil.rmtree(self.tempdir.name) + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name) def test_sequencefiles(self): basepath = self.tempdir.name @@ -954,15 +948,13 @@ def test_converters(self): self.assertEqual(maps, em) -class TestOutputFormat(PySparkTestCase): +class OutputFormatTests(ReusedPySparkTestCase): def setUp(self): - PySparkTestCase.setUp(self) self.tempdir = tempfile.NamedTemporaryFile(delete=False) os.unlink(self.tempdir.name) def tearDown(self): - PySparkTestCase.tearDown(self) shutil.rmtree(self.tempdir.name, ignore_errors=True) def test_sequencefiles(self): @@ -1243,8 +1235,7 @@ def test_malformed_RDD(self): basepath + "/malformed/sequence")) -class TestDaemon(unittest.TestCase): - +class DaemonTests(unittest.TestCase): def connect(self, port): from socket import socket, AF_INET, SOCK_STREAM sock = socket(AF_INET, SOCK_STREAM) @@ -1290,7 +1281,7 @@ def test_termination_sigterm(self): self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) -class TestWorker(PySparkTestCase): +class WorkerTests(PySparkTestCase): def test_cancel_task(self): temp = tempfile.NamedTemporaryFile(delete=True) @@ -1342,11 +1333,6 @@ def run(): rdd = self.sc.parallelize(range(100), 1) self.assertEqual(100, rdd.map(str).count()) - def test_fd_leak(self): - N = 1100 # fd limit is 1024 by default - rdd = self.sc.parallelize(range(N), N) - self.assertEquals(N, rdd.count()) - def test_after_exception(self): def raise_exception(_): raise Exception() @@ -1379,7 +1365,7 @@ def test_accumulator_when_reuse_worker(self): self.assertEqual(sum(range(100)), acc1.value) -class TestSparkSubmit(unittest.TestCase): +class SparkSubmitTests(unittest.TestCase): def setUp(self): self.programDir = tempfile.mkdtemp() @@ -1492,6 +1478,8 @@ def test_single_script_on_cluster(self): |sc = SparkContext() |print sc.parallelize([1, 2, 3]).map(foo).collect() """) + # this will fail if you have different spark.executor.memory + # in conf/spark-defaults.conf proc = subprocess.Popen( [self.sparkSubmit, "--master", "local-cluster[1,1,512]", script], stdout=subprocess.PIPE) @@ -1500,7 +1488,11 @@ def test_single_script_on_cluster(self): self.assertIn("[2, 4, 6]", out) -class ContextStopTests(unittest.TestCase): +class ContextTests(unittest.TestCase): + + def test_failed_sparkcontext_creation(self): + # Regression test for SPARK-1550 + self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) def test_stop(self): sc = SparkContext() diff --git a/python/run-tests b/python/run-tests index a7ec270c7da21..63395f72788f9 100755 --- a/python/run-tests +++ b/python/run-tests @@ -34,7 +34,7 @@ rm -rf metastore warehouse function run_test() { echo "Running test: $1" - SPARK_TESTING=1 "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log + SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log FAILED=$((PIPESTATUS[0]||$FAILED)) @@ -48,6 +48,38 @@ function run_test() { fi } +function run_core_tests() { + echo "Run core tests ..." + run_test "pyspark/rdd.py" + run_test "pyspark/context.py" + run_test "pyspark/conf.py" + PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" + PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" + PYSPARK_DOC_TEST=1 run_test "pyspark/serializers.py" + run_test "pyspark/shuffle.py" + run_test "pyspark/tests.py" +} + +function run_sql_tests() { + echo "Run sql tests ..." + run_test "pyspark/sql.py" +} + +function run_mllib_tests() { + echo "Run mllib tests ..." + run_test "pyspark/mllib/classification.py" + run_test "pyspark/mllib/clustering.py" + run_test "pyspark/mllib/feature.py" + run_test "pyspark/mllib/linalg.py" + run_test "pyspark/mllib/random.py" + run_test "pyspark/mllib/recommendation.py" + run_test "pyspark/mllib/regression.py" + run_test "pyspark/mllib/stat.py" + run_test "pyspark/mllib/tree.py" + run_test "pyspark/mllib/util.py" + run_test "pyspark/mllib/tests.py" +} + echo "Running PySpark tests. Output is in python/unit-tests.log." export PYSPARK_PYTHON="python" @@ -60,29 +92,9 @@ fi echo "Testing with Python version:" $PYSPARK_PYTHON --version -run_test "pyspark/rdd.py" -run_test "pyspark/context.py" -run_test "pyspark/conf.py" -run_test "pyspark/sql.py" -# These tests are included in the module-level docs, and so must -# be handled on a higher level rather than within the python file. -export PYSPARK_DOC_TEST=1 -run_test "pyspark/broadcast.py" -run_test "pyspark/accumulators.py" -run_test "pyspark/serializers.py" -unset PYSPARK_DOC_TEST -run_test "pyspark/shuffle.py" -run_test "pyspark/tests.py" -run_test "pyspark/mllib/classification.py" -run_test "pyspark/mllib/clustering.py" -run_test "pyspark/mllib/linalg.py" -run_test "pyspark/mllib/random.py" -run_test "pyspark/mllib/recommendation.py" -run_test "pyspark/mllib/regression.py" -run_test "pyspark/mllib/stat.py" -run_test "pyspark/mllib/tests.py" -run_test "pyspark/mllib/tree.py" -run_test "pyspark/mllib/util.py" +run_core_tests +run_sql_tests +run_mllib_tests # Try to test with PyPy if [ $(which pypy) ]; then @@ -90,19 +102,8 @@ if [ $(which pypy) ]; then echo "Testing with PyPy version:" $PYSPARK_PYTHON --version - run_test "pyspark/rdd.py" - run_test "pyspark/context.py" - run_test "pyspark/conf.py" - run_test "pyspark/sql.py" - # These tests are included in the module-level docs, and so must - # be handled on a higher level rather than within the python file. - export PYSPARK_DOC_TEST=1 - run_test "pyspark/broadcast.py" - run_test "pyspark/accumulators.py" - run_test "pyspark/serializers.py" - unset PYSPARK_DOC_TEST - run_test "pyspark/shuffle.py" - run_test "pyspark/tests.py" + run_core_tests + run_sql_tests fi if [[ $FAILED == 0 ]]; then diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 6ddb6accd696b..646c68e60c2e9 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -84,9 +84,11 @@ import org.apache.spark.util.Utils * @author Moez A. Abdel-Gawad * @author Lex Spoon */ - class SparkIMain(initialSettings: Settings, val out: JPrintWriter) - extends SparkImports with Logging { - imain => + class SparkIMain( + initialSettings: Settings, + val out: JPrintWriter, + propagateExceptions: Boolean = false) + extends SparkImports with Logging { imain => val conf = new SparkConf() @@ -816,6 +818,10 @@ import org.apache.spark.util.Utils val resultName = FixedSessionNames.resultName def bindError(t: Throwable) = { + // Immediately throw the exception if we are asked to propagate them + if (propagateExceptions) { + throw unwrap(t) + } if (!bindExceptions) // avoid looping if already binding throw t diff --git a/sql/README.md b/sql/README.md index 31f9152344086..c84534da9a3d3 100644 --- a/sql/README.md +++ b/sql/README.md @@ -44,38 +44,37 @@ Type in expressions to have them evaluated. Type :help for more information. scala> val query = sql("SELECT * FROM (SELECT * FROM src) a") -query: org.apache.spark.sql.ExecutedQuery = -SELECT * FROM (SELECT * FROM src) a -=== Query Plan === -Project [key#6:0.0,value#7:0.1] - HiveTableScan [key#6,value#7], (MetastoreRelation default, src, None), None +query: org.apache.spark.sql.SchemaRDD = +== Query Plan == +== Physical Plan == +HiveTableScan [key#10,value#11], (MetastoreRelation default, src, None), None ``` Query results are RDDs and can be operated as such. ``` scala> query.collect() -res8: Array[org.apache.spark.sql.execution.Row] = Array([238,val_238], [86,val_86], [311,val_311]... +res2: Array[org.apache.spark.sql.Row] = Array([238,val_238], [86,val_86], [311,val_311], [27,val_27]... ``` You can also build further queries on top of these RDDs using the query DSL. ``` -scala> query.where('key === 100).toRdd.collect() -res11: Array[org.apache.spark.sql.execution.Row] = Array([100,val_100], [100,val_100]) +scala> query.where('key === 100).collect() +res3: Array[org.apache.spark.sql.Row] = Array([100,val_100], [100,val_100]) ``` -From the console you can even write rules that transform query plans. For example, the above query has redundant project operators that aren't doing anything. This redundancy can be eliminated using the `transform` function that is available on all [`TreeNode`](http://databricks.github.io/catalyst/latest/api/#catalyst.trees.TreeNode) objects. +From the console you can even write rules that transform query plans. For example, the above query has redundant project operators that aren't doing anything. This redundancy can be eliminated using the `transform` function that is available on all [`TreeNode`](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala) objects. ```scala -scala> query.logicalPlan -res1: catalyst.plans.logical.LogicalPlan = -Project {key#0,value#1} - Project {key#0,value#1} +scala> query.queryExecution.analyzed +res4: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan = +Project [key#10,value#11] + Project [key#10,value#11] MetastoreRelation default, src, None -scala> query.logicalPlan transform { +scala> query.queryExecution.analyzed transform { | case Project(projectList, child) if projectList == child.output => child | } -res2: catalyst.plans.logical.LogicalPlan = -Project {key#0,value#1} +res5: res17: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan = +Project [key#10,value#11] MetastoreRelation default, src, None ``` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 88a8fa7c28e0f..b3ae8e6779700 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -33,7 +33,7 @@ object ScalaReflection { /** Converts Scala objects to catalyst rows / types */ def convertToCatalyst(a: Any): Any = a match { - case o: Option[_] => o.orNull + case o: Option[_] => o.map(convertToCatalyst).orNull case s: Seq[_] => s.map(convertToCatalyst) case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) } case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 26336332c05a2..854b5b461bdc8 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -67,11 +67,12 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected implicit def asParser(k: Keyword): Parser[String] = lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + protected val ABS = Keyword("ABS") protected val ALL = Keyword("ALL") protected val AND = Keyword("AND") + protected val APPROXIMATE = Keyword("APPROXIMATE") protected val AS = Keyword("AS") protected val ASC = Keyword("ASC") - protected val APPROXIMATE = Keyword("APPROXIMATE") protected val AVG = Keyword("AVG") protected val BETWEEN = Keyword("BETWEEN") protected val BY = Keyword("BY") @@ -80,9 +81,9 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val COUNT = Keyword("COUNT") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") + protected val EXCEPT = Keyword("EXCEPT") protected val FALSE = Keyword("FALSE") protected val FIRST = Keyword("FIRST") - protected val LAST = Keyword("LAST") protected val FROM = Keyword("FROM") protected val FULL = Keyword("FULL") protected val GROUP = Keyword("GROUP") @@ -91,42 +92,42 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val IN = Keyword("IN") protected val INNER = Keyword("INNER") protected val INSERT = Keyword("INSERT") + protected val INTERSECT = Keyword("INTERSECT") protected val INTO = Keyword("INTO") protected val IS = Keyword("IS") protected val JOIN = Keyword("JOIN") + protected val LAST = Keyword("LAST") + protected val LAZY = Keyword("LAZY") protected val LEFT = Keyword("LEFT") + protected val LIKE = Keyword("LIKE") protected val LIMIT = Keyword("LIMIT") + protected val LOWER = Keyword("LOWER") protected val MAX = Keyword("MAX") protected val MIN = Keyword("MIN") protected val NOT = Keyword("NOT") protected val NULL = Keyword("NULL") protected val ON = Keyword("ON") protected val OR = Keyword("OR") - protected val OVERWRITE = Keyword("OVERWRITE") - protected val LIKE = Keyword("LIKE") - protected val RLIKE = Keyword("RLIKE") - protected val UPPER = Keyword("UPPER") - protected val LOWER = Keyword("LOWER") - protected val REGEXP = Keyword("REGEXP") protected val ORDER = Keyword("ORDER") protected val OUTER = Keyword("OUTER") + protected val OVERWRITE = Keyword("OVERWRITE") + protected val REGEXP = Keyword("REGEXP") protected val RIGHT = Keyword("RIGHT") + protected val RLIKE = Keyword("RLIKE") protected val SELECT = Keyword("SELECT") protected val SEMI = Keyword("SEMI") + protected val SQRT = Keyword("SQRT") protected val STRING = Keyword("STRING") + protected val SUBSTR = Keyword("SUBSTR") + protected val SUBSTRING = Keyword("SUBSTRING") protected val SUM = Keyword("SUM") protected val TABLE = Keyword("TABLE") protected val TIMESTAMP = Keyword("TIMESTAMP") protected val TRUE = Keyword("TRUE") protected val UNCACHE = Keyword("UNCACHE") protected val UNION = Keyword("UNION") + protected val UPPER = Keyword("UPPER") protected val WHERE = Keyword("WHERE") - protected val INTERSECT = Keyword("INTERSECT") - protected val EXCEPT = Keyword("EXCEPT") - protected val SUBSTR = Keyword("SUBSTR") - protected val SUBSTRING = Keyword("SUBSTRING") - protected val SQRT = Keyword("SQRT") - protected val ABS = Keyword("ABS") // Use reflection to find the reserved words defined in this class. protected val reservedWords = @@ -183,17 +184,15 @@ class SqlParser extends StandardTokenParsers with PackratParsers { } protected lazy val cache: Parser[LogicalPlan] = - CACHE ~ TABLE ~> ident ~ opt(AS ~> select) <~ opt(";") ^^ { - case tableName ~ None => - CacheCommand(tableName, true) - case tableName ~ Some(plan) => - CacheTableAsSelectCommand(tableName, plan) + CACHE ~> opt(LAZY) ~ (TABLE ~> ident) ~ opt(AS ~> select) <~ opt(";") ^^ { + case isLazy ~ tableName ~ plan => + CacheTableCommand(tableName, plan, isLazy.isDefined) } - + protected lazy val unCache: Parser[LogicalPlan] = UNCACHE ~ TABLE ~> ident <~ opt(";") ^^ { - case tableName => CacheCommand(tableName, false) - } + case tableName => UncacheTableCommand(tableName) + } protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",") @@ -283,7 +282,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { termExpression ~ ">=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThanOrEqual(e1, e2) } | termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | - termExpression ~ BETWEEN ~ termExpression ~ AND ~ termExpression ^^ { + termExpression ~ BETWEEN ~ termExpression ~ AND ~ termExpression ^^ { case e ~ _ ~ el ~ _ ~ eu => And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) } | termExpression ~ RLIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 616f1e2ecb60f..2059a91ba0612 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -87,7 +87,7 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog { tableName: String, alias: Option[String] = None): LogicalPlan = { val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) - val table = tables.get(tblName).getOrElse(sys.error(s"Table Not Found: $tableName")) + val table = tables.getOrElse(tblName, sys.error(s"Table Not Found: $tableName")) val tableWithQualifiers = Subquery(tblName, table) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 79e5283e86a37..64881854df7a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -348,8 +348,11 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e // Decimal and Double remain the same - case d: Divide if d.dataType == DoubleType => d - case d: Divide if d.dataType == DecimalType => d + case d: Divide if d.resolved && d.dataType == DoubleType => d + case d: Divide if d.resolved && d.dataType == DecimalType => d + + case Divide(l, r) if l.dataType == DecimalType => Divide(l, Cast(r, DecimalType)) + case Divide(l, r) if r.dataType == DecimalType => Divide(Cast(l, DecimalType), r) case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala index 1eb55715794a7..1a4ac06c7a79d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -24,9 +24,7 @@ import org.apache.spark.sql.catalyst.types.DataType /** * The data type representing [[DynamicRow]] values. */ -case object DynamicType extends DataType { - def simpleString: String = "dynamic" -} +case object DynamicType extends DataType /** * Wrap a [[Row]] as a [[DynamicRow]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a4133feae8166..636d0b95583e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -299,6 +299,18 @@ object BooleanSimplification extends Rule[LogicalPlan] { case (_, _) => or } + case not @ Not(exp) => + exp match { + case Literal(true, BooleanType) => Literal(false) + case Literal(false, BooleanType) => Literal(true) + case GreaterThan(l, r) => LessThanOrEqual(l, r) + case GreaterThanOrEqual(l, r) => LessThan(l, r) + case LessThan(l, r) => GreaterThanOrEqual(l, r) + case LessThanOrEqual(l, r) => GreaterThan(l, r) + case Not(e) => e + case _ => not + } + // Turn "if (true) a else b" into "a", and if (false) a else b" into "b". case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 8366639fa0e8b..9a3848cfc6b62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -56,9 +56,15 @@ case class ExplainCommand(plan: LogicalPlan, extended: Boolean = false) extends } /** - * Returned for the "CACHE TABLE tableName" and "UNCACHE TABLE tableName" command. + * Returned for the "CACHE TABLE tableName [AS SELECT ...]" command. */ -case class CacheCommand(tableName: String, doCache: Boolean) extends Command +case class CacheTableCommand(tableName: String, plan: Option[LogicalPlan], isLazy: Boolean) + extends Command + +/** + * Returned for the "UNCACHE TABLE tableName" command. + */ +case class UncacheTableCommand(tableName: String) extends Command /** * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. @@ -75,8 +81,3 @@ case class DescribeCommand( AttributeReference("data_type", StringType, nullable = false)(), AttributeReference("comment", StringType, nullable = false)()) } - -/** - * Returned for the "CACHE TABLE tableName AS SELECT .." command. - */ -case class CacheTableAsSelectCommand(tableName: String, plan: LogicalPlan) extends Command diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index ac043d4dd8eb9..1d375b8754182 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -19,71 +19,125 @@ package org.apache.spark.sql.catalyst.types import java.sql.Timestamp -import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral} +import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral} import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror} +import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} import scala.util.parsing.combinator.RegexParsers +import org.json4s.JsonAST.JValue +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.util.Utils -/** - * Utility functions for working with DataTypes. - */ -object DataType extends RegexParsers { - protected lazy val primitiveType: Parser[DataType] = - "StringType" ^^^ StringType | - "FloatType" ^^^ FloatType | - "IntegerType" ^^^ IntegerType | - "ByteType" ^^^ ByteType | - "ShortType" ^^^ ShortType | - "DoubleType" ^^^ DoubleType | - "LongType" ^^^ LongType | - "BinaryType" ^^^ BinaryType | - "BooleanType" ^^^ BooleanType | - "DecimalType" ^^^ DecimalType | - "TimestampType" ^^^ TimestampType - - protected lazy val arrayType: Parser[DataType] = - "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { - case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) - } - protected lazy val mapType: Parser[DataType] = - "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { - case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) +object DataType { + def fromJson(json: String): DataType = parseDataType(parse(json)) + + private object JSortedObject { + def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match { + case JObject(seq) => Some(seq.toList.sortBy(_._1)) + case _ => None } + } + + // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side. + private def parseDataType(json: JValue): DataType = json match { + case JString(name) => + PrimitiveType.nameToType(name) + + case JSortedObject( + ("containsNull", JBool(n)), + ("elementType", t: JValue), + ("type", JString("array"))) => + ArrayType(parseDataType(t), n) + + case JSortedObject( + ("keyType", k: JValue), + ("type", JString("map")), + ("valueContainsNull", JBool(n)), + ("valueType", v: JValue)) => + MapType(parseDataType(k), parseDataType(v), n) + + case JSortedObject( + ("fields", JArray(fields)), + ("type", JString("struct"))) => + StructType(fields.map(parseStructField)) + } - protected lazy val structField: Parser[StructField] = - ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { - case name ~ tpe ~ nullable => + private def parseStructField(json: JValue): StructField = json match { + case JSortedObject( + ("name", JString(name)), + ("nullable", JBool(nullable)), + ("type", dataType: JValue)) => + StructField(name, parseDataType(dataType), nullable) + } + + @deprecated("Use DataType.fromJson instead") + def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) + + private object CaseClassStringParser extends RegexParsers { + protected lazy val primitiveType: Parser[DataType] = + ( "StringType" ^^^ StringType + | "FloatType" ^^^ FloatType + | "IntegerType" ^^^ IntegerType + | "ByteType" ^^^ ByteType + | "ShortType" ^^^ ShortType + | "DoubleType" ^^^ DoubleType + | "LongType" ^^^ LongType + | "BinaryType" ^^^ BinaryType + | "BooleanType" ^^^ BooleanType + | "DecimalType" ^^^ DecimalType + | "TimestampType" ^^^ TimestampType + ) + + protected lazy val arrayType: Parser[DataType] = + "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { + case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) + } + + protected lazy val mapType: Parser[DataType] = + "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { + case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) + } + + protected lazy val structField: Parser[StructField] = + ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { + case name ~ tpe ~ nullable => StructField(name, tpe, nullable = nullable) - } + } - protected lazy val boolVal: Parser[Boolean] = - "true" ^^^ true | - "false" ^^^ false + protected lazy val boolVal: Parser[Boolean] = + ( "true" ^^^ true + | "false" ^^^ false + ) - protected lazy val structType: Parser[DataType] = - "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { - case fields => new StructType(fields) - } + protected lazy val structType: Parser[DataType] = + "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { + case fields => new StructType(fields) + } - protected lazy val dataType: Parser[DataType] = - arrayType | - mapType | - structType | - primitiveType + protected lazy val dataType: Parser[DataType] = + ( arrayType + | mapType + | structType + | primitiveType + ) + + /** + * Parses a string representation of a DataType. + * + * TODO: Generate parser as pickler... + */ + def apply(asString: String): DataType = parseAll(dataType, asString) match { + case Success(result, _) => result + case failure: NoSuccess => + throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") + } - /** - * Parses a string representation of a DataType. - * - * TODO: Generate parser as pickler... - */ - def apply(asString: String): DataType = parseAll(dataType, asString) match { - case Success(result, _) => result - case failure: NoSuccess => sys.error(s"Unsupported dataType: $asString, $failure") } protected[types] def buildFormattedString( @@ -111,15 +165,19 @@ abstract class DataType { def isPrimitive: Boolean = false - def simpleString: String -} + def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase + + private[sql] def jsonValue: JValue = typeName -case object NullType extends DataType { - def simpleString: String = "null" + def json: String = compact(render(jsonValue)) + + def prettyJson: String = pretty(render(jsonValue)) } +case object NullType extends DataType + object NativeType { - def all = Seq( + val all = Seq( IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) def unapply(dt: DataType): Boolean = all.contains(dt) @@ -139,6 +197,12 @@ trait PrimitiveType extends DataType { override def isPrimitive = true } +object PrimitiveType { + private[sql] val all = Seq(DecimalType, TimestampType, BinaryType) ++ NativeType.all + + private[sql] val nameToType = all.map(t => t.typeName -> t).toMap +} + abstract class NativeType extends DataType { private[sql] type JvmType @transient private[sql] val tag: TypeTag[JvmType] @@ -154,7 +218,6 @@ case object StringType extends NativeType with PrimitiveType { private[sql] type JvmType = String @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "string" } case object BinaryType extends NativeType with PrimitiveType { @@ -166,17 +229,15 @@ case object BinaryType extends NativeType with PrimitiveType { val res = x(i).compareTo(y(i)) if (res != 0) return res } - return x.length - y.length + x.length - y.length } } - def simpleString: String = "binary" } case object BooleanType extends NativeType with PrimitiveType { private[sql] type JvmType = Boolean @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "boolean" } case object TimestampType extends NativeType { @@ -187,8 +248,6 @@ case object TimestampType extends NativeType { private[sql] val ordering = new Ordering[JvmType] { def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) } - - def simpleString: String = "timestamp" } abstract class NumericType extends NativeType with PrimitiveType { @@ -222,7 +281,6 @@ case object LongType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Long]] private[sql] val integral = implicitly[Integral[Long]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "long" } case object IntegerType extends IntegralType { @@ -231,7 +289,6 @@ case object IntegerType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Int]] private[sql] val integral = implicitly[Integral[Int]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "integer" } case object ShortType extends IntegralType { @@ -240,7 +297,6 @@ case object ShortType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Short]] private[sql] val integral = implicitly[Integral[Short]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "short" } case object ByteType extends IntegralType { @@ -249,7 +305,6 @@ case object ByteType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Byte]] private[sql] val integral = implicitly[Integral[Byte]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "byte" } /** Matcher for any expressions that evaluate to [[FractionalType]]s */ @@ -271,7 +326,6 @@ case object DecimalType extends FractionalType { private[sql] val fractional = implicitly[Fractional[BigDecimal]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = BigDecimalAsIfIntegral - def simpleString: String = "decimal" } case object DoubleType extends FractionalType { @@ -281,7 +335,6 @@ case object DoubleType extends FractionalType { private[sql] val fractional = implicitly[Fractional[Double]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = DoubleAsIfIntegral - def simpleString: String = "double" } case object FloatType extends FractionalType { @@ -291,12 +344,12 @@ case object FloatType extends FractionalType { private[sql] val fractional = implicitly[Fractional[Float]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = FloatAsIfIntegral - def simpleString: String = "float" } object ArrayType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ def apply(elementType: DataType): ArrayType = ArrayType(elementType, true) + def typeName: String = "array" } /** @@ -309,11 +362,14 @@ object ArrayType { case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append( - s"${prefix}-- element: ${elementType.simpleString} (containsNull = ${containsNull})\n") + s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n") DataType.buildFormattedString(elementType, s"$prefix |", builder) } - def simpleString: String = "array" + override private[sql] def jsonValue = + ("type" -> typeName) ~ + ("elementType" -> elementType.jsonValue) ~ + ("containsNull" -> containsNull) } /** @@ -325,14 +381,22 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT case class StructField(name: String, dataType: DataType, nullable: Boolean) { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append(s"${prefix}-- ${name}: ${dataType.simpleString} (nullable = ${nullable})\n") + builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") DataType.buildFormattedString(dataType, s"$prefix |", builder) } + + private[sql] def jsonValue: JValue = { + ("name" -> name) ~ + ("type" -> dataType.jsonValue) ~ + ("nullable" -> nullable) + } } object StructType { protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) + + def typeName = "struct" } case class StructType(fields: Seq[StructField]) extends DataType { @@ -348,8 +412,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { * have a name matching the given name, `null` will be returned. */ def apply(name: String): StructField = { - nameToField.get(name).getOrElse( - throw new IllegalArgumentException(s"Field ${name} does not exist.")) + nameToField.getOrElse(name, throw new IllegalArgumentException(s"Field $name does not exist.")) } /** @@ -358,7 +421,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { */ def apply(names: Set[String]): StructType = { val nonExistFields = names -- fieldNamesSet - if (!nonExistFields.isEmpty) { + if (nonExistFields.nonEmpty) { throw new IllegalArgumentException( s"Field ${nonExistFields.mkString(",")} does not exist.") } @@ -384,7 +447,9 @@ case class StructType(fields: Seq[StructField]) extends DataType { fields.foreach(field => field.buildFormattedString(prefix, builder)) } - def simpleString: String = "struct" + override private[sql] def jsonValue = + ("type" -> typeName) ~ + ("fields" -> fields.map(_.jsonValue)) } object MapType { @@ -394,6 +459,8 @@ object MapType { */ def apply(keyType: DataType, valueType: DataType): MapType = MapType(keyType: DataType, valueType: DataType, true) + + def simpleName = "map" } /** @@ -407,12 +474,16 @@ case class MapType( valueType: DataType, valueContainsNull: Boolean) extends DataType { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append(s"${prefix}-- key: ${keyType.simpleString}\n") - builder.append(s"${prefix}-- value: ${valueType.simpleString} " + - s"(valueContainsNull = ${valueContainsNull})\n") + builder.append(s"$prefix-- key: ${keyType.typeName}\n") + builder.append(s"$prefix-- value: ${valueType.typeName} " + + s"(valueContainsNull = $valueContainsNull)\n") DataType.buildFormattedString(keyType, s"$prefix |", builder) DataType.buildFormattedString(valueType, s"$prefix |", builder) } - def simpleString: String = "map" + override private[sql] def jsonValue: JValue = + ("type" -> typeName) ~ + ("keyType" -> keyType.jsonValue) ~ + ("valueType" -> valueType.jsonValue) ~ + ("valueContainsNull" -> valueContainsNull) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 428607d8c8253..488e373854bb3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -53,7 +53,8 @@ case class OptionalData( floatField: Option[Float], shortField: Option[Short], byteField: Option[Byte], - booleanField: Option[Boolean]) + booleanField: Option[Boolean], + structField: Option[PrimitiveData]) case class ComplexData( arrayField: Seq[Int], @@ -100,7 +101,7 @@ class ScalaReflectionSuite extends FunSuite { nullable = true)) } - test("optinal data") { + test("optional data") { val schema = schemaFor[OptionalData] assert(schema === Schema( StructType(Seq( @@ -110,7 +111,8 @@ class ScalaReflectionSuite extends FunSuite { StructField("floatField", FloatType, nullable = true), StructField("shortField", ShortType, nullable = true), StructField("byteField", ByteType, nullable = true), - StructField("booleanField", BooleanType, nullable = true))), + StructField("booleanField", BooleanType, nullable = true), + StructField("structField", schemaFor[PrimitiveData].dataType, nullable = true))), nullable = true)) } @@ -228,4 +230,17 @@ class ScalaReflectionSuite extends FunSuite { assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3))) assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1,2,3)))) } + + test("convert PrimitiveData to catalyst") { + val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) + val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) + assert(convertToCatalyst(data) === convertedData) + } + + test("convert Option[Product] to catalyst") { + val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true) + val data = OptionalData(Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(true), Some(primitiveData)) + val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true, convertToCatalyst(primitiveData)) + assert(convertToCatalyst(data) === convertedData) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 5809a108ff62e..7b45738c4fc95 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.{BeforeAndAfter, FunSuite} -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.types.IntegerType +import org.apache.spark.sql.catalyst.types._ class AnalysisSuite extends FunSuite with BeforeAndAfter { val caseSensitiveCatalog = new SimpleCatalog(true) @@ -33,6 +34,12 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) + val testRelation2 = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType)(), + AttributeReference("e", ShortType)()) before { caseSensitiveCatalog.registerTable(None, "TaBlE", testRelation) @@ -74,7 +81,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { val e = intercept[RuntimeException] { caseSensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None)) } - assert(e.getMessage === "Table Not Found: tAbLe") + assert(e.getMessage == "Table Not Found: tAbLe") assert( caseSensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) === @@ -106,4 +113,31 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { } assert(e.getMessage().toLowerCase.contains("unresolved plan")) } + + test("divide should be casted into fractional types") { + val testRelation2 = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType)(), + AttributeReference("e", ShortType)()) + + val expr0 = 'a / 2 + val expr1 = 'a / 'b + val expr2 = 'a / 'c + val expr3 = 'a / 'd + val expr4 = 'e / 'e + val plan = caseInsensitiveAnalyze(Project( + Alias(expr0, s"Analyzer($expr0)")() :: + Alias(expr1, s"Analyzer($expr1)")() :: + Alias(expr2, s"Analyzer($expr2)")() :: + Alias(expr3, s"Analyzer($expr3)")() :: + Alias(expr4, s"Analyzer($expr4)")() :: Nil, testRelation2)) + val pl = plan.asInstanceOf[Project].projectList + assert(pl(0).dataType == DoubleType) + assert(pl(1).dataType == DoubleType) + assert(pl(2).dataType == DoubleType) + assert(pl(3).dataType == DecimalType) + assert(pl(4).dataType == DoubleType) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index aebdbb68e49b8..3bf7382ac67a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -91,14 +91,10 @@ private[sql] trait CacheManager { } /** Removes the data for the given SchemaRDD from the cache */ - private[sql] def uncacheQuery(query: SchemaRDD, blocking: Boolean = false): Unit = writeLock { + private[sql] def uncacheQuery(query: SchemaRDD, blocking: Boolean = true): Unit = writeLock { val planToCache = query.queryExecution.optimizedPlan val dataIndex = cachedData.indexWhere(_.plan.sameResult(planToCache)) - - if (dataIndex < 0) { - throw new IllegalArgumentException(s"Table $query is not cached.") - } - + require(dataIndex >= 0, s"Table $query is not cached.") cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking) cachedData.remove(dataIndex) } @@ -135,5 +131,4 @@ private[sql] trait CacheManager { case _ => } } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 7a55c5bf97a71..35561cac3e5e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration +import org.apache.spark.SparkContext import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.ScalaReflection @@ -31,12 +32,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.types.DataType import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.SparkStrategies +import org.apache.spark.sql.execution.{SparkStrategies, _} import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.{Logging, SparkContext} /** * :: AlphaComponent :: @@ -409,8 +409,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * It is only used by PySpark. */ private[sql] def parseDataType(dataTypeString: String): DataType = { - val parser = org.apache.spark.sql.catalyst.types.DataType - parser(dataTypeString) + DataType.fromJson(dataTypeString) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index cec82a7f2df94..4f79173a26f88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -111,7 +111,7 @@ private[sql] case class InMemoryRelation( override def newInstance() = { new InMemoryRelation( - output.map(_.newInstance), + output.map(_.newInstance()), useCompression, batchSize, storageLevel, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index cf93d5ad7b503..bbf17b9fadf86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.parquet._ + private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => @@ -34,13 +35,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // Find left semi joins where at least some predicates can be evaluated by matching join keys case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => - val semiJoin = execution.LeftSemiJoinHash( + val semiJoin = joins.LeftSemiJoinHash( leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => - execution.LeftSemiJoinBNL( - planLater(left), planLater(right), condition) :: Nil + joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil case _ => Nil } } @@ -50,13 +50,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * evaluated by matching hash keys. * * This strategy applies a simple optimization based on the estimates of the physical sizes of - * the two join sides. When planning a [[execution.BroadcastHashJoin]], if one side has an + * the two join sides. When planning a [[joins.BroadcastHashJoin]], if one side has an * estimated physical size smaller than the user-settable threshold * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be * ''broadcasted'' to all of the executors involved in the join, as a * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they - * will instead be used to decide the build side in a [[execution.ShuffledHashJoin]]. + * will instead be used to decide the build side in a [[joins.ShuffledHashJoin]]. */ object HashJoin extends Strategy with PredicateHelper { @@ -66,8 +66,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { left: LogicalPlan, right: LogicalPlan, condition: Option[Expression], - side: BuildSide) = { - val broadcastHashJoin = execution.BroadcastHashJoin( + side: joins.BuildSide) = { + val broadcastHashJoin = execution.joins.BroadcastHashJoin( leftKeys, rightKeys, side, planLater(left), planLater(right)) condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil } @@ -76,27 +76,26 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.autoBroadcastJoinThreshold > 0 && right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight) + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.autoBroadcastJoinThreshold > 0 && left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft) + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - BuildRight + joins.BuildRight } else { - BuildLeft + joins.BuildLeft } - val hashJoin = - execution.ShuffledHashJoin( - leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) + val hashJoin = joins.ShuffledHashJoin( + leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => - execution.HashOuterJoin( + joins.HashOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil case _ => Nil @@ -164,8 +163,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => val buildSide = - if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) BuildRight else BuildLeft - execution.BroadcastNestedLoopJoin( + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { + joins.BuildRight + } else { + joins.BuildLeft + } + joins.BroadcastNestedLoopJoin( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case _ => Nil } @@ -174,10 +177,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, _, None) => - execution.CartesianProduct(planLater(left), planLater(right)) :: Nil + execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil case logical.Join(left, right, Inner, Some(condition)) => execution.Filter(condition, - execution.CartesianProduct(planLater(left), planLater(right))) :: Nil + execution.joins.CartesianProduct(planLater(left), planLater(right))) :: Nil case _ => Nil } } @@ -274,9 +277,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil case logical.LocalRelation(output, data) => + val nPartitions = if (data.isEmpty) 1 else numPartitions PhysicalRDD( output, - RDDConversions.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil + RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => @@ -304,10 +308,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { Seq(execution.SetCommand(key, value, plan.output)(context)) case logical.ExplainCommand(logicalPlan, extended) => Seq(execution.ExplainCommand(logicalPlan, plan.output, extended)(context)) - case logical.CacheCommand(tableName, cache) => - Seq(execution.CacheCommand(tableName, cache)(context)) - case logical.CacheTableAsSelectCommand(tableName, plan) => - Seq(execution.CacheTableAsSelectCommand(tableName, plan)) + case logical.CacheTableCommand(tableName, optPlan, isLazy) => + Seq(execution.CacheTableCommand(tableName, optPlan, isLazy)) + case logical.UncacheTableCommand(tableName) => + Seq(execution.UncacheTableCommand(tableName)) case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index f88099ec0761e..d49633c24ad4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -138,49 +138,54 @@ case class ExplainCommand( * :: DeveloperApi :: */ @DeveloperApi -case class CacheCommand(tableName: String, doCache: Boolean)(@transient context: SQLContext) +case class CacheTableCommand( + tableName: String, + plan: Option[LogicalPlan], + isLazy: Boolean) extends LeafNode with Command { override protected lazy val sideEffectResult = { - if (doCache) { - context.cacheTable(tableName) - } else { - context.uncacheTable(tableName) + import sqlContext._ + + plan.foreach(_.registerTempTable(tableName)) + val schemaRDD = table(tableName) + schemaRDD.cache() + + if (!isLazy) { + // Performs eager caching + schemaRDD.count() } + Seq.empty[Row] } override def output: Seq[Attribute] = Seq.empty } + /** * :: DeveloperApi :: */ @DeveloperApi -case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])( - @transient context: SQLContext) - extends LeafNode with Command { - +case class UncacheTableCommand(tableName: String) extends LeafNode with Command { override protected lazy val sideEffectResult: Seq[Row] = { - Row("# Registered as a temporary table", null, null) +: - child.output.map(field => Row(field.name, field.dataType.toString, null)) + sqlContext.table(tableName).unpersist() + Seq.empty[Row] } + + override def output: Seq[Attribute] = Seq.empty } /** * :: DeveloperApi :: */ @DeveloperApi -case class CacheTableAsSelectCommand(tableName: String, logicalPlan: LogicalPlan) +case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])( + @transient context: SQLContext) extends LeafNode with Command { - - override protected[sql] lazy val sideEffectResult = { - import sqlContext._ - logicalPlan.registerTempTable(tableName) - cacheTable(tableName) - Seq.empty[Row] - } - override def output: Seq[Attribute] = Seq.empty - + override protected lazy val sideEffectResult: Seq[Row] = { + Row("# Registered as a temporary table", null, null) +: + child.output.map(field => Row(field.name, field.dataType.toString, null)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala deleted file mode 100644 index 2890a563bed48..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ /dev/null @@ -1,624 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util.{HashMap => JavaHashMap} - -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent._ -import scala.concurrent.duration._ - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.util.collection.CompactBuffer - -@DeveloperApi -sealed abstract class BuildSide - -@DeveloperApi -case object BuildLeft extends BuildSide - -@DeveloperApi -case object BuildRight extends BuildSide - -trait HashJoin { - self: SparkPlan => - - val leftKeys: Seq[Expression] - val rightKeys: Seq[Expression] - val buildSide: BuildSide - val left: SparkPlan - val right: SparkPlan - - lazy val (buildPlan, streamedPlan) = buildSide match { - case BuildLeft => (left, right) - case BuildRight => (right, left) - } - - lazy val (buildKeys, streamedKeys) = buildSide match { - case BuildLeft => (leftKeys, rightKeys) - case BuildRight => (rightKeys, leftKeys) - } - - def output = left.output ++ right.output - - @transient lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) - @transient lazy val streamSideKeyGenerator = - newMutableProjection(streamedKeys, streamedPlan.output) - - def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { - // TODO: Use Spark's HashMap implementation. - - val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() - var currentRow: Row = null - - // Create a mapping of buildKeys -> rows - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val existingMatchList = hashTable.get(rowKey) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[Row]() - hashTable.put(rowKey, newMatchList) - newMatchList - } else { - existingMatchList - } - matchList += currentRow.copy() - } - } - - new Iterator[Row] { - private[this] var currentStreamedRow: Row = _ - private[this] var currentHashMatches: CompactBuffer[Row] = _ - private[this] var currentMatchPosition: Int = -1 - - // Mutable per row objects. - private[this] val joinRow = new JoinedRow2 - - private[this] val joinKeys = streamSideKeyGenerator() - - override final def hasNext: Boolean = - (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || - (streamIter.hasNext && fetchNext()) - - override final def next() = { - val ret = buildSide match { - case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) - case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) - } - currentMatchPosition += 1 - ret - } - - /** - * Searches the streamed iterator for the next row that has at least one match in hashtable. - * - * @return true if the search is successful, and false if the streamed iterator runs out of - * tuples. - */ - private final def fetchNext(): Boolean = { - currentHashMatches = null - currentMatchPosition = -1 - - while (currentHashMatches == null && streamIter.hasNext) { - currentStreamedRow = streamIter.next() - if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatches = hashTable.get(joinKeys.currentValue) - } - } - - if (currentHashMatches == null) { - false - } else { - currentMatchPosition = 0 - true - } - } - } - } -} - -/** - * :: DeveloperApi :: - * Performs a hash based outer join for two child relations by shuffling the data using - * the join keys. This operator requires loading the associated partition in both side into memory. - */ -@DeveloperApi -case class HashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode { - - override def outputPartitioning: Partitioning = joinType match { - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") - } - - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - override def output = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case x => - throw new Exception(s"HashOuterJoin should not take $x as the JoinType") - } - } - - @transient private[this] lazy val DUMMY_LIST = Seq[Row](null) - @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] - - // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala - // iterator for performance purpose. - - private[this] def leftOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val rightNullRow = new GenericRow(right.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - leftIter.iterator.flatMap { l => - joinedRow.withLeft(l) - var matched = false - (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => - matched = true - joinedRow.copy - } else { - Nil - }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the - // records in right side. - // If we didn't get any proper row, then append a single row with empty right - joinedRow.withRight(rightNullRow).copy - }) - } - } - - private[this] def rightOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val leftNullRow = new GenericRow(left.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - rightIter.iterator.flatMap { r => - joinedRow.withRight(r) - var matched = false - (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => - matched = true - joinedRow.copy - } else { - Nil - }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the - // records in left side. - // If we didn't get any proper row, then append a single row with empty left. - joinedRow.withLeft(leftNullRow).copy - }) - } - } - - private[this] def fullOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val leftNullRow = new GenericRow(left.output.length) - val rightNullRow = new GenericRow(right.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - if (!key.anyNull) { - // Store the positions of records in right, if one of its associated row satisfy - // the join condition. - val rightMatchedSet = scala.collection.mutable.Set[Int]() - leftIter.iterator.flatMap[Row] { l => - joinedRow.withLeft(l) - var matched = false - rightIter.zipWithIndex.collect { - // 1. For those matched (satisfy the join condition) records with both sides filled, - // append them directly - - case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { - matched = true - // if the row satisfy the join condition, add its index into the matched set - rightMatchedSet.add(idx) - joinedRow.copy - } - } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // 2. For those unmatched records in left, append additional records with empty right. - - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all - // of the records in right side. - // If we didn't get any proper row, then append a single row with empty right. - joinedRow.withRight(rightNullRow).copy - }) - } ++ rightIter.zipWithIndex.collect { - // 3. For those unmatched records in right, append additional records with empty left. - - // Re-visiting the records in right, and append additional row with empty left, if its not - // in the matched set. - case (r, idx) if (!rightMatchedSet.contains(idx)) => { - joinedRow(leftNullRow, r).copy - } - } - } else { - leftIter.iterator.map[Row] { l => - joinedRow(l, rightNullRow).copy - } ++ rightIter.iterator.map[Row] { r => - joinedRow(leftNullRow, r).copy - } - } - } - - private[this] def buildHashTable( - iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = { - val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]() - while (iter.hasNext) { - val currentRow = iter.next() - val rowKey = keyGenerator(currentRow) - - var existingMatchList = hashTable.get(rowKey) - if (existingMatchList == null) { - existingMatchList = new CompactBuffer[Row]() - hashTable.put(rowKey, existingMatchList) - } - - existingMatchList += currentRow.copy() - } - - hashTable - } - - def execute() = { - left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - // TODO this probably can be replaced by external sort (sort merged join?) - // Build HashMap for current partition in left relation - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - // Build HashMap for current partition in right relation - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - - import scala.collection.JavaConversions._ - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - joinType match { - case LeftOuter => leftHashTable.keysIterator.flatMap { key => - leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) - } - case RightOuter => rightHashTable.keysIterator.flatMap { key => - rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) - } - case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) - } - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") - } - } - } -} - -/** - * :: DeveloperApi :: - * Performs an inner hash join of two child relations by first shuffling the data using the join - * keys. - */ -@DeveloperApi -case class ShuffledHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - override def outputPartitioning: Partitioning = left.outputPartitioning - - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - def execute() = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { - (buildIter, streamIter) => joinIterators(buildIter, streamIter) - } - } -} - -/** - * :: DeveloperApi :: - * Build the right table's join keys into a HashSet, and iteratively go through the left - * table, to find the if join keys are in the Hash set. - */ -@DeveloperApi -case class LeftSemiJoinHash( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - val buildSide = BuildRight - - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - override def output = left.output - - def execute() = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashSet = new java.util.HashSet[Row]() - var currentRow: Row = null - - // Create a Hash set of buildKeys - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - hashSet.add(rowKey) - } - } - } - - val joinKeys = streamSideKeyGenerator() - streamIter.filter(current => { - !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue) - }) - } - } -} - - -/** - * :: DeveloperApi :: - * Performs an inner hash join of two child relations. When the output RDD of this operator is - * being constructed, a Spark job is asynchronously started to calculate the values for the - * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed - * relation is not shuffled. - */ -@DeveloperApi -case class BroadcastHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - - override def requiredChildDistribution = - UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - - @transient - val broadcastFuture = future { - sparkContext.broadcast(buildPlan.executeCollect()) - } - - def execute() = { - val broadcastRelation = Await.result(broadcastFuture, 5.minute) - - streamedPlan.execute().mapPartitions { streamedIter => - joinIterators(broadcastRelation.value.iterator, streamedIter) - } - } -} - -/** - * :: DeveloperApi :: - * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys - * for hash join. - */ -@DeveloperApi -case class LeftSemiJoinBNL( - streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) - extends BinaryNode { - // TODO: Override requiredChildDistribution. - - override def outputPartitioning: Partitioning = streamed.outputPartitioning - - def output = left.output - - /** The Streamed Relation */ - def left = streamed - /** The Broadcast relation */ - def right = broadcast - - @transient lazy val boundCondition = - InterpretedPredicate( - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true))) - - def execute() = { - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) - - streamed.execute().mapPartitions { streamedIter => - val joinedRow = new JoinedRow - - streamedIter.filter(streamedRow => { - var i = 0 - var matched = false - - while (i < broadcastedRelation.value.size && !matched) { - val broadcastedRow = broadcastedRelation.value(i) - if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { - matched = true - } - i += 1 - } - matched - }) - } - } -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { - def output = left.output ++ right.output - - def execute() = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) - - leftResults.cartesian(rightResults).mapPartitions { iter => - val joinedRow = new JoinedRow - iter.map(r => joinedRow(r._1, r._2)) - } - } -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class BroadcastNestedLoopJoin( - left: SparkPlan, - right: SparkPlan, - buildSide: BuildSide, - joinType: JoinType, - condition: Option[Expression]) extends BinaryNode { - // TODO: Override requiredChildDistribution. - - /** BuildRight means the right relation <=> the broadcast relation. */ - val (streamed, broadcast) = buildSide match { - case BuildRight => (left, right) - case BuildLeft => (right, left) - } - - override def outputPartitioning: Partitioning = streamed.outputPartitioning - - override def output = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case _ => - left.output ++ right.output - } - } - - @transient lazy val boundCondition = - InterpretedPredicate( - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true))) - - def execute() = { - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) - - /** All rows that either match both-way, or rows from streamed joined with nulls. */ - val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => - val matchedRows = new CompactBuffer[Row] - // TODO: Use Spark's BitSet. - val includedBroadcastTuples = - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - val joinedRow = new JoinedRow - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - - streamedIter.foreach { streamedRow => - var i = 0 - var streamRowMatched = false - - while (i < broadcastedRelation.value.size) { - // TODO: One bitset per partition instead of per row. - val broadcastedRow = broadcastedRelation.value(i) - buildSide match { - case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => - matchedRows += joinedRow(streamedRow, broadcastedRow).copy() - streamRowMatched = true - includedBroadcastTuples += i - case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => - matchedRows += joinedRow(broadcastedRow, streamedRow).copy() - streamRowMatched = true - includedBroadcastTuples += i - case _ => - } - i += 1 - } - - (streamRowMatched, joinType, buildSide) match { - case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += joinedRow(streamedRow, rightNulls).copy() - case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += joinedRow(leftNulls, streamedRow).copy() - case _ => - } - } - Iterator((matchedRows, includedBroadcastTuples)) - } - - val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) - val allIncludedBroadcastTuples = - if (includedBroadcastTuples.count == 0) { - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - } else { - includedBroadcastTuples.reduce(_ ++ _) - } - - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - /** Rows from broadcasted joined with nulls. */ - val broadcastRowsWithNulls: Seq[Row] = { - val buf: CompactBuffer[Row] = new CompactBuffer() - var i = 0 - val rel = broadcastedRelation.value - while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { - (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) - case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) - case _ => - } - } - i += 1 - } - buf.toSeq - } - - // TODO: Breaks lineage. - sparkContext.union( - matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala new file mode 100644 index 0000000000000..d88ab6367a1b3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.concurrent._ +import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext.Implicits.global + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Performs an inner hash join of two child relations. When the output RDD of this operator is + * being constructed, a Spark job is asynchronously started to calculate the values for the + * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed + * relation is not shuffled. + */ +@DeveloperApi +case class BroadcastHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: SparkPlan, + right: SparkPlan) + extends BinaryNode with HashJoin { + + override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning + + override def requiredChildDistribution = + UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + + @transient + private val broadcastFuture = future { + sparkContext.broadcast(buildPlan.executeCollect()) + } + + override def execute() = { + val broadcastRelation = Await.result(broadcastFuture, 5.minute) + + streamedPlan.execute().mapPartitions { streamedIter => + joinIterators(broadcastRelation.value.iterator, streamedIter) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala new file mode 100644 index 0000000000000..36aad13778bd2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class BroadcastNestedLoopJoin( + left: SparkPlan, + right: SparkPlan, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) extends BinaryNode { + // TODO: Override requiredChildDistribution. + + /** BuildRight means the right relation <=> the broadcast relation. */ + private val (streamed, broadcast) = buildSide match { + case BuildRight => (left, right) + case BuildLeft => (right, left) + } + + override def outputPartitioning: Partitioning = streamed.outputPartitioning + + override def output = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case _ => + left.output ++ right.output + } + } + + @transient private lazy val boundCondition = + InterpretedPredicate( + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true))) + + override def execute() = { + val broadcastedRelation = + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + + /** All rows that either match both-way, or rows from streamed joined with nulls. */ + val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => + val matchedRows = new CompactBuffer[Row] + // TODO: Use Spark's BitSet. + val includedBroadcastTuples = + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + val joinedRow = new JoinedRow + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) + + streamedIter.foreach { streamedRow => + var i = 0 + var streamRowMatched = false + + while (i < broadcastedRelation.value.size) { + // TODO: One bitset per partition instead of per row. + val broadcastedRow = broadcastedRelation.value(i) + buildSide match { + case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => + matchedRows += joinedRow(streamedRow, broadcastedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => + matchedRows += joinedRow(broadcastedRow, streamedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case _ => + } + i += 1 + } + + (streamRowMatched, joinType, buildSide) match { + case (false, LeftOuter | FullOuter, BuildRight) => + matchedRows += joinedRow(streamedRow, rightNulls).copy() + case (false, RightOuter | FullOuter, BuildLeft) => + matchedRows += joinedRow(leftNulls, streamedRow).copy() + case _ => + } + } + Iterator((matchedRows, includedBroadcastTuples)) + } + + val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) + val allIncludedBroadcastTuples = + if (includedBroadcastTuples.count == 0) { + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + } else { + includedBroadcastTuples.reduce(_ ++ _) + } + + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) + /** Rows from broadcasted joined with nulls. */ + val broadcastRowsWithNulls: Seq[Row] = { + val buf: CompactBuffer[Row] = new CompactBuffer() + var i = 0 + val rel = broadcastedRelation.value + while (i < rel.length) { + if (!allIncludedBroadcastTuples.contains(i)) { + (joinType, buildSide) match { + case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) + case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) + case _ => + } + } + i += 1 + } + buf.toSeq + } + + // TODO: Breaks lineage. + sparkContext.union( + matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala new file mode 100644 index 0000000000000..76c14c02aab34 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { + override def output = left.output ++ right.output + + override def execute() = { + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + + leftResults.cartesian(rightResults).mapPartitions { iter => + val joinedRow = new JoinedRow + iter.map(r => joinedRow(r._1, r._2)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala new file mode 100644 index 0000000000000..472b2e6ca6b4a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow2, Row} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.util.collection.CompactBuffer + + +trait HashJoin { + self: SparkPlan => + + val leftKeys: Seq[Expression] + val rightKeys: Seq[Expression] + val buildSide: BuildSide + val left: SparkPlan + val right: SparkPlan + + protected lazy val (buildPlan, streamedPlan) = buildSide match { + case BuildLeft => (left, right) + case BuildRight => (right, left) + } + + protected lazy val (buildKeys, streamedKeys) = buildSide match { + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) + } + + override def output = left.output ++ right.output + + @transient protected lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) + @transient protected lazy val streamSideKeyGenerator = + newMutableProjection(streamedKeys, streamedPlan.output) + + protected def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = + { + // TODO: Use Spark's HashMap implementation. + + val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() + var currentRow: Row = null + + // Create a mapping of buildKeys -> rows + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if (!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[Row]() + hashTable.put(rowKey, newMatchList) + newMatchList + } else { + existingMatchList + } + matchList += currentRow.copy() + } + } + + new Iterator[Row] { + private[this] var currentStreamedRow: Row = _ + private[this] var currentHashMatches: CompactBuffer[Row] = _ + private[this] var currentMatchPosition: Int = -1 + + // Mutable per row objects. + private[this] val joinRow = new JoinedRow2 + + private[this] val joinKeys = streamSideKeyGenerator() + + override final def hasNext: Boolean = + (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || + (streamIter.hasNext && fetchNext()) + + override final def next() = { + val ret = buildSide match { + case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) + case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) + } + currentMatchPosition += 1 + ret + } + + /** + * Searches the streamed iterator for the next row that has at least one match in hashtable. + * + * @return true if the search is successful, and false if the streamed iterator runs out of + * tuples. + */ + private final def fetchNext(): Boolean = { + currentHashMatches = null + currentMatchPosition = -1 + + while (currentHashMatches == null && streamIter.hasNext) { + currentStreamedRow = streamIter.next() + if (!joinKeys(currentStreamedRow).anyNull) { + currentHashMatches = hashTable.get(joinKeys.currentValue) + } + } + + if (currentHashMatches == null) { + false + } else { + currentMatchPosition = 0 + true + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala new file mode 100644 index 0000000000000..b73041d306b36 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.{HashMap => JavaHashMap} + +import scala.collection.JavaConversions._ + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + * Performs a hash based outer join for two child relations by shuffling the data using + * the join keys. This operator requires loading the associated partition in both side into memory. + */ +@DeveloperApi +case class HashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def outputPartitioning: Partitioning = joinType match { + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def output = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + } + + @transient private[this] lazy val DUMMY_LIST = Seq[Row](null) + @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] + + // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala + // iterator for performance purpose. + + private[this] def leftOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + leftIter.iterator.flatMap { l => + joinedRow.withLeft(l) + var matched = false + (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in right side. + // If we didn't get any proper row, then append a single row with empty right + joinedRow.withRight(rightNullRow).copy + }) + } + } + + private[this] def rightOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + rightIter.iterator.flatMap { r => + joinedRow.withRight(r) + var matched = false + (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in left side. + // If we didn't get any proper row, then append a single row with empty left. + joinedRow.withLeft(leftNullRow).copy + }) + } + } + + private[this] def fullOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + if (!key.anyNull) { + // Store the positions of records in right, if one of its associated row satisfy + // the join condition. + val rightMatchedSet = scala.collection.mutable.Set[Int]() + leftIter.iterator.flatMap[Row] { l => + joinedRow.withLeft(l) + var matched = false + rightIter.zipWithIndex.collect { + // 1. For those matched (satisfy the join condition) records with both sides filled, + // append them directly + + case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { + matched = true + // if the row satisfy the join condition, add its index into the matched set + rightMatchedSet.add(idx) + joinedRow.copy + } + } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // 2. For those unmatched records in left, append additional records with empty right. + + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all + // of the records in right side. + // If we didn't get any proper row, then append a single row with empty right. + joinedRow.withRight(rightNullRow).copy + }) + } ++ rightIter.zipWithIndex.collect { + // 3. For those unmatched records in right, append additional records with empty left. + + // Re-visiting the records in right, and append additional row with empty left, if its not + // in the matched set. + case (r, idx) if (!rightMatchedSet.contains(idx)) => { + joinedRow(leftNullRow, r).copy + } + } + } else { + leftIter.iterator.map[Row] { l => + joinedRow(l, rightNullRow).copy + } ++ rightIter.iterator.map[Row] { r => + joinedRow(leftNullRow, r).copy + } + } + } + + private[this] def buildHashTable( + iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = { + val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]() + while (iter.hasNext) { + val currentRow = iter.next() + val rowKey = keyGenerator(currentRow) + + var existingMatchList = hashTable.get(rowKey) + if (existingMatchList == null) { + existingMatchList = new CompactBuffer[Row]() + hashTable.put(rowKey, existingMatchList) + } + + existingMatchList += currentRow.copy() + } + + hashTable + } + + override def execute() = { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // TODO this probably can be replaced by external sort (sort merged join?) + // Build HashMap for current partition in left relation + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + // Build HashMap for current partition in right relation + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + joinType match { + case LeftOuter => leftHashTable.keysIterator.flatMap { key => + leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) + } + case RightOuter => rightHashTable.keysIterator.flatMap { key => + rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) + } + case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => + fullOuterIterator(key, + leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) + } + case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala new file mode 100644 index 0000000000000..60003d1900d85 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys + * for hash join. + */ +@DeveloperApi +case class LeftSemiJoinBNL( + streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) + extends BinaryNode { + // TODO: Override requiredChildDistribution. + + override def outputPartitioning: Partitioning = streamed.outputPartitioning + + override def output = left.output + + /** The Streamed Relation */ + override def left = streamed + /** The Broadcast relation */ + override def right = broadcast + + @transient private lazy val boundCondition = + InterpretedPredicate( + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true))) + + override def execute() = { + val broadcastedRelation = + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + + streamed.execute().mapPartitions { streamedIter => + val joinedRow = new JoinedRow + + streamedIter.filter(streamedRow => { + var i = 0 + var matched = false + + while (i < broadcastedRelation.value.size && !matched) { + val broadcastedRow = broadcastedRelation.value(i) + if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { + matched = true + } + i += 1 + } + matched + }) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala new file mode 100644 index 0000000000000..ea7babf3be948 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.{Expression, Row} +import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Build the right table's join keys into a HashSet, and iteratively go through the left + * table, to find the if join keys are in the Hash set. + */ +@DeveloperApi +case class LeftSemiJoinHash( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashJoin { + + override val buildSide = BuildRight + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def output = left.output + + override def execute() = { + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => + val hashSet = new java.util.HashSet[Row]() + var currentRow: Row = null + + // Create a Hash set of buildKeys + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if (!rowKey.anyNull) { + val keyExists = hashSet.contains(rowKey) + if (!keyExists) { + hashSet.add(rowKey) + } + } + } + + val joinKeys = streamSideKeyGenerator() + streamIter.filter(current => { + !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue) + }) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala new file mode 100644 index 0000000000000..8247304c1dc2c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Performs an inner hash join of two child relations by first shuffling the data using the join + * keys. + */ +@DeveloperApi +case class ShuffledHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: SparkPlan, + right: SparkPlan) + extends BinaryNode with HashJoin { + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def execute() = { + buildPlan.execute().zipPartitions(streamedPlan.execute()) { + (buildIter, streamIter) => joinIterators(buildIter, streamIter) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala new file mode 100644 index 0000000000000..7f2ab1765b28f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Physical execution operators for join operations. + */ +package object joins { + + @DeveloperApi + sealed abstract class BuildSide + + @DeveloperApi + case object BuildRight extends BuildSide + + @DeveloperApi + case object BuildLeft extends BuildSide + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 2941b9793597f..e6389cf77a4c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.parquet import java.io.IOException +import scala.util.Try + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job @@ -323,14 +325,14 @@ private[parquet] object ParquetTypesConverter extends Logging { } def convertFromString(string: String): Seq[Attribute] = { - DataType(string) match { + Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match { case s: StructType => s.toAttributes case other => sys.error(s"Can convert $string to row") } } def convertToString(schema: Seq[Attribute]): String = { - StructType.fromAttributes(schema).toString + StructType.fromAttributes(schema).json } def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 957388e99bd85..1e624f97004f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -18,30 +18,39 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.storage.RDDBlockId case class BigData(s: String) class CachedTableSuite extends QueryTest { - import TestSQLContext._ TestData // Load test tables. - /** - * Throws a test failed exception when the number of cached tables differs from the expected - * number. - */ def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { val planWithCaching = query.queryExecution.withCachedData val cachedData = planWithCaching collect { case cached: InMemoryRelation => cached } - if (cachedData.size != numCachedTables) { - fail( - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + assert( + cachedData.size == numCachedTables, + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + planWithCaching) - } + } + + def rddIdOf(tableName: String): Int = { + val executedPlan = table(tableName).queryExecution.executedPlan + executedPlan.collect { + case InMemoryColumnarTableScan(_, _, relation) => + relation.cachedColumnBuffers.id + case _ => + fail(s"Table $tableName is not cached\n" + executedPlan) + }.head + } + + def isMaterialized(rddId: Int): Boolean = { + sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty } test("too big for memory") { @@ -52,10 +61,33 @@ class CachedTableSuite extends QueryTest { uncacheTable("bigData") } - test("calling .cache() should use inmemory columnar caching") { + test("calling .cache() should use in-memory columnar caching") { table("testData").cache() + assertCached(table("testData")) + } + + test("calling .unpersist() should drop in-memory columnar cache") { + table("testData").cache() + table("testData").count() + table("testData").unpersist(true) + assertCached(table("testData"), 0) + } + + test("isCached") { + cacheTable("testData") assertCached(table("testData")) + assert(table("testData").queryExecution.withCachedData match { + case _: InMemoryRelation => true + case _ => false + }) + + uncacheTable("testData") + assert(!isCached("testData")) + assert(table("testData").queryExecution.withCachedData match { + case _: InMemoryRelation => false + case _ => true + }) } test("SPARK-1669: cacheTable should be idempotent") { @@ -64,32 +96,27 @@ class CachedTableSuite extends QueryTest { cacheTable("testData") assertCached(table("testData")) - cacheTable("testData") - table("testData").queryExecution.analyzed match { - case InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan) => - fail("cacheTable is not idempotent") + assertResult(1, "InMemoryRelation not found, testData should have been cached") { + table("testData").queryExecution.withCachedData.collect { + case r: InMemoryRelation => r + }.size + } - case _ => + cacheTable("testData") + assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { + table("testData").queryExecution.withCachedData.collect { + case r @ InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan) => r + }.size } } test("read from cached table and uncache") { cacheTable("testData") - - checkAnswer( - table("testData"), - testData.collect().toSeq - ) - + checkAnswer(table("testData"), testData.collect().toSeq) assertCached(table("testData")) uncacheTable("testData") - - checkAnswer( - table("testData"), - testData.collect().toSeq - ) - + checkAnswer(table("testData"), testData.collect().toSeq) assertCached(table("testData"), 0) } @@ -99,10 +126,12 @@ class CachedTableSuite extends QueryTest { } } - test("SELECT Star Cached Table") { + test("SELECT star from cached table") { sql("SELECT * FROM testData").registerTempTable("selectStar") cacheTable("selectStar") - sql("SELECT * FROM selectStar WHERE key = 1").collect() + checkAnswer( + sql("SELECT * FROM selectStar WHERE key = 1"), + Seq(Row(1, "1"))) uncacheTable("selectStar") } @@ -120,23 +149,57 @@ class CachedTableSuite extends QueryTest { sql("CACHE TABLE testData") assertCached(table("testData")) - assert(isCached("testData"), "Table 'testData' should be cached") + val rddId = rddIdOf("testData") + assert( + isMaterialized(rddId), + "Eagerly cached in-memory table should have already been materialized") sql("UNCACHE TABLE testData") - assertCached(table("testData"), 0) assert(!isCached("testData"), "Table 'testData' should not be cached") + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } - - test("CACHE TABLE tableName AS SELECT Star Table") { + + test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - sql("SELECT * FROM testCacheTable WHERE key = 1").collect() - assert(isCached("testCacheTable"), "Table 'testCacheTable' should be cached") + assertCached(table("testCacheTable")) + + val rddId = rddIdOf("testCacheTable") + assert( + isMaterialized(rddId), + "Eagerly cached in-memory table should have already been materialized") + uncacheTable("testCacheTable") + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } - - test("'CACHE TABLE tableName AS SELECT ..'") { - sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assert(isCached("testCacheTable"), "Table 'testCacheTable' should be cached") + + test("CACHE TABLE tableName AS SELECT ...") { + sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") + assertCached(table("testCacheTable")) + + val rddId = rddIdOf("testCacheTable") + assert( + isMaterialized(rddId), + "Eagerly cached in-memory table should have already been materialized") + uncacheTable("testCacheTable") + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } + + test("CACHE LAZY TABLE tableName") { + sql("CACHE LAZY TABLE testData") + assertCached(table("testData")) + + val rddId = rddIdOf("testData") + assert( + !isMaterialized(rddId), + "Lazily cached in-memory table shouldn't be materialized eagerly") + + sql("SELECT COUNT(*) FROM testData").collect() + assert( + isMaterialized(rddId), + "Lazily cached in-memory table should have been materialized") + + uncacheTable("testData") + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala index 8fb59c5830f6d..100ecb45e9e88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.types.DataType + class DataTypeSuite extends FunSuite { test("construct an ArrayType") { @@ -55,4 +57,30 @@ class DataTypeSuite extends FunSuite { struct(Set("b", "d", "e", "f")) } } + + def checkDataTypeJsonRepr(dataType: DataType): Unit = { + test(s"JSON - $dataType") { + assert(DataType.fromJson(dataType.json) === dataType) + } + } + + checkDataTypeJsonRepr(BooleanType) + checkDataTypeJsonRepr(ByteType) + checkDataTypeJsonRepr(ShortType) + checkDataTypeJsonRepr(IntegerType) + checkDataTypeJsonRepr(LongType) + checkDataTypeJsonRepr(FloatType) + checkDataTypeJsonRepr(DoubleType) + checkDataTypeJsonRepr(DecimalType) + checkDataTypeJsonRepr(TimestampType) + checkDataTypeJsonRepr(StringType) + checkDataTypeJsonRepr(BinaryType) + checkDataTypeJsonRepr(ArrayType(DoubleType, true)) + checkDataTypeJsonRepr(ArrayType(StringType, false)) + checkDataTypeJsonRepr(MapType(IntegerType, StringType, true)) + checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false)) + checkDataTypeJsonRepr( + StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", ArrayType(DoubleType), nullable = false)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 6c7697ece8c56..07f4d2946c1b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6fb6cb8db0c8f..b9b196ea5a46a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.{ShuffledHashJoin, BroadcastHashJoin} +import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.test._ import org.scalatest.BeforeAndAfterAll import java.util.TimeZone diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 69e0adbd3ee0d..f53acc8c9f718 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -67,10 +67,11 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3) checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2) checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4) + checkBatchPruning("NOT (i < 88)", 88 to 100, 1, 2) // With unsupported predicate checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2) - checkBatchPruning("NOT (i < 88)", 88 to 100, 5, 10) + checkBatchPruning(s"NOT (i in (${(1 to 30).mkString(",")}))", 31 to 100, 5, 10) def checkBatchPruning( filter: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index bfbf431a11913..f14ffca0e4d35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.execution import org.scalatest.FunSuite +import org.apache.spark.sql.{SQLConf, execution} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.{SQLConf, execution} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 07adf731405af..25e41ecf28e2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -789,7 +789,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result3(0)(1) === "the answer") Utils.deleteRecursively(tmpdir) } - + test("Querying on empty parquet throws exception (SPARK-3536)") { val tmpdir = Utils.createTempDir() Utils.deleteRecursively(tmpdir) @@ -798,4 +798,18 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result1.size === 0) Utils.deleteRecursively(tmpdir) } + + test("DataType string parser compatibility") { + val schema = StructType(List( + StructField("c1", IntegerType, false), + StructField("c2", BinaryType, false))) + + val fromCaseClassString = ParquetTypesConverter.convertFromString(schema.toString) + val fromJson = ParquetTypesConverter.convertFromString(schema.json) + + (fromCaseClassString, fromJson).zipped.foreach { (a, b) => + assert(a.name == b.name) + assert(a.dataType === b.dataType) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index e7e1cb980c2ae..c5844e92eaaa9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -24,11 +24,11 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.SqlLexical /** - * A parser that recognizes all HiveQL constructs together with several Spark SQL specific + * A parser that recognizes all HiveQL constructs together with several Spark SQL specific * extensions like CACHE TABLE and UNCACHE TABLE. */ -private[hive] class ExtendedHiveQlParser extends StandardTokenParsers with PackratParsers { - +private[hive] class ExtendedHiveQlParser extends StandardTokenParsers with PackratParsers { + def apply(input: String): LogicalPlan = { // Special-case out set commands since the value fields can be // complex to handle without RegexParsers. Also this approach @@ -54,16 +54,17 @@ private[hive] class ExtendedHiveQlParser extends StandardTokenParsers with Packr protected case class Keyword(str: String) - protected val CACHE = Keyword("CACHE") - protected val SET = Keyword("SET") protected val ADD = Keyword("ADD") - protected val JAR = Keyword("JAR") - protected val TABLE = Keyword("TABLE") protected val AS = Keyword("AS") - protected val UNCACHE = Keyword("UNCACHE") - protected val FILE = Keyword("FILE") + protected val CACHE = Keyword("CACHE") protected val DFS = Keyword("DFS") + protected val FILE = Keyword("FILE") + protected val JAR = Keyword("JAR") + protected val LAZY = Keyword("LAZY") + protected val SET = Keyword("SET") protected val SOURCE = Keyword("SOURCE") + protected val TABLE = Keyword("TABLE") + protected val UNCACHE = Keyword("UNCACHE") protected implicit def asParser(k: Keyword): Parser[String] = lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) @@ -79,57 +80,56 @@ private[hive] class ExtendedHiveQlParser extends StandardTokenParsers with Packr override val lexical = new SqlLexical(reservedWords) - protected lazy val query: Parser[LogicalPlan] = + protected lazy val query: Parser[LogicalPlan] = cache | uncache | addJar | addFile | dfs | source | hiveQl protected lazy val hiveQl: Parser[LogicalPlan] = - remainingQuery ^^ { - case r => HiveQl.createPlan(r.trim()) + restInput ^^ { + case statement => HiveQl.createPlan(statement.trim()) } - /** It returns all remaining query */ - protected lazy val remainingQuery: Parser[String] = new Parser[String] { + // Returns the whole input string + protected lazy val wholeInput: Parser[String] = new Parser[String] { def apply(in: Input) = - Success( - in.source.subSequence(in.offset, in.source.length).toString, - in.drop(in.source.length())) + Success(in.source.toString, in.drop(in.source.length())) } - /** It returns all query */ - protected lazy val allQuery: Parser[String] = new Parser[String] { + // Returns the rest of the input string that are not parsed yet + protected lazy val restInput: Parser[String] = new Parser[String] { def apply(in: Input) = - Success(in.source.toString, in.drop(in.source.length())) + Success( + in.source.subSequence(in.offset, in.source.length).toString, + in.drop(in.source.length())) } protected lazy val cache: Parser[LogicalPlan] = - CACHE ~ TABLE ~> ident ~ opt(AS ~> hiveQl) ^^ { - case tableName ~ None => CacheCommand(tableName, true) - case tableName ~ Some(plan) => - CacheTableAsSelectCommand(tableName, plan) + CACHE ~> opt(LAZY) ~ (TABLE ~> ident) ~ opt(AS ~> hiveQl) ^^ { + case isLazy ~ tableName ~ plan => + CacheTableCommand(tableName, plan, isLazy.isDefined) } protected lazy val uncache: Parser[LogicalPlan] = UNCACHE ~ TABLE ~> ident ^^ { - case tableName => CacheCommand(tableName, false) + case tableName => UncacheTableCommand(tableName) } protected lazy val addJar: Parser[LogicalPlan] = - ADD ~ JAR ~> remainingQuery ^^ { - case rq => AddJar(rq.trim()) + ADD ~ JAR ~> restInput ^^ { + case jar => AddJar(jar.trim()) } protected lazy val addFile: Parser[LogicalPlan] = - ADD ~ FILE ~> remainingQuery ^^ { - case rq => AddFile(rq.trim()) + ADD ~ FILE ~> restInput ^^ { + case file => AddFile(file.trim()) } protected lazy val dfs: Parser[LogicalPlan] = - DFS ~> allQuery ^^ { - case aq => NativeCommand(aq.trim()) + DFS ~> wholeInput ^^ { + case command => NativeCommand(command.trim()) } protected lazy val source: Parser[LogicalPlan] = - SOURCE ~> remainingQuery ^^ { - case rq => SourceCommand(rq.trim()) + SOURCE ~> restInput ^^ { + case file => SourceCommand(file.trim()) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 8bcc098bbb620..fad3b39f81413 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -268,7 +268,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { */ protected[sql] def runSqlHive(sql: String): Seq[String] = { val maxResults = 100000 - val results = runHive(sql, 100000) + val results = runHive(sql, maxResults) // It is very confusing when you only get back some of the results... if (results.size == maxResults) sys.error("RESULTS POSSIBLY TRUNCATED") results diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index cc0605b0adb35..addd5bed8426d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -19,31 +19,28 @@ package org.apache.spark.sql.hive import scala.util.parsing.combinator.RegexParsers -import org.apache.hadoop.hive.metastore.api.{FieldSchema, StorageDescriptor, SerDeInfo} -import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} +import org.apache.hadoop.hive.metastore.api.{FieldSchema, SerDeInfo, StorageDescriptor, Partition => TPartition, Table => TTable} import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.Deserializer -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, Catalog} +import org.apache.spark.sql.catalyst.analysis.Catalog import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.util.Utils /* Implicit conversions */ import scala.collection.JavaConversions._ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging { - import HiveMetastoreTypes._ + import org.apache.spark.sql.hive.HiveMetastoreTypes._ /** Connection to hive metastore. Usages should lock on `this`. */ protected[hive] val client = Hive.get(hive.hiveconf) @@ -137,10 +134,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) = { val childOutputDataTypes = child.output.map(_.dataType) - // Only check attributes, not partitionKeys since they are always strings. - // TODO: Fully support inserting into partitioned tables. val tableOutputDataTypes = - table.attributes.map(_.dataType) ++ table.partitionKeys.map(_.dataType) + (table.attributes ++ table.partitionKeys).take(child.output.length).map(_.dataType) if (childOutputDataTypes == tableOutputDataTypes) { p diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index c0e69393cc2e3..a4354c1379c63 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -32,7 +32,7 @@ import org.apache.hadoop.hive.serde2.avro.AvroSerDe import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.plans.logical.{CacheCommand, LogicalPlan, NativeCommand} +import org.apache.spark.sql.catalyst.plans.logical.{CacheTableCommand, LogicalPlan, NativeCommand} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.hive._ import org.apache.spark.sql.SQLConf @@ -67,7 +67,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath /** Sets up the system initially or after a RESET command */ - protected def configure() { + protected def configure(): Unit = { setConf("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=$metastorePath;create=true") setConf("hive.metastore.warehouse.dir", warehousePath) @@ -154,7 +154,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { override lazy val analyzed = { val describedTables = logical match { case NativeCommand(describedTable(tbl)) => tbl :: Nil - case CacheCommand(tbl, _) => tbl :: Nil + case CacheTableCommand(tbl, _, _) => tbl :: Nil case _ => Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index f8b4e898ec41d..f0785d8882636 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -69,33 +69,36 @@ case class InsertIntoHiveTable( * Wraps with Hive types based on object inspector. * TODO: Consolidate all hive OI/data interface code. */ - protected def wrap(a: (Any, ObjectInspector)): Any = a match { - case (s: String, oi: JavaHiveVarcharObjectInspector) => - new HiveVarchar(s, s.size) - - case (bd: BigDecimal, oi: JavaHiveDecimalObjectInspector) => - new HiveDecimal(bd.underlying()) - - case (row: Row, oi: StandardStructObjectInspector) => - val struct = oi.create() - row.zip(oi.getAllStructFieldRefs: Seq[StructField]).foreach { - case (data, field) => - oi.setStructFieldData(struct, field, wrap(data, field.getFieldObjectInspector)) + protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { + case _: JavaHiveVarcharObjectInspector => + (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) + + case _: JavaHiveDecimalObjectInspector => + (o: Any) => new HiveDecimal(o.asInstanceOf[BigDecimal].underlying()) + + case soi: StandardStructObjectInspector => + val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) + (o: Any) => { + val struct = soi.create() + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach { + (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + } + struct } - struct - case (s: Seq[_], oi: ListObjectInspector) => - val wrappedSeq = s.map(wrap(_, oi.getListElementObjectInspector)) - seqAsJavaList(wrappedSeq) + case loi: ListObjectInspector => + val wrapper = wrapperFor(loi.getListElementObjectInspector) + (o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) - case (m: Map[_, _], oi: MapObjectInspector) => - val keyOi = oi.getMapKeyObjectInspector - val valueOi = oi.getMapValueObjectInspector - val wrappedMap = m.map { case (key, value) => wrap(key, keyOi) -> wrap(value, valueOi) } - mapAsJavaMap(wrappedMap) + case moi: MapObjectInspector => + val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector) + val valueWrapper = wrapperFor(moi.getMapValueObjectInspector) + (o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) => + keyWrapper(key) -> valueWrapper(value) + }) - case (obj, _) => - obj + case _ => + identity[Any] } def saveAsHiveFile( @@ -103,7 +106,7 @@ case class InsertIntoHiveTable( valueClass: Class[_], fileSinkConf: FileSinkDesc, conf: SerializableWritable[JobConf], - writerContainer: SparkHiveWriterContainer) { + writerContainer: SparkHiveWriterContainer): Unit = { assert(valueClass != null, "Output value class not set") conf.value.setOutputValueClass(valueClass) @@ -122,7 +125,7 @@ case class InsertIntoHiveTable( writerContainer.commitJob() // Note that this function is executed on executor side - def writeToFile(context: TaskContext, iterator: Iterator[Row]) { + def writeToFile(context: TaskContext, iterator: Iterator[Row]): Unit = { val serializer = newSerializer(fileSinkConf.getTableInfo) val standardOI = ObjectInspectorUtils .getStandardObjectInspector( @@ -131,6 +134,7 @@ case class InsertIntoHiveTable( .asInstanceOf[StructObjectInspector] val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val wrappers = fieldOIs.map(wrapperFor) val outputData = new Array[Any](fieldOIs.length) // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it @@ -141,13 +145,13 @@ case class InsertIntoHiveTable( iterator.foreach { row => var i = 0 while (i < fieldOIs.length) { - // TODO (lian) avoid per row dynamic dispatching and pattern matching cost in `wrap` - outputData(i) = wrap(row(i), fieldOIs(i)) + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row(i)) i += 1 } - val writer = writerContainer.getLocalFileWriter(row) - writer.write(serializer.serialize(outputData, standardOI)) + writerContainer + .getLocalFileWriter(row) + .write(serializer.serialize(outputData, standardOI)) } writerContainer.close() @@ -207,7 +211,7 @@ case class InsertIntoHiveTable( // Report error if any static partition appears after a dynamic partition val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty) - isDynamic.init.zip(isDynamic.tail).find(_ == (true, false)).foreach { _ => + if (isDynamic.init.zip(isDynamic.tail).contains((true, false))) { throw new SparkException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ac5c7a8220296..6ccbc22a4acfb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -55,8 +55,8 @@ private[hive] class SparkHiveWriterContainer( private var taID: SerializableWritable[TaskAttemptID] = null @transient private var writer: FileSinkOperator.RecordWriter = null - @transient private lazy val committer = conf.value.getOutputCommitter - @transient private lazy val jobContext = newJobContext(conf.value, jID.value) + @transient protected lazy val committer = conf.value.getOutputCommitter + @transient protected lazy val jobContext = newJobContext(conf.value, jID.value) @transient private lazy val taskContext = newTaskAttemptContext(conf.value, taID.value) @transient private lazy val outputFormat = conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef,Writable]] @@ -122,8 +122,6 @@ private[hive] class SparkHiveWriterContainer( } } - // ********* Private Functions ********* - private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { jobID = jobId splitID = splitId @@ -157,12 +155,18 @@ private[hive] object SparkHiveWriterContainer { } } +private[spark] object SparkHiveDynamicPartitionWriterContainer { + val SUCCESSFUL_JOB_OUTPUT_DIR_MARKER = "mapreduce.fileoutputcommitter.marksuccessfuljobs" +} + private[spark] class SparkHiveDynamicPartitionWriterContainer( @transient jobConf: JobConf, fileSinkConf: FileSinkDesc, dynamicPartColNames: Array[String]) extends SparkHiveWriterContainer(jobConf, fileSinkConf) { + import SparkHiveDynamicPartitionWriterContainer._ + private val defaultPartName = jobConf.get( ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultVal) @@ -179,6 +183,20 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( commit() } + override def commitJob(): Unit = { + // This is a hack to avoid writing _SUCCESS mark file. In lower versions of Hadoop (e.g. 1.0.4), + // semantics of FileSystem.globStatus() is different from higher versions (e.g. 2.4.1) and will + // include _SUCCESS file when glob'ing for dynamic partition data files. + // + // Better solution is to add a step similar to what Hive FileSinkOperator.jobCloseOp does: + // calling something like Utilities.mvFileToFinalPath to cleanup the output directory and then + // load it with loadDynamicPartitions/loadPartition/loadTable. + val oldMarker = jobConf.getBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, true) + jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, false) + super.commitJob() + jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) + } + override def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = { val dynamicPartPath = dynamicPartColNames .zip(row.takeRight(dynamicPartColNames.length)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 158cfb5bbee7c..2060e1f1a7a4b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{QueryTest, SchemaRDD} -import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} +import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.{QueryTest, SchemaRDD} +import org.apache.spark.storage.RDDBlockId class CachedTableSuite extends QueryTest { - import TestHive._ - /** * Throws a test failed exception when the number of cached tables differs from the expected * number. @@ -34,11 +34,24 @@ class CachedTableSuite extends QueryTest { case cached: InMemoryRelation => cached } - if (cachedData.size != numCachedTables) { - fail( - s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + - planWithCaching) - } + assert( + cachedData.size == numCachedTables, + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + planWithCaching) + } + + def rddIdOf(tableName: String): Int = { + val executedPlan = table(tableName).queryExecution.executedPlan + executedPlan.collect { + case InMemoryColumnarTableScan(_, _, relation) => + relation.cachedColumnBuffers.id + case _ => + fail(s"Table $tableName is not cached\n" + executedPlan) + }.head + } + + def isMaterialized(rddId: Int): Boolean = { + sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty } test("cache table") { @@ -102,16 +115,47 @@ class CachedTableSuite extends QueryTest { assert(!TestHive.isCached("src"), "Table 'src' should not be cached") } - test("CACHE TABLE AS SELECT") { - assertCached(sql("SELECT * FROM src"), 0) - sql("CACHE TABLE test AS SELECT key FROM src") + test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { + sql("CACHE TABLE testCacheTable AS SELECT * FROM src") + assertCached(table("testCacheTable")) - checkAnswer( - sql("SELECT * FROM test"), - sql("SELECT key FROM src").collect().toSeq) + val rddId = rddIdOf("testCacheTable") + assert( + isMaterialized(rddId), + "Eagerly cached in-memory table should have already been materialized") - assertCached(sql("SELECT * FROM test")) + uncacheTable("testCacheTable") + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } + + test("CACHE TABLE tableName AS SELECT ...") { + sql("CACHE TABLE testCacheTable AS SELECT key FROM src LIMIT 10") + assertCached(table("testCacheTable")) + + val rddId = rddIdOf("testCacheTable") + assert( + isMaterialized(rddId), + "Eagerly cached in-memory table should have already been materialized") + + uncacheTable("testCacheTable") + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } - assertCached(sql("SELECT * FROM test JOIN test"), 2) + test("CACHE LAZY TABLE tableName") { + sql("CACHE LAZY TABLE src") + assertCached(table("src")) + + val rddId = rddIdOf("src") + assert( + !isMaterialized(rddId), + "Lazily cached in-memory table shouldn't be materialized eagerly") + + sql("SELECT COUNT(*) FROM src").collect() + assert( + isMaterialized(rddId), + "Lazily cached in-memory table should have been materialized") + + uncacheTable("src") + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index a35c40efdc207..14e791fe0f0ee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.catalyst.plans.logical.NativeCommand -import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala index 9644b707eb1a0..46b11b582b26d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala @@ -25,34 +25,30 @@ import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.sql.api.java.JavaSchemaRDD import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.test.TestSQLContext // Implicits import scala.collection.JavaConversions._ class JavaHiveQLSuite extends FunSuite { - lazy val javaCtx = new JavaSparkContext(TestSQLContext.sparkContext) + lazy val javaCtx = new JavaSparkContext(TestHive.sparkContext) // There is a little trickery here to avoid instantiating two HiveContexts in the same JVM lazy val javaHiveCtx = new JavaHiveContext(javaCtx) { override val sqlContext = TestHive } - ignore("SELECT * FROM src") { + test("SELECT * FROM src") { assert( javaHiveCtx.sql("SELECT * FROM src").collect().map(_.getInt(0)) === TestHive.sql("SELECT * FROM src").collect().map(_.getInt(0)).toSeq) } - private val explainCommandClassName = - classOf[ExplainCommand].getSimpleName.stripSuffix("$") - def isExplanation(result: JavaSchemaRDD) = { val explanation = result.collect().map(_.getString(0)) - explanation.size > 1 && explanation.head.startsWith(explainCommandClassName) + explanation.size > 1 && explanation.head.startsWith("== Physical Plan ==") } - ignore("Query Hive native command execution result") { + test("Query Hive native command execution result") { val tableName = "test_native_commands" assertResult(0) { @@ -63,23 +59,18 @@ class JavaHiveQLSuite extends FunSuite { javaHiveCtx.sql(s"CREATE TABLE $tableName(key INT, value STRING)").count() } - javaHiveCtx.sql("SHOW TABLES").registerTempTable("show_tables") - assert( javaHiveCtx - .sql("SELECT result FROM show_tables") + .sql("SHOW TABLES") .collect() .map(_.getString(0)) .contains(tableName)) - assertResult(Array(Array("key", "int", "None"), Array("value", "string", "None"))) { - javaHiveCtx.sql(s"DESCRIBE $tableName").registerTempTable("describe_table") - - + assertResult(Array(Array("key", "int"), Array("value", "string"))) { javaHiveCtx - .sql("SELECT result FROM describe_table") + .sql(s"describe $tableName") .collect() - .map(_.getString(0).split("\t").map(_.trim)) + .map(row => Array(row.get(0).asInstanceOf[String], row.get(1).asInstanceOf[String])) .toArray } @@ -89,7 +80,7 @@ class JavaHiveQLSuite extends FunSuite { TestHive.reset() } - ignore("Exactly once semantics for DDL and command statements") { + test("Exactly once semantics for DDL and command statements") { val tableName = "test_exactly_once" val q0 = javaHiveCtx.sql(s"CREATE TABLE $tableName(key INT, value STRING)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2e282a9ade40c..2829105f43716 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -22,6 +22,7 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -675,6 +676,41 @@ class HiveQuerySuite extends HiveComparisonTest { sql("SELECT * FROM boom").queryExecution.analyzed } + test("SPARK-3810: PreInsertionCasts static partitioning support") { + val analyzedPlan = { + loadTestTable("srcpart") + sql("DROP TABLE IF EXISTS withparts") + sql("CREATE TABLE withparts LIKE srcpart") + sql("INSERT INTO TABLE withparts PARTITION(ds='1', hr='2') SELECT key, value FROM src") + .queryExecution.analyzed + } + + assertResult(1, "Duplicated project detected\n" + analyzedPlan) { + analyzedPlan.collect { + case _: Project => () + }.size + } + } + + test("SPARK-3810: PreInsertionCasts dynamic partitioning support") { + val analyzedPlan = { + loadTestTable("srcpart") + sql("DROP TABLE IF EXISTS withparts") + sql("CREATE TABLE withparts LIKE srcpart") + sql("SET hive.exec.dynamic.partition.mode=nonstrict") + + sql("CREATE TABLE IF NOT EXISTS withparts LIKE srcpart") + sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM src") + .queryExecution.analyzed + } + + assertResult(1, "Duplicated project detected\n" + analyzedPlan) { + analyzedPlan.collect { + case _: Project => () + }.size + } + } + test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 374848358e700..7d73ada12d107 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -217,7 +217,6 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Generate jobs and perform checkpoint for the given `time`. */ private def generateJobs(time: Time) { - SparkEnv.set(ssc.env) Try(graph.generateJobs(time)) match { case Success(jobs) => val receivedBlockInfo = graph.getReceiverInputStreams.map { stream => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 1b034b9fb187c..cfa3cd8925c80 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -138,7 +138,6 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } jobSet.handleJobStart(job) logInfo("Starting job " + job.id + " from job set of time " + jobSet.time) - SparkEnv.set(ssc.env) } private def handleJobCompletion(job: Job) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 5307fe189d717..7149dbc12a365 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -202,7 +202,6 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging { @transient val thread = new Thread() { override def run() { try { - SparkEnv.set(env) startReceivers() } catch { case ie: InterruptedException => logInfo("ReceiverLauncher interrupted") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 952a74fd5f6de..6107fcdc447b6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -18,8 +18,6 @@ package org.apache.spark.streaming import akka.actor.Actor -import akka.actor.IO -import akka.actor.IOManager import akka.actor.Props import akka.util.ByteString @@ -144,59 +142,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") } - // TODO: This test works in IntelliJ but not through SBT - ignore("actor input stream") { - // Start the server - val testServer = new TestServer() - val port = testServer.port - testServer.start() - - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val networkStream = ssc.actorStream[String](Props(new TestActor(port)), "TestActor", - // Had to pass the local value of port to prevent from closing over entire scope - StorageLevel.MEMORY_AND_DISK) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(networkStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) - outputStream.register() - ssc.start() - - // Feed data to the server to send to the network receiver - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = 1 to 9 - val expectedOutput = input.map(x => x.toString) - Thread.sleep(1000) - for (i <- 0 until input.size) { - testServer.send(input(i).toString) - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) - } - Thread.sleep(1000) - logInfo("Stopping server") - testServer.stop() - logInfo("Stopping context") - ssc.stop() - - // Verify whether data received was as expected - logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) - logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output.size = " + expectedOutput.size) - logInfo("expected output") - expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("--------------------------------") - - // Verify whether all the elements received are as expected - // (whether the elements were received one in each interval is not verified) - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - assert(output(i) === expectedOutput(i)) - } - } - - test("multi-thread receiver") { // set up the test receiver val numThreads = 10 @@ -378,22 +323,6 @@ class TestServer(portToBind: Int = 0) extends Logging { def port = serverSocket.getLocalPort } -/** This is an actor for testing actor input stream */ -class TestActor(port: Int) extends Actor with ActorHelper { - - def bytesToString(byteString: ByteString) = byteString.utf8String - - override def preStart(): Unit = { - @deprecated("suppress compile time deprecation warning", "1.0.0") - val unit = IOManager(context.system).connect(new InetSocketAddress(port)) - } - - def receive = { - case IO.Read(socket, bytes) => - store(bytesToString(bytes)) - } -} - /** This is a receiver to test multiple threads inserting data using block generator */ class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int) extends Receiver[Int](StorageLevel.MEMORY_ONLY_SER) with Logging { diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 6c93d8582330b..abd37834ed3cc 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -43,7 +43,7 @@ private[yarn] class YarnAllocationHandler( args: ApplicationMasterArguments, preferredNodes: collection.Map[String, collection.Set[SplitInfo]], securityMgr: SecurityManager) - extends YarnAllocator(conf, sparkConf, args, preferredNodes, securityMgr) { + extends YarnAllocator(conf, sparkConf, appAttemptId, args, preferredNodes, securityMgr) { private val lastResponseId = new AtomicInteger() private val releaseList: CopyOnWriteArrayList[ContainerId] = new CopyOnWriteArrayList() diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala index 9bd1719cb1808..7faf55bc63372 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala @@ -40,6 +40,7 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC private var rpc: YarnRPC = null private var resourceManager: AMRMProtocol = _ private var uiHistoryAddress: String = _ + private var registered: Boolean = false override def register( conf: YarnConfiguration, @@ -51,8 +52,11 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC this.rpc = YarnRPC.create(conf) this.uiHistoryAddress = uiHistoryAddress - resourceManager = registerWithResourceManager(conf) - registerApplicationMaster(uiAddress) + synchronized { + resourceManager = registerWithResourceManager(conf) + registerApplicationMaster(uiAddress) + registered = true + } new YarnAllocationHandler(conf, sparkConf, resourceManager, getAttemptId(), args, preferredNodeLocations, securityMgr) @@ -66,14 +70,16 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC appAttemptId } - override def shutdown(status: FinalApplicationStatus, diagnostics: String = "") = { - val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) - .asInstanceOf[FinishApplicationMasterRequest] - finishReq.setAppAttemptId(getAttemptId()) - finishReq.setFinishApplicationStatus(status) - finishReq.setDiagnostics(diagnostics) - finishReq.setTrackingUrl(uiHistoryAddress) - resourceManager.finishApplicationMaster(finishReq) + override def unregister(status: FinalApplicationStatus, diagnostics: String = "") = synchronized { + if (registered) { + val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) + .asInstanceOf[FinishApplicationMasterRequest] + finishReq.setAppAttemptId(getAttemptId()) + finishReq.setFinishApplicationStatus(status) + finishReq.setDiagnostics(diagnostics) + finishReq.setTrackingUrl(uiHistoryAddress) + resourceManager.finishApplicationMaster(finishReq) + } } override def getAmIpFilterParams(conf: YarnConfiguration, proxyBase: String) = { diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index caceef5d4b5b0..a3c43b43848d2 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -33,6 +33,7 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} +import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend @@ -56,8 +57,11 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) + @volatile private var exitCode = 0 + @volatile private var unregistered = false @volatile private var finished = false @volatile private var finalStatus = FinalApplicationStatus.UNDEFINED + @volatile private var finalMsg: String = "" @volatile private var userClassThread: Thread = _ private var reporterThread: Thread = _ @@ -71,80 +75,107 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, private val sparkContextRef = new AtomicReference[SparkContext](null) final def run(): Int = { - val appAttemptId = client.getAttemptId() + try { + val appAttemptId = client.getAttemptId() - if (isDriver) { - // Set the web ui port to be ephemeral for yarn so we don't conflict with - // other spark processes running on the same box - System.setProperty("spark.ui.port", "0") + if (isDriver) { + // Set the web ui port to be ephemeral for yarn so we don't conflict with + // other spark processes running on the same box + System.setProperty("spark.ui.port", "0") - // Set the master property to match the requested mode. - System.setProperty("spark.master", "yarn-cluster") + // Set the master property to match the requested mode. + System.setProperty("spark.master", "yarn-cluster") - // Propagate the application ID so that YarnClusterSchedulerBackend can pick it up. - System.setProperty("spark.yarn.app.id", appAttemptId.getApplicationId().toString()) - } + // Propagate the application ID so that YarnClusterSchedulerBackend can pick it up. + System.setProperty("spark.yarn.app.id", appAttemptId.getApplicationId().toString()) + } - logInfo("ApplicationAttemptId: " + appAttemptId) + logInfo("ApplicationAttemptId: " + appAttemptId) - val cleanupHook = new Runnable { - override def run() { - // If the SparkContext is still registered, shut it down as a best case effort in case - // users do not call sc.stop or do System.exit(). - val sc = sparkContextRef.get() - if (sc != null) { - logInfo("Invoking sc stop from shutdown hook") - sc.stop() - finish(FinalApplicationStatus.SUCCEEDED) - } + val cleanupHook = new Runnable { + override def run() { + // If the SparkContext is still registered, shut it down as a best case effort in case + // users do not call sc.stop or do System.exit(). + val sc = sparkContextRef.get() + if (sc != null) { + logInfo("Invoking sc stop from shutdown hook") + sc.stop() + } + val maxAppAttempts = client.getMaxRegAttempts(yarnConf) + val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts + + if (!finished) { + // this shouldn't ever happen, but if it does assume weird failure + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, + "shutdown hook called without cleanly finishing") + } - // Cleanup the staging dir after the app is finished, or if it's the last attempt at - // running the AM. - val maxAppAttempts = client.getMaxRegAttempts(yarnConf) - val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts - if (finished || isLastAttempt) { - cleanupStagingDir() + if (!unregistered) { + // we only want to unregister if we don't want the RM to retry + if (finalStatus == FinalApplicationStatus.SUCCEEDED || isLastAttempt) { + unregister(finalStatus, finalMsg) + cleanupStagingDir() + } + } } } - } - // Use higher priority than FileSystem. - assert(ApplicationMaster.SHUTDOWN_HOOK_PRIORITY > FileSystem.SHUTDOWN_HOOK_PRIORITY) - ShutdownHookManager - .get().addShutdownHook(cleanupHook, ApplicationMaster.SHUTDOWN_HOOK_PRIORITY) + // Use higher priority than FileSystem. + assert(ApplicationMaster.SHUTDOWN_HOOK_PRIORITY > FileSystem.SHUTDOWN_HOOK_PRIORITY) + ShutdownHookManager + .get().addShutdownHook(cleanupHook, ApplicationMaster.SHUTDOWN_HOOK_PRIORITY) - // Call this to force generation of secret so it gets populated into the - // Hadoop UGI. This has to happen before the startUserClass which does a - // doAs in order for the credentials to be passed on to the executor containers. - val securityMgr = new SecurityManager(sparkConf) + // Call this to force generation of secret so it gets populated into the + // Hadoop UGI. This has to happen before the startUserClass which does a + // doAs in order for the credentials to be passed on to the executor containers. + val securityMgr = new SecurityManager(sparkConf) - if (isDriver) { - runDriver(securityMgr) - } else { - runExecutorLauncher(securityMgr) + if (isDriver) { + runDriver(securityMgr) + } else { + runExecutorLauncher(securityMgr) + } + } catch { + case e: Exception => + // catch everything else if not specifically handled + logError("Uncaught exception: ", e) + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, + "Uncaught exception: " + e.getMessage()) } + exitCode + } - if (finalStatus != FinalApplicationStatus.UNDEFINED) { - finish(finalStatus) - 0 - } else { - 1 + /** + * unregister is used to completely unregister the application from the ResourceManager. + * This means the ResourceManager will not retry the application attempt on your behalf if + * a failure occurred. + */ + final def unregister(status: FinalApplicationStatus, diagnostics: String = null) = synchronized { + if (!unregistered) { + logInfo(s"Unregistering ApplicationMaster with $status" + + Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) + unregistered = true + client.unregister(status, Option(diagnostics).getOrElse("")) } } - final def finish(status: FinalApplicationStatus, diagnostics: String = null) = synchronized { + final def finish(status: FinalApplicationStatus, code: Int, msg: String = null) = synchronized { if (!finished) { - logInfo(s"Finishing ApplicationMaster with $status" + - Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) - finished = true + logInfo(s"Final app status: ${status}, exitCode: ${code}" + + Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) + exitCode = code finalStatus = status - try { - if (Thread.currentThread() != reporterThread) { - reporterThread.interrupt() - reporterThread.join() - } - } finally { - client.shutdown(status, Option(diagnostics).getOrElse("")) + finalMsg = msg + finished = true + if (Thread.currentThread() != reporterThread && reporterThread != null) { + logDebug("shutting down reporter thread") + reporterThread.interrupt() + } + if (Thread.currentThread() != userClassThread && userClassThread != null) { + logDebug("shutting down user thread") + userClassThread.interrupt() } } } @@ -182,7 +213,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, private def runDriver(securityMgr: SecurityManager): Unit = { addAmIpFilter() - val userThread = startUserClass() + setupSystemSecurityManager() + userClassThread = startUserClass() // This a bit hacky, but we need to wait until the spark.driver.port property has // been set by the Thread executing the user class. @@ -190,15 +222,12 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, // If there is no SparkContext at this point, just fail the app. if (sc == null) { - finish(FinalApplicationStatus.FAILED, "Timed out waiting for SparkContext.") + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_SC_NOT_INITED, + "Timed out waiting for SparkContext.") } else { registerAM(sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) - try { - userThread.join() - } finally { - // In cluster mode, ask the reporter thread to stop since the user app is finished. - reporterThread.interrupt() - } + userClassThread.join() } } @@ -211,7 +240,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, // In client mode the actor will stop the reporter thread. reporterThread.join() - finalStatus = FinalApplicationStatus.SUCCEEDED } private def launchReporterThread(): Thread = { @@ -231,33 +259,26 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, val t = new Thread { override def run() { var failureCount = 0 - while (!finished) { try { - checkNumExecutorsFailed() - if (!finished) { + if (allocator.getNumExecutorsFailed >= maxNumExecutorFailures) { + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_MAX_EXECUTOR_FAILURES, + "Max number of executor failures reached") + } else { logDebug("Sending progress") allocator.allocateResources() } failureCount = 0 } catch { + case i: InterruptedException => case e: Throwable => { failureCount += 1 if (!NonFatal(e) || failureCount >= reporterMaxFailures) { - logError("Exception was thrown from Reporter thread.", e) - finish(FinalApplicationStatus.FAILED, "Exception was thrown" + - s"${failureCount} time(s) from Reporter thread.") - - /** - * If exception is thrown from ReporterThread, - * interrupt user class to stop. - * Without this interrupting, if exception is - * thrown before allocating enough executors, - * YarnClusterScheduler waits until timeout even though - * we cannot allocate executors. - */ - logInfo("Interrupting user class to stop.") - userClassThread.interrupt + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_REPORTER_FAILURE, "Exception was thrown " + + s"${failureCount} time(s) from Reporter thread.") + } else { logWarning(s"Reporter thread fails ${failureCount} time(s) in a row.", e) } @@ -308,7 +329,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, sparkContextRef.synchronized { var count = 0 val waitTime = 10000L - val numTries = sparkConf.getInt("spark.yarn.ApplicationMaster.waitTries", 10) + val numTries = sparkConf.getInt("spark.yarn.applicationMaster.waitTries", 10) while (sparkContextRef.get() == null && count < numTries && !finished) { logInfo("Waiting for spark context initialization ... " + count) count = count + 1 @@ -328,10 +349,19 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, private def waitForSparkDriver(): ActorRef = { logInfo("Waiting for Spark driver to be reachable.") var driverUp = false + var count = 0 val hostport = args.userArgs(0) val (driverHost, driverPort) = Utils.parseHostPort(hostport) - while (!driverUp) { + + // spark driver should already be up since it launched us, but we don't want to + // wait forever, so wait 100 seconds max to match the cluster mode setting. + // Leave this config unpublished for now. SPARK-3779 to investigating changing + // this config to be time based. + val numTries = sparkConf.getInt("spark.yarn.applicationMaster.waitTries", 1000) + + while (!driverUp && !finished && count < numTries) { try { + count = count + 1 val socket = new Socket(driverHost, driverPort) socket.close() logInfo("Driver now available: %s:%s".format(driverHost, driverPort)) @@ -343,6 +373,11 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, Thread.sleep(100) } } + + if (!driverUp) { + throw new SparkException("Failed to connect to driver!") + } + sparkConf.set("spark.driver.host", driverHost) sparkConf.set("spark.driver.port", driverPort.toString) @@ -354,18 +389,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM") } - private def checkNumExecutorsFailed() = { - if (allocator.getNumExecutorsFailed >= maxNumExecutorFailures) { - finish(FinalApplicationStatus.FAILED, "Max number of executor failures reached.") - - val sc = sparkContextRef.get() - if (sc != null) { - logInfo("Invoking sc stop from checkNumExecutorsFailed") - sc.stop() - } - } - } - /** Add the Yarn IP filter that is required for properly securing the UI. */ private def addAmIpFilter() = { val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) @@ -379,40 +402,81 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, } } + /** + * This system security manager applies to the entire process. + * It's main purpose is to handle the case if the user code does a System.exit. + * This allows us to catch that and properly set the YARN application status and + * cleanup if needed. + */ + private def setupSystemSecurityManager(): Unit = { + try { + var stopped = false + System.setSecurityManager(new java.lang.SecurityManager() { + override def checkExit(paramInt: Int) { + if (!stopped) { + logInfo("In securityManager checkExit, exit code: " + paramInt) + if (paramInt == 0) { + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + } else { + finish(FinalApplicationStatus.FAILED, + paramInt, + "User class exited with non-zero exit code") + } + stopped = true + } + } + // required for the checkExit to work properly + override def checkPermission(perm: java.security.Permission): Unit = {} + }) + } + catch { + case e: SecurityException => + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_SECURITY, + "Error in setSecurityManager") + logError("Error in setSecurityManager:", e) + } + } + + /** + * Start the user class, which contains the spark driver, in a separate Thread. + * If the main routine exits cleanly or exits with System.exit(0) we + * assume it was successful, for all other cases we assume failure. + * + * Returns the user thread that was started. + */ private def startUserClass(): Thread = { logInfo("Starting the user JAR in a separate Thread") System.setProperty("spark.executor.instances", args.numExecutors.toString) val mainMethod = Class.forName(args.userClass, false, Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]]) - userClassThread = new Thread { + val userThread = new Thread { override def run() { - var status = FinalApplicationStatus.FAILED try { - // Copy val mainArgs = new Array[String](args.userArgs.size) args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size) mainMethod.invoke(null, mainArgs) - // Some apps have "System.exit(0)" at the end. The user thread will stop here unless - // it has an uncaught exception thrown out. It needs a shutdown hook to set SUCCEEDED. - status = FinalApplicationStatus.SUCCEEDED + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + logDebug("Done running users class") } catch { case e: InvocationTargetException => e.getCause match { case _: InterruptedException => // Reporter thread can interrupt to stop user class - - case e => throw e + case e: Exception => + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_EXCEPTION_USER_CLASS, + "User class threw exception: " + e.getMessage) + // re-throw to get it logged + throw e } - } finally { - logDebug("Finishing main") - finalStatus = status } } } - userClassThread.setName("Driver") - userClassThread.start() - userClassThread + userThread.setName("Driver") + userThread.start() + userThread } // Actor used to monitor the driver when running in client deploy mode. @@ -432,7 +496,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, override def receive = { case x: DisassociatedEvent => logInfo(s"Driver terminated or disconnected! Shutting down. $x") - finish(FinalApplicationStatus.SUCCEEDED) + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) case x: AddWebUIFilter => logInfo(s"Add WebUI Filter. $x") driver ! x @@ -446,6 +510,15 @@ object ApplicationMaster extends Logging { val SHUTDOWN_HOOK_PRIORITY: Int = 30 + // exit codes for different causes, no reason behind the values + private val EXIT_SUCCESS = 0 + private val EXIT_UNCAUGHT_EXCEPTION = 10 + private val EXIT_MAX_EXECUTOR_FAILURES = 11 + private val EXIT_REPORTER_FAILURE = 12 + private val EXIT_SC_NOT_INITED = 13 + private val EXIT_SECURITY = 14 + private val EXIT_EXCEPTION_USER_CLASS = 15 + private var master: ApplicationMaster = _ def main(args: Array[String]) = { diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 6ecac6eae6e03..14a0386b78978 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -23,6 +23,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, ListBuffer, Map} import scala.util.{Try, Success, Failure} +import com.google.common.base.Objects import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission @@ -64,12 +65,12 @@ private[spark] trait ClientBase extends Logging { s"memory capability of the cluster ($maxMem MB per container)") val executorMem = args.executorMemory + executorMemoryOverhead if (executorMem > maxMem) { - throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" + + throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" + s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") } val amMem = args.amMemory + amMemoryOverhead if (amMem > maxMem) { - throw new IllegalArgumentException(s"Required AM memory (${args.amMemory}" + + throw new IllegalArgumentException(s"Required AM memory (${args.amMemory}" + s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") } logInfo("Will allocate AM container, with %d MB memory including %d MB overhead".format( @@ -771,15 +772,17 @@ private[spark] object ClientBase extends Logging { private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { val srcUri = srcFs.getUri() val dstUri = destFs.getUri() - if (srcUri.getScheme() == null) { - return false - } - if (!srcUri.getScheme().equals(dstUri.getScheme())) { + if (srcUri.getScheme() == null || srcUri.getScheme() != dstUri.getScheme()) { return false } + var srcHost = srcUri.getHost() var dstHost = dstUri.getHost() - if ((srcHost != null) && (dstHost != null)) { + + // In HA or when using viewfs, the host part of the URI may not actually be a host, but the + // name of the HDFS namespace. Those names won't resolve, so avoid even trying if they + // match. + if (srcHost != null && dstHost != null && srcHost != dstHost) { try { srcHost = InetAddress.getByName(srcHost).getCanonicalHostName() dstHost = InetAddress.getByName(dstHost).getCanonicalHostName() @@ -787,19 +790,9 @@ private[spark] object ClientBase extends Logging { case e: UnknownHostException => return false } - if (!srcHost.equals(dstHost)) { - return false - } - } else if (srcHost == null && dstHost != null) { - return false - } else if (srcHost != null && dstHost == null) { - return false - } - if (srcUri.getPort() != dstUri.getPort()) { - false - } else { - true } + + Objects.equal(srcHost, dstHost) && srcUri.getPort() == dstUri.getPort() } } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 943dc56202a37..2510b9c9cef68 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -49,12 +49,12 @@ trait YarnRMClient { securityMgr: SecurityManager): YarnAllocator /** - * Shuts down the AM. Guaranteed to only be called once. + * Unregister the AM. Guaranteed to only be called once. * * @param status The final status of the AM. * @param diagnostics Diagnostics message to include in the final status. */ - def shutdown(status: FinalApplicationStatus, diagnostics: String = ""): Unit + def unregister(status: FinalApplicationStatus, diagnostics: String = ""): Unit /** Returns the attempt ID. */ def getAttemptId(): ApplicationAttemptId diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index 97eb0548e77c3..fe55d70ccc370 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -41,4 +41,55 @@ + + + + hadoop-2.2 + + 1.9 + + + + org.mortbay.jetty + jetty + 6.1.26 + + + org.mortbay.jetty + servlet-api + + + test + + + com.sun.jersey + jersey-core + ${jersey.version} + test + + + com.sun.jersey + jersey-json + ${jersey.version} + test + + + stax + stax-api + + + + + com.sun.jersey + jersey-server + ${jersey.version} + test + + + + + diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala index b581790e158ac..8d4b96ed79933 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala @@ -45,6 +45,7 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC private var amClient: AMRMClient[ContainerRequest] = _ private var uiHistoryAddress: String = _ + private var registered: Boolean = false override def register( conf: YarnConfiguration, @@ -59,13 +60,19 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC this.uiHistoryAddress = uiHistoryAddress logInfo("Registering the ApplicationMaster") - amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) + synchronized { + amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) + registered = true + } new YarnAllocationHandler(conf, sparkConf, amClient, getAttemptId(), args, preferredNodeLocations, securityMgr) } - override def shutdown(status: FinalApplicationStatus, diagnostics: String = "") = - amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress) + override def unregister(status: FinalApplicationStatus, diagnostics: String = "") = synchronized { + if (registered) { + amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress) + } + } override def getAttemptId() = { val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name())