{
/// A utility page iterator which stores page readers in memory, used for tests.
#[derive(Clone)]
pub struct InMemoryPageIterator Some methods are copied from `org.apache.spark.unsafe.memory.TaskMemoryManager` with
- * modifications. Most modifications are to remove the dependency on the configured memory mode.
+ * store serialized rows. This class is simply an implementation of `MemoryConsumer` that delegates
+ * memory allocation to the `TaskMemoryManager`. This requires that the `TaskMemoryManager` is
+ * configured with `MemoryMode.OFF_HEAP`, i.e. it is using off-heap memory.
*/
-public final class CometShuffleMemoryAllocator extends MemoryConsumer {
- private final UnsafeMemoryAllocator allocator = new UnsafeMemoryAllocator();
-
- private final long pageSize;
- private final long totalMemory;
- private long allocatedMemory = 0L;
-
- /** The number of bits used to address the page table. */
- private static final int PAGE_NUMBER_BITS = 13;
-
- /** The number of entries in the page table. */
- private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS;
-
- private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE];
- private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE);
+public final class CometShuffleMemoryAllocator extends CometShuffleMemoryAllocatorTrait {
+ private static CometShuffleMemoryAllocatorTrait INSTANCE;
- private static final int OFFSET_BITS = 51;
- private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL;
-
- private static CometShuffleMemoryAllocator INSTANCE;
-
- public static synchronized CometShuffleMemoryAllocator getInstance(
+ /**
+ * Returns the singleton instance of `CometShuffleMemoryAllocator`. This method should be used
+ * instead of the constructor to ensure that only one instance of `CometShuffleMemoryAllocator` is
+ * created. For Spark tests, this returns `CometTestShuffleMemoryAllocator` which is a test-only
+ * allocator that should not be used in production.
+ */
+ public static CometShuffleMemoryAllocatorTrait getInstance(
SparkConf conf, TaskMemoryManager taskMemoryManager, long pageSize) {
- if (INSTANCE == null) {
- INSTANCE = new CometShuffleMemoryAllocator(conf, taskMemoryManager, pageSize);
+ boolean isSparkTesting = Utils.isTesting();
+ boolean useUnifiedMemAllocator =
+ (boolean)
+ CometConf$.MODULE$.COMET_COLUMNAR_SHUFFLE_UNIFIED_MEMORY_ALLOCATOR_IN_TEST().get();
+
+ if (isSparkTesting && !useUnifiedMemAllocator) {
+ synchronized (CometShuffleMemoryAllocator.class) {
+ if (INSTANCE == null) {
+ // CometTestShuffleMemoryAllocator handles pages by itself so it can be a singleton.
+ INSTANCE = new CometTestShuffleMemoryAllocator(conf, taskMemoryManager, pageSize);
+ }
+ }
+ return INSTANCE;
+ } else {
+ if (taskMemoryManager.getTungstenMemoryMode() != MemoryMode.OFF_HEAP) {
+ throw new IllegalArgumentException(
+ "CometShuffleMemoryAllocator should be used with off-heap "
+ + "memory mode, but got "
+ + taskMemoryManager.getTungstenMemoryMode());
+ }
+
+ // CometShuffleMemoryAllocator stores pages in TaskMemoryManager which is not singleton,
+ // but one instance per task. So we need to create a new instance for each task.
+ return new CometShuffleMemoryAllocator(taskMemoryManager, pageSize);
}
-
- return INSTANCE;
}
- CometShuffleMemoryAllocator(SparkConf conf, TaskMemoryManager taskMemoryManager, long pageSize) {
+ CometShuffleMemoryAllocator(TaskMemoryManager taskMemoryManager, long pageSize) {
super(taskMemoryManager, pageSize, MemoryMode.OFF_HEAP);
- this.pageSize = pageSize;
- this.totalMemory =
- CometSparkSessionExtensions$.MODULE$.getCometShuffleMemorySize(conf, SQLConf.get());
- }
-
- public synchronized long acquireMemory(long size) {
- if (allocatedMemory >= totalMemory) {
- throw new SparkOutOfMemoryError(
- "Unable to acquire "
- + size
- + " bytes of memory, current usage "
- + "is "
- + allocatedMemory
- + " bytes and max memory is "
- + totalMemory
- + " bytes");
- }
- long allocationSize = Math.min(size, totalMemory - allocatedMemory);
- allocatedMemory += allocationSize;
- return allocationSize;
}
public long spill(long l, MemoryConsumer memoryConsumer) throws IOException {
+ // JVM shuffle writer does not support spilling for other memory consumers
return 0;
}
- public synchronized LongArray allocateArray(long size) {
- long required = size * 8L;
- MemoryBlock page = allocate(required);
- return new LongArray(page);
- }
-
- public synchronized void freeArray(LongArray array) {
- if (array == null) {
- return;
- }
- free(array.memoryBlock());
- }
-
- public synchronized MemoryBlock allocatePage(long required) {
- long size = Math.max(pageSize, required);
- return allocate(size);
- }
-
- private synchronized MemoryBlock allocate(long required) {
- if (required > TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES) {
- throw new TooLargePageException(required);
- }
-
- long got = acquireMemory(required);
-
- if (got < required) {
- allocatedMemory -= got;
-
- throw new SparkOutOfMemoryError(
- "Unable to acquire "
- + required
- + " bytes of memory, got "
- + got
- + " bytes. Available: "
- + (totalMemory - allocatedMemory));
- }
-
- int pageNumber = allocatedPages.nextClearBit(0);
- if (pageNumber >= PAGE_TABLE_SIZE) {
- allocatedMemory -= got;
-
- throw new IllegalStateException(
- "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages");
- }
-
- MemoryBlock block = allocator.allocate(got);
-
- block.pageNumber = pageNumber;
- pageTable[pageNumber] = block;
- allocatedPages.set(pageNumber);
-
- return block;
+ public synchronized MemoryBlock allocate(long required) {
+ return this.allocatePage(required);
}
public synchronized void free(MemoryBlock block) {
- if (block.pageNumber == MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) {
- // Already freed block
- return;
- }
- allocatedMemory -= block.size();
-
- pageTable[block.pageNumber] = null;
- allocatedPages.clear(block.pageNumber);
- block.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER;
-
- allocator.free(block);
- }
-
- public synchronized long getAvailableMemory() {
- return totalMemory - allocatedMemory;
+ this.freePage(block);
}
/**
@@ -178,21 +96,11 @@ public synchronized long getAvailableMemory() {
* method assumes that the page number is valid.
*/
public long getOffsetInPage(long pagePlusOffsetAddress) {
- long offsetInPage = decodeOffset(pagePlusOffsetAddress);
- int pageNumber = TaskMemoryManager.decodePageNumber(pagePlusOffsetAddress);
- assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
- MemoryBlock page = pageTable[pageNumber];
- assert (page != null);
- return page.getBaseOffset() + offsetInPage;
- }
-
- public long decodeOffset(long pagePlusOffsetAddress) {
- return pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS;
+ return taskMemoryManager.getOffsetInPage(pagePlusOffsetAddress);
}
public long encodePageNumberAndOffset(int pageNumber, long offsetInPage) {
- assert (pageNumber >= 0);
- return ((long) pageNumber) << OFFSET_BITS | offsetInPage & MASK_LONG_LOWER_51_BITS;
+ return TaskMemoryManager.encodePageNumberAndOffset(pageNumber, offsetInPage);
}
public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) {
diff --git a/spark/src/main/java/org/apache/spark/shuffle/comet/CometShuffleMemoryAllocatorTrait.java b/spark/src/main/java/org/apache/spark/shuffle/comet/CometShuffleMemoryAllocatorTrait.java
new file mode 100644
index 000000000..6831396b3
--- /dev/null
+++ b/spark/src/main/java/org/apache/spark/shuffle/comet/CometShuffleMemoryAllocatorTrait.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.shuffle.comet;
+
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.MemoryMode;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+/** The base class for Comet JVM shuffle memory allocators. */
+public abstract class CometShuffleMemoryAllocatorTrait extends MemoryConsumer {
+ protected CometShuffleMemoryAllocatorTrait(
+ TaskMemoryManager taskMemoryManager, long pageSize, MemoryMode mode) {
+ super(taskMemoryManager, pageSize, mode);
+ }
+
+ public abstract MemoryBlock allocate(long required);
+
+ public abstract void free(MemoryBlock block);
+
+ public abstract long getOffsetInPage(long pagePlusOffsetAddress);
+
+ public abstract long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage);
+}
diff --git a/spark/src/main/java/org/apache/spark/shuffle/comet/CometTestShuffleMemoryAllocator.java b/spark/src/main/java/org/apache/spark/shuffle/comet/CometTestShuffleMemoryAllocator.java
new file mode 100644
index 000000000..084e82b2b
--- /dev/null
+++ b/spark/src/main/java/org/apache/spark/shuffle/comet/CometTestShuffleMemoryAllocator.java
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.shuffle.comet;
+
+import java.io.IOException;
+import java.util.BitSet;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.MemoryMode;
+import org.apache.spark.memory.SparkOutOfMemoryError;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.sql.internal.SQLConf;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.UnsafeMemoryAllocator;
+
+import org.apache.comet.CometSparkSessionExtensions$;
+
+/**
+ * A simple memory allocator used by `CometShuffleExternalSorter` to allocate memory blocks which
+ * store serialized rows. We don't rely on Spark memory allocator because we need to allocate
+ * off-heap memory no matter memory mode is on-heap or off-heap. This allocator is configured with
+ * fixed size of memory, and it will throw `SparkOutOfMemoryError` if the memory is not enough.
+ *
+ * Some methods are copied from `org.apache.spark.unsafe.memory.TaskMemoryManager` with
+ * modifications. Most modifications are to remove the dependency on the configured memory mode.
+ *
+ * This allocator is test-only and should not be used in production. It is used to test Comet JVM
+ * shuffle and execution with Spark tests which basically require on-heap memory configuration.
+ * Thus, this allocator is used to allocate separate off-heap memory allocation for Comet JVM
+ * shuffle and execution apart from Spark's on-heap memory configuration.
+ */
+public final class CometTestShuffleMemoryAllocator extends CometShuffleMemoryAllocatorTrait {
+ private final UnsafeMemoryAllocator allocator = new UnsafeMemoryAllocator();
+
+ private final long pageSize;
+ private final long totalMemory;
+ private long allocatedMemory = 0L;
+
+ /** The number of bits used to address the page table. */
+ private static final int PAGE_NUMBER_BITS = 13;
+
+ /** The number of entries in the page table. */
+ private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS;
+
+ private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE];
+ private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE);
+
+ private static final int OFFSET_BITS = 51;
+ private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL;
+
+ private static CometTestShuffleMemoryAllocator INSTANCE;
+
+ CometTestShuffleMemoryAllocator(
+ SparkConf conf, TaskMemoryManager taskMemoryManager, long pageSize) {
+ super(taskMemoryManager, pageSize, MemoryMode.OFF_HEAP);
+ this.pageSize = pageSize;
+ this.totalMemory =
+ CometSparkSessionExtensions$.MODULE$.getCometShuffleMemorySize(conf, SQLConf.get());
+ }
+
+ private synchronized long _acquireMemory(long size) {
+ if (allocatedMemory >= totalMemory) {
+ throw new SparkOutOfMemoryError(
+ "Unable to acquire "
+ + size
+ + " bytes of memory, current usage "
+ + "is "
+ + allocatedMemory
+ + " bytes and max memory is "
+ + totalMemory
+ + " bytes");
+ }
+ long allocationSize = Math.min(size, totalMemory - allocatedMemory);
+ allocatedMemory += allocationSize;
+ return allocationSize;
+ }
+
+ public long spill(long l, MemoryConsumer memoryConsumer) throws IOException {
+ return 0;
+ }
+
+ public synchronized LongArray allocateArray(long size) {
+ long required = size * 8L;
+ MemoryBlock page = allocateMemoryBlock(required);
+ return new LongArray(page);
+ }
+
+ public synchronized void freeArray(LongArray array) {
+ if (array == null) {
+ return;
+ }
+ free(array.memoryBlock());
+ }
+
+ public synchronized MemoryBlock allocate(long required) {
+ long size = Math.max(pageSize, required);
+ return allocateMemoryBlock(size);
+ }
+
+ private synchronized MemoryBlock allocateMemoryBlock(long required) {
+ if (required > TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES) {
+ throw new TooLargePageException(required);
+ }
+
+ long got = _acquireMemory(required);
+
+ if (got < required) {
+ allocatedMemory -= got;
+
+ throw new SparkOutOfMemoryError(
+ "Unable to acquire "
+ + required
+ + " bytes of memory, got "
+ + got
+ + " bytes. Available: "
+ + (totalMemory - allocatedMemory));
+ }
+
+ int pageNumber = allocatedPages.nextClearBit(0);
+ if (pageNumber >= PAGE_TABLE_SIZE) {
+ allocatedMemory -= got;
+
+ throw new IllegalStateException(
+ "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages");
+ }
+
+ MemoryBlock block = allocator.allocate(got);
+
+ block.pageNumber = pageNumber;
+ pageTable[pageNumber] = block;
+ allocatedPages.set(pageNumber);
+
+ return block;
+ }
+
+ public synchronized void free(MemoryBlock block) {
+ if (block.pageNumber == MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) {
+ // Already freed block
+ return;
+ }
+ allocatedMemory -= block.size();
+
+ pageTable[block.pageNumber] = null;
+ allocatedPages.clear(block.pageNumber);
+ block.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER;
+
+ allocator.free(block);
+ }
+
+ /**
+ * Returns the offset in the page for the given page plus base offset address. Note that this
+ * method assumes that the page number is valid.
+ */
+ public long getOffsetInPage(long pagePlusOffsetAddress) {
+ long offsetInPage = decodeOffset(pagePlusOffsetAddress);
+ int pageNumber = TaskMemoryManager.decodePageNumber(pagePlusOffsetAddress);
+ assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
+ MemoryBlock page = pageTable[pageNumber];
+ assert (page != null);
+ return page.getBaseOffset() + offsetInPage;
+ }
+
+ public long decodeOffset(long pagePlusOffsetAddress) {
+ return pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS;
+ }
+
+ public long encodePageNumberAndOffset(int pageNumber, long offsetInPage) {
+ assert (pageNumber >= 0);
+ return ((long) pageNumber) << OFFSET_BITS | offsetInPage & MASK_LONG_LOWER_51_BITS;
+ }
+
+ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) {
+ return encodePageNumberAndOffset(page.pageNumber, offsetInPage - page.getBaseOffset());
+ }
+}
diff --git a/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java b/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java
index ed3e2be66..cc4495570 100644
--- a/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java
+++ b/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java
@@ -38,6 +38,7 @@
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.comet.CometShuffleChecksumSupport;
import org.apache.spark.shuffle.comet.CometShuffleMemoryAllocator;
+import org.apache.spark.shuffle.comet.CometShuffleMemoryAllocatorTrait;
import org.apache.spark.shuffle.comet.TooLargePageException;
import org.apache.spark.sql.comet.execution.shuffle.CometUnsafeShuffleWriter;
import org.apache.spark.sql.comet.execution.shuffle.ShuffleThreadPool;
@@ -110,7 +111,7 @@ public final class CometShuffleExternalSorter implements CometShuffleChecksumSup
// The memory allocator for this sorter. It is used to allocate/free memory pages for this sorter.
// Because we need to allocate off-heap memory regardless of configured Spark memory mode
// (on-heap/off-heap), we need a separate memory allocator.
- private final CometShuffleMemoryAllocator allocator;
+ private final CometShuffleMemoryAllocatorTrait allocator;
/** Whether to write shuffle spilling file in async mode */
private final boolean isAsync;
diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java
index f793874d7..dcb9d99d3 100644
--- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java
+++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometDiskBlockWriter.java
@@ -41,6 +41,7 @@
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.comet.CometShuffleMemoryAllocator;
+import org.apache.spark.shuffle.comet.CometShuffleMemoryAllocatorTrait;
import org.apache.spark.shuffle.sort.RowPartition;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.types.StructType;
@@ -87,7 +88,7 @@ public final class CometDiskBlockWriter {
static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27;
/** The Comet allocator used to allocate pages. */
- private final CometShuffleMemoryAllocator allocator;
+ private final CometShuffleMemoryAllocatorTrait allocator;
/** The serializer used to write rows to memory page. */
private final SerializerInstance serializer;
@@ -435,12 +436,17 @@ public int compare(CometDiskBlockWriter lhs, CometDiskBlockWriter rhs) {
}
});
+ long totalFreed = 0;
for (CometDiskBlockWriter writer : currentWriters) {
// Force to spill the writer in a synchronous way, otherwise, we may not be able to
// acquire enough memory.
+ long used = writer.getActiveMemoryUsage();
+
writer.doSpill(true);
- if (allocator.getAvailableMemory() >= required) {
+ totalFreed += used;
+
+ if (totalFreed >= required) {
break;
}
}
diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java
index cc8c04fdd..3dc86b05b 100644
--- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java
+++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/SpillWriter.java
@@ -31,7 +31,7 @@
import org.apache.spark.memory.SparkOutOfMemoryError;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
-import org.apache.spark.shuffle.comet.CometShuffleMemoryAllocator;
+import org.apache.spark.shuffle.comet.CometShuffleMemoryAllocatorTrait;
import org.apache.spark.shuffle.sort.RowPartition;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.memory.MemoryBlock;
@@ -62,7 +62,7 @@ public abstract class SpillWriter {
// The memory allocator for this sorter. It is used to allocate/free memory pages for this sorter.
// Because we need to allocate off-heap memory regardless of configured Spark memory mode
// (on-heap/off-heap), we need a separate memory allocator.
- protected CometShuffleMemoryAllocator allocator;
+ protected CometShuffleMemoryAllocatorTrait allocator;
protected Native nativeLib;
@@ -134,7 +134,7 @@ public boolean acquireNewPageIfNecessary(int required) {
|| pageCursor + required > currentPage.getBaseOffset() + currentPage.size()) {
// TODO: try to find space in previous pages
try {
- currentPage = allocator.allocatePage(required);
+ currentPage = allocator.allocate(required);
} catch (SparkOutOfMemoryError error) {
try {
// Cannot allocate enough memory, spill
@@ -155,7 +155,7 @@ public boolean acquireNewPageIfNecessary(int required) {
public void initialCurrentPage(int required) {
assert (currentPage == null);
try {
- currentPage = allocator.allocatePage(required);
+ currentPage = allocator.allocate(required);
} catch (SparkOutOfMemoryError e) {
logger.error("Unable to acquire {} bytes of memory", required);
throw e;
diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
index b2eef5d09..08b24e029 100644
--- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
+++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
@@ -60,60 +60,39 @@ class CometExecIterator(
new CometBatchIterator(iterator, nativeUtil)
}.toArray
private val plan = {
- val configs = createNativeConf
+ val conf = SparkEnv.get.conf
+ // Only enable unified memory manager when off-heap mode is enabled. Otherwise,
+ // we'll use the built-in memory pool from DF, and initializes with `memory_limit`
+ // and `memory_fraction` below.
nativeLib.createPlan(
id,
- configs,
cometBatchIterators,
protobufQueryPlan,
numParts,
nativeMetrics,
- new CometTaskMemoryManager(id))
+ new CometTaskMemoryManager(id),
+ batchSize = COMET_BATCH_SIZE.get(),
+ use_unified_memory_manager = conf.getBoolean("spark.memory.offHeap.enabled", false),
+ memory_limit = CometSparkSessionExtensions.getCometMemoryOverhead(conf),
+ memory_fraction = COMET_EXEC_MEMORY_FRACTION.get(),
+ debug = COMET_DEBUG_ENABLED.get(),
+ explain = COMET_EXPLAIN_NATIVE_ENABLED.get(),
+ workerThreads = COMET_WORKER_THREADS.get(),
+ blockingThreads = COMET_BLOCKING_THREADS.get())
}
private var nextBatch: Option[ColumnarBatch] = None
private var currentBatch: ColumnarBatch = null
private var closed: Boolean = false
- /**
- * Creates a new configuration map to be passed to the native side.
- */
- private def createNativeConf: java.util.HashMap[String, String] = {
- val result = new java.util.HashMap[String, String]()
- val conf = SparkEnv.get.conf
-
- val maxMemory = CometSparkSessionExtensions.getCometMemoryOverhead(conf)
- // Only enable unified memory manager when off-heap mode is enabled. Otherwise,
- // we'll use the built-in memory pool from DF, and initializes with `memory_limit`
- // and `memory_fraction` below.
- result.put(
- "use_unified_memory_manager",
- String.valueOf(conf.get("spark.memory.offHeap.enabled", "false")))
- result.put("memory_limit", String.valueOf(maxMemory))
- result.put("memory_fraction", String.valueOf(COMET_EXEC_MEMORY_FRACTION.get()))
- result.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get()))
- result.put("debug_native", String.valueOf(COMET_DEBUG_ENABLED.get()))
- result.put("explain_native", String.valueOf(COMET_EXPLAIN_NATIVE_ENABLED.get()))
- result.put("worker_threads", String.valueOf(COMET_WORKER_THREADS.get()))
- result.put("blocking_threads", String.valueOf(COMET_BLOCKING_THREADS.get()))
-
- // Strip mandatory prefix spark. which is not required for DataFusion session params
- conf.getAll.foreach {
- case (k, v) if k.startsWith("spark.datafusion") =>
- result.put(k.replaceFirst("spark\\.", ""), v)
- case _ =>
- }
-
- result
- }
-
def getNextBatch(): Option[ColumnarBatch] = {
assert(partitionIndex >= 0 && partitionIndex < numParts)
nativeUtil.getNextBatch(
numOutputCols,
(arrayAddrs, schemaAddrs) => {
- nativeLib.executePlan(plan, partitionIndex, arrayAddrs, schemaAddrs)
+ val ctx = TaskContext.get()
+ nativeLib.executePlan(ctx.stageId(), partitionIndex, plan, arrayAddrs, schemaAddrs)
})
}
diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index 32668f0dd..522da0f58 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -53,7 +53,7 @@ import org.apache.spark.sql.types.{DoubleType, FloatType}
import org.apache.comet.CometConf._
import org.apache.comet.CometExplainInfo.getActualPlan
-import org.apache.comet.CometSparkSessionExtensions.{createMessage, getCometBroadcastNotEnabledReason, getCometShuffleNotEnabledReason, isANSIEnabled, isCometBroadCastForceEnabled, isCometEnabled, isCometExecEnabled, isCometJVMShuffleMode, isCometNativeShuffleMode, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSpark34Plus, isSpark40Plus, shouldApplySparkToColumnar, withInfo, withInfos}
+import org.apache.comet.CometSparkSessionExtensions.{createMessage, getCometBroadcastNotEnabledReason, getCometShuffleNotEnabledReason, isANSIEnabled, isCometBroadCastForceEnabled, isCometEnabled, isCometExecEnabled, isCometJVMShuffleMode, isCometNativeShuffleMode, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isOffHeapEnabled, isSpark34Plus, isSpark40Plus, isTesting, shouldApplySparkToColumnar, withInfo, withInfos}
import org.apache.comet.parquet.{CometParquetScan, SupportsComet}
import org.apache.comet.rules.RewriteJoin
import org.apache.comet.serde.OperatorOuterClass.Operator
@@ -207,7 +207,7 @@ class CometSparkSessionExtensions
// data source V1
case scanExec @ FileSourceScanExec(
- HadoopFsRelation(_, partitionSchema, _, _, _: ParquetFileFormat, _),
+ HadoopFsRelation(_, partitionSchema, _, _, fileFormat, _),
_: Seq[_],
requiredSchema,
_,
@@ -216,14 +216,15 @@ class CometSparkSessionExtensions
_,
_,
_)
- if CometScanExec.isSchemaSupported(requiredSchema)
+ if CometScanExec.isFileFormatSupported(fileFormat)
+ && CometScanExec.isSchemaSupported(requiredSchema)
&& CometScanExec.isSchemaSupported(partitionSchema) =>
logInfo("Comet extension enabled for v1 Scan")
CometScanExec(scanExec, session)
// data source v1 not supported case
case scanExec @ FileSourceScanExec(
- HadoopFsRelation(_, partitionSchema, _, _, _: ParquetFileFormat, _),
+ HadoopFsRelation(_, partitionSchema, _, _, fileFormat, _),
_: Seq[_],
requiredSchema,
_,
@@ -233,12 +234,15 @@ class CometSparkSessionExtensions
_,
_) =>
val info1 = createMessage(
+ !CometScanExec.isFileFormatSupported(fileFormat),
+ s"File format $fileFormat is not supported")
+ val info2 = createMessage(
!CometScanExec.isSchemaSupported(requiredSchema),
s"Schema $requiredSchema is not supported")
- val info2 = createMessage(
+ val info3 = createMessage(
!CometScanExec.isSchemaSupported(partitionSchema),
s"Partition schema $partitionSchema is not supported")
- withInfo(scanExec, Seq(info1, info2).flatten.mkString(","))
+ withInfo(scanExec, Seq(info1, info2, info3).flatten.mkString(","))
scanExec
}
}
@@ -938,6 +942,14 @@ class CometSparkSessionExtensions
}
override def apply(plan: SparkPlan): SparkPlan = {
+
+ // Comet required off-heap memory to be enabled
+ if (!isOffHeapEnabled(conf) && !isTesting) {
+ logWarning("Comet native exec disabled because spark.memory.offHeap.enabled=false")
+ withInfo(plan, "Comet native exec disabled because spark.memory.offHeap.enabled=false")
+ return plan
+ }
+
// DataFusion doesn't have ANSI mode. For now we just disable CometExec if ANSI mode is
// enabled.
if (isANSIEnabled(conf)) {
@@ -1194,8 +1206,21 @@ object CometSparkSessionExtensions extends Logging {
}
}
+ private[comet] def isOffHeapEnabled(conf: SQLConf): Boolean =
+ conf.getConfString("spark.memory.offHeap.enabled", "false").toBoolean
+
+ // Copied from org.apache.spark.util.Utils which is private to Spark.
+ private[comet] def isTesting: Boolean = {
+ System.getenv("SPARK_TESTING") != null || System.getProperty("spark.testing") != null
+ }
+
+ // Check whether Comet shuffle is enabled:
+ // 1. `COMET_EXEC_SHUFFLE_ENABLED` is true
+ // 2. `spark.shuffle.manager` is set to `CometShuffleManager`
+ // 3. Off-heap memory is enabled || Spark/Comet unit testing
private[comet] def isCometShuffleEnabled(conf: SQLConf): Boolean =
- COMET_EXEC_SHUFFLE_ENABLED.get(conf) && isCometShuffleManagerEnabled(conf)
+ COMET_EXEC_SHUFFLE_ENABLED.get(conf) && isCometShuffleManagerEnabled(conf) &&
+ (isOffHeapEnabled(conf) || isTesting)
private[comet] def getCometShuffleNotEnabledReason(conf: SQLConf): Option[String] = {
if (!COMET_EXEC_SHUFFLE_ENABLED.get(conf)) {
diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala
index ce0e26129..82c0373f4 100644
--- a/spark/src/main/scala/org/apache/comet/Native.scala
+++ b/spark/src/main/scala/org/apache/comet/Native.scala
@@ -19,13 +19,12 @@
package org.apache.comet
-import java.util.Map
-
import org.apache.spark.CometTaskMemoryManager
import org.apache.spark.sql.comet.CometMetricNode
class Native extends NativeBase {
+ // scalastyle:off
/**
* Create a native query plan from execution SparkPlan serialized in bytes.
* @param id
@@ -45,18 +44,31 @@ class Native extends NativeBase {
* @return
* the address to native query plan.
*/
+ // scalastyle:off
@native def createPlan(
id: Long,
- configMap: Map[String, String],
iterators: Array[CometBatchIterator],
plan: Array[Byte],
partitionCount: Int,
metrics: CometMetricNode,
- taskMemoryManager: CometTaskMemoryManager): Long
+ taskMemoryManager: CometTaskMemoryManager,
+ batchSize: Int,
+ use_unified_memory_manager: Boolean,
+ memory_limit: Long,
+ memory_fraction: Double,
+ debug: Boolean,
+ explain: Boolean,
+ workerThreads: Int,
+ blockingThreads: Int): Long
+ // scalastyle:on
/**
* Execute a native query plan based on given input Arrow arrays.
*
+ * @param stage
+ * the stage ID, for informational purposes
+ * @param partition
+ * the partition ID, for informational purposes
* @param plan
* the address to native query plan.
* @param arrayAddrs
@@ -67,8 +79,9 @@ class Native extends NativeBase {
* the number of rows, if -1, it means end of the output.
*/
@native def executePlan(
+ stage: Int,
+ partition: Int,
plan: Long,
- partitionId: Int,
arrayAddrs: Array[Long],
schemaAddrs: Array[Long]): Long
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 11d6d049f..859cb13be 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -70,9 +70,13 @@ object CometCast {
case _ =>
Unsupported
}
- case (_: DecimalType, _: DecimalType) =>
- // https://github.com/apache/datafusion-comet/issues/375
- Incompatible()
+ case (from: DecimalType, to: DecimalType) =>
+ if (to.precision < from.precision) {
+ // https://github.com/apache/datafusion/issues/13492
+ Incompatible(Some("Casting to smaller precision is not supported"))
+ } else {
+ Compatible()
+ }
case (DataTypes.StringType, _) =>
canCastFromString(toType, timeZoneId, evalMode)
case (_, DataTypes.StringType) =>
diff --git a/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala b/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala
index 17844aba8..bcb23986f 100644
--- a/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala
+++ b/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala
@@ -723,20 +723,22 @@ class ParquetFilters(
.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldNames, value))
- case sources.LessThan(name, value) if canMakeFilterOn(name, value) =>
+ case sources.LessThan(name, value) if (value != null) && canMakeFilterOn(name, value) =>
makeLt
.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldNames, value))
- case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) =>
+ case sources.LessThanOrEqual(name, value)
+ if (value != null) && canMakeFilterOn(name, value) =>
makeLtEq
.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldNames, value))
- case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) =>
+ case sources.GreaterThan(name, value) if (value != null) && canMakeFilterOn(name, value) =>
makeGt
.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldNames, value))
- case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) =>
+ case sources.GreaterThanOrEqual(name, value)
+ if (value != null) && canMakeFilterOn(name, value) =>
makeGtEq
.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldNames, value))
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 5ee16bd7b..a92ffa668 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -2194,6 +2194,35 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
None
}
+ case expr if expr.prettyName == "array_insert" =>
+ val srcExprProto = exprToProto(expr.children(0), inputs, binding)
+ val posExprProto = exprToProto(expr.children(1), inputs, binding)
+ val itemExprProto = exprToProto(expr.children(2), inputs, binding)
+ val legacyNegativeIndex =
+ SQLConf.get.getConfString("spark.sql.legacy.negativeIndexInArrayInsert").toBoolean
+ if (srcExprProto.isDefined && posExprProto.isDefined && itemExprProto.isDefined) {
+ val arrayInsertBuilder = ExprOuterClass.ArrayInsert
+ .newBuilder()
+ .setSrcArrayExpr(srcExprProto.get)
+ .setPosExpr(posExprProto.get)
+ .setItemExpr(itemExprProto.get)
+ .setLegacyNegativeIndex(legacyNegativeIndex)
+
+ Some(
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setArrayInsert(arrayInsertBuilder)
+ .build())
+ } else {
+ withInfo(
+ expr,
+ "unsupported arguments for ArrayInsert",
+ expr.children(0),
+ expr.children(1),
+ expr.children(2))
+ None
+ }
+
case ElementAt(child, ordinal, defaultValue, failOnError)
if child.dataType.isInstanceOf[ArrayType] =>
val childExpr = exprToProto(child, inputs, binding)
@@ -2239,7 +2268,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
withInfo(expr, "unsupported arguments for GetArrayStructFields", child)
None
}
-
+ case _ if expr.prettyName == "array_append" =>
+ createBinaryExpr(
+ expr.children(0),
+ expr.children(1),
+ inputs,
+ (builder, binaryExpr) => builder.setArrayAppend(binaryExpr))
case _ =>
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
None
@@ -2476,7 +2510,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
*/
def operator2Proto(op: SparkPlan, childOp: Operator*): Option[Operator] = {
val conf = op.conf
- val result = OperatorOuterClass.Operator.newBuilder()
+ val result = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
childOp.foreach(result.addChildren)
op match {
@@ -2952,7 +2986,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
case op if isCometSink(op) && op.output.forall(a => supportedDataType(a.dataType, true)) =>
// These operators are source of Comet native execution chain
val scanBuilder = OperatorOuterClass.Scan.newBuilder()
- scanBuilder.setSource(op.simpleStringWithNodeId())
+ val source = op.simpleStringWithNodeId()
+ if (source.isEmpty) {
+ scanBuilder.setSource(op.getClass.getSimpleName)
+ } else {
+ scanBuilder.setSource(source)
+ }
val scanTypes = op.output.flatten { attr =>
serializeDataType(attr.dataType)
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala
index 8ea0b1765..f75af5076 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala
@@ -57,7 +57,8 @@ case class CometCollectLimitExec(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"numPartitions" -> SQLMetrics.createMetric(
sparkContext,
- "number of partitions")) ++ readMetrics ++ writeMetrics
+ "number of partitions")) ++ readMetrics ++ writeMetrics ++ CometMetricNode.shuffleMetrics(
+ sparkContext)
private lazy val serializer: Serializer =
new UnsafeRowSerializer(child.output.size, longMetric("dataSize"))
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala
index 9698dc98b..2fc73bb7c 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala
@@ -88,7 +88,7 @@ object CometExecUtils {
* child partition
*/
def getLimitNativePlan(outputAttributes: Seq[Attribute], limit: Int): Option[Operator] = {
- val scanBuilder = OperatorOuterClass.Scan.newBuilder()
+ val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("LimitInput")
val scanOpBuilder = OperatorOuterClass.Operator.newBuilder()
val scanTypes = outputAttributes.flatten { attr =>
@@ -118,7 +118,7 @@ object CometExecUtils {
sortOrder: Seq[SortOrder],
child: SparkPlan,
limit: Int): Option[Operator] = {
- val scanBuilder = OperatorOuterClass.Scan.newBuilder()
+ val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("TopKInput")
val scanOpBuilder = OperatorOuterClass.Operator.newBuilder()
val scanTypes = outputAttributes.flatten { attr =>
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala
index 47c89d943..a26fa28c8 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala
@@ -130,6 +130,17 @@ object CometMetricNode {
"spilled_rows" -> SQLMetrics.createMetric(sc, "Total spilled rows"))
}
+ def shuffleMetrics(sc: SparkContext): Map[String, SQLMetric] = {
+ Map(
+ "elapsed_compute" -> SQLMetrics.createNanoTimingMetric(sc, "native shuffle time"),
+ "mempool_time" -> SQLMetrics.createNanoTimingMetric(sc, "memory pool time"),
+ "repart_time" -> SQLMetrics.createNanoTimingMetric(sc, "repartition time"),
+ "ipc_time" -> SQLMetrics.createNanoTimingMetric(sc, "encoding and compression time"),
+ "spill_count" -> SQLMetrics.createMetric(sc, "number of spills"),
+ "spilled_bytes" -> SQLMetrics.createMetric(sc, "spilled bytes"),
+ "input_batches" -> SQLMetrics.createMetric(sc, "number of input batches"))
+ }
+
/**
* Creates a [[CometMetricNode]] from a [[CometPlan]].
*/
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala
index 5d28b4b72..352d4a656 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.comet.shims.ShimCometScanExec
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
+import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions}
import org.apache.spark.sql.execution.datasources.v2.DataSourceRDD
import org.apache.spark.sql.execution.metric._
import org.apache.spark.sql.types._
@@ -510,4 +510,10 @@ object CometScanExec extends DataTypeSupport {
scanExec.logicalLink.foreach(batchScanExec.setLogicalLink)
batchScanExec
}
+
+ def isFileFormatSupported(fileFormat: FileFormat): Boolean = {
+ // Only support Spark's built-in Parquet scans, not others such as Delta which use a subclass
+ // of ParquetFileFormat.
+ fileFormat.getClass().equals(classOf[ParquetFileFormat])
+ }
}
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala
index 5582f4d68..19586628a 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala
@@ -57,7 +57,8 @@ case class CometTakeOrderedAndProjectExec(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"numPartitions" -> SQLMetrics.createMetric(
sparkContext,
- "number of partitions")) ++ readMetrics ++ writeMetrics
+ "number of partitions")) ++ readMetrics ++ writeMetrics ++ CometMetricNode.shuffleMetrics(
+ sparkContext)
private lazy val serializer: Serializer =
new UnsafeRowSerializer(child.output.size, longMetric("dataSize"))
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
index 388c07a27..0cd8a9ce6 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
@@ -79,7 +79,8 @@ case class CometShuffleExchangeExec(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"numPartitions" -> SQLMetrics.createMetric(
sparkContext,
- "number of partitions")) ++ readMetrics ++ writeMetrics
+ "number of partitions")) ++ readMetrics ++ writeMetrics ++ CometMetricNode.shuffleMetrics(
+ sparkContext)
override def nodeName: String = if (shuffleType == CometNativeShuffle) {
"CometExchange"
@@ -477,11 +478,21 @@ class CometShuffleWriteProcessor(
// Call native shuffle write
val nativePlan = getNativePlan(tempDataFilename, tempIndexFilename)
+ val detailedMetrics = Seq(
+ "elapsed_compute",
+ "ipc_time",
+ "repart_time",
+ "mempool_time",
+ "input_batches",
+ "spill_count",
+ "spilled_bytes")
+
// Maps native metrics to SQL metrics
val nativeSQLMetrics = Map(
"output_rows" -> metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN),
"data_size" -> metrics("dataSize"),
- "elapsed_compute" -> metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_WRITE_TIME))
+ "write_time" -> metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_WRITE_TIME)) ++
+ metrics.filterKeys(detailedMetrics.contains)
val nativeMetrics = CometMetricNode(nativeSQLMetrics)
// Getting rid of the fake partitionId
@@ -528,7 +539,7 @@ class CometShuffleWriteProcessor(
}
def getNativePlan(dataFile: String, indexFile: String): Operator = {
- val scanBuilder = OperatorOuterClass.Scan.newBuilder()
+ val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput")
val opBuilder = OperatorOuterClass.Operator.newBuilder()
val scanTypes = outputAttributes.flatten { attr =>
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index db9a870dc..f8c1a8b09 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -861,10 +861,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
// primitives
checkSparkAnswerAndOperator(
"SELECT CAST(struct(_1, _2, _3, _4, _5, _6, _7, _8) as string) FROM tbl")
- // TODO: enable tests for unsigned ints (_9, _10, _11, _12) once
- // https://github.com/apache/datafusion-comet/issues/1067 is resolved
- // checkSparkAnswerAndOperator(
- // "SELECT CAST(struct(_9, _10, _11, _12) as string) FROM tbl")
+ checkSparkAnswerAndOperator("SELECT CAST(struct(_9, _10, _11, _12) as string) FROM tbl")
// decimals
// TODO add _16 when https://github.com/apache/datafusion-comet/issues/1068 is resolved
checkSparkAnswerAndOperator("SELECT CAST(struct(_15, _17) as string) FROM tbl")
@@ -895,6 +892,34 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
+ test("cast between decimals with different precision and scale") {
+ // cast between default Decimal(38, 18) to Decimal(6,2)
+ val values = Seq(BigDecimal("12345.6789"), BigDecimal("9876.5432"), BigDecimal("123.4567"))
+ val df = withNulls(values)
+ .toDF("b")
+ .withColumn("a", col("b").cast(DecimalType(6, 2)))
+ checkSparkAnswer(df)
+ }
+
+ test("cast between decimals with higher precision than source") {
+ // cast between Decimal(10, 2) to Decimal(10,4)
+ castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 4))
+ }
+
+ test("cast between decimals with negative precision") {
+ // cast to negative scale
+ checkSparkMaybeThrows(
+ spark.sql("select a, cast(a as DECIMAL(10,-4)) from t order by a")) match {
+ case (expected, actual) =>
+ assert(expected.contains("PARSE_SYNTAX_ERROR") === actual.contains("PARSE_SYNTAX_ERROR"))
+ }
+ }
+
+ test("cast between decimals with zero precision") {
+ // cast between Decimal(10, 2) to Decimal(10,0)
+ castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 0))
+ }
+
private def generateFloats(): DataFrame = {
withNulls(gen.generateFloats(dataSize)).toDF("a")
}
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index e65feb6b2..f9e2c44c6 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
import org.apache.spark.sql.types.{Decimal, DecimalType}
-import org.apache.comet.CometSparkSessionExtensions.{isSpark33Plus, isSpark34Plus, isSpark40Plus}
+import org.apache.comet.CometSparkSessionExtensions.{isSpark33Plus, isSpark34Plus, isSpark35Plus, isSpark40Plus}
class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
import testImplicits._
@@ -119,10 +119,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)
withParquetTable(path.toString, "tbl") {
- // TODO: enable test for unsigned ints
- checkSparkAnswerAndOperator(
- "select _1, _2, _3, _4, _5, _6, _7, _8, _13, _14, _15, _16, _17, " +
- "_18, _19, _20 FROM tbl WHERE _2 > 100")
+ checkSparkAnswerAndOperator("select * FROM tbl WHERE _2 > 100")
}
}
}
@@ -1115,7 +1112,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 100)
withParquetTable(path.toString, "tbl") {
- Seq(2, 3, 4, 5, 6, 7, 15, 16, 17).foreach { col =>
+ Seq(2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 15, 16, 17).foreach { col =>
checkSparkAnswerAndOperator(s"SELECT abs(_${col}) FROM tbl")
}
}
@@ -1239,9 +1236,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
withParquetTable(path.toString, "tbl") {
for (s <- Seq(-5, -1, 0, 1, 5, -1000, 1000, -323, -308, 308, -15, 15, -16, 16, null)) {
// array tests
- // TODO: enable test for unsigned ints (_9, _10, _11, _12)
// TODO: enable test for floats (_6, _7, _8, _13)
- for (c <- Seq(2, 3, 4, 5, 15, 16, 17)) {
+ for (c <- Seq(2, 3, 4, 5, 9, 10, 11, 12, 15, 16, 17)) {
checkSparkAnswerAndOperator(s"select _${c}, round(_${c}, ${s}) FROM tbl")
}
// scalar tests
@@ -1452,9 +1448,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)
withParquetTable(path.toString, "tbl") {
- // _9 and _10 (uint8 and uint16) not supported
checkSparkAnswerAndOperator(
- "SELECT hex(_1), hex(_2), hex(_3), hex(_4), hex(_5), hex(_6), hex(_7), hex(_8), hex(_11), hex(_12), hex(_13), hex(_14), hex(_15), hex(_16), hex(_17), hex(_18), hex(_19), hex(_20) FROM tbl")
+ "SELECT hex(_1), hex(_2), hex(_3), hex(_4), hex(_5), hex(_6), hex(_7), hex(_8), hex(_9), hex(_10), hex(_11), hex(_12), hex(_13), hex(_14), hex(_15), hex(_16), hex(_17), hex(_18), hex(_19), hex(_20) FROM tbl")
}
}
}
@@ -2200,6 +2195,133 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
+ ignore("get_struct_field - select primitive fields") {
+ withTempPath { dir =>
+ // create input file with Comet disabled
+ withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+ val df = spark
+ .range(5)
+ // Add both a null struct and null inner value
+ .select(when(col("id") > 1, struct(when(col("id") > 2, col("id")).alias("id")))
+ .alias("nested1"))
+
+ df.write.parquet(dir.toString())
+ }
+
+ Seq("", "parquet").foreach { v1List =>
+ withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> v1List) {
+ val df = spark.read.parquet(dir.toString())
+ checkSparkAnswerAndOperator(df.select("nested1.id"))
+ }
+ }
+ }
+ }
+
+ ignore("get_struct_field - select subset of struct") {
+ withTempPath { dir =>
+ // create input file with Comet disabled
+ withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+ val df = spark
+ .range(5)
+ // Add both a null struct and null inner value
+ .select(
+ when(
+ col("id") > 1,
+ struct(
+ when(col("id") > 2, col("id")).alias("id"),
+ when(col("id") > 2, struct(when(col("id") > 3, col("id")).alias("id")))
+ .as("nested2")))
+ .alias("nested1"))
+
+ df.write.parquet(dir.toString())
+ }
+
+ Seq("", "parquet").foreach { v1List =>
+ withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> v1List) {
+ val df = spark.read.parquet(dir.toString())
+ checkSparkAnswerAndOperator(df.select("nested1.id"))
+ checkSparkAnswerAndOperator(df.select("nested1.nested2"))
+ checkSparkAnswerAndOperator(df.select("nested1.nested2.id"))
+ checkSparkAnswerAndOperator(df.select("nested1.id", "nested1.nested2.id"))
+ }
+ }
+ }
+ }
+
+ ignore("get_struct_field - read entire struct") {
+ withTempPath { dir =>
+ // create input file with Comet disabled
+ withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+ val df = spark
+ .range(5)
+ // Add both a null struct and null inner value
+ .select(
+ when(
+ col("id") > 1,
+ struct(
+ when(col("id") > 2, col("id")).alias("id"),
+ when(col("id") > 2, struct(when(col("id") > 3, col("id")).alias("id")))
+ .as("nested2")))
+ .alias("nested1"))
+
+ df.write.parquet(dir.toString())
+ }
+
+ Seq("", "parquet").foreach { v1List =>
+ withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> v1List) {
+ val df = spark.read.parquet(dir.toString())
+ checkSparkAnswerAndOperator(df.select("nested1"))
+ }
+ }
+ }
+ }
+
+ ignore("read map[int, int] from parquet") {
+ withTempPath { dir =>
+ // create input file with Comet disabled
+ withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+ val df = spark
+ .range(5)
+ // Spark does not allow null as a key but does allow null as a
+ // value, and the entire map be null
+ .select(
+ when(col("id") > 1, map(col("id"), when(col("id") > 2, col("id")))).alias("map1"))
+ df.write.parquet(dir.toString())
+ }
+
+ Seq("", "parquet").foreach { v1List =>
+ withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> v1List) {
+ val df = spark.read.parquet(dir.toString())
+ checkSparkAnswerAndOperator(df.select("map1"))
+ checkSparkAnswerAndOperator(df.select(map_keys(col("map1"))))
+ checkSparkAnswerAndOperator(df.select(map_values(col("map1"))))
+ }
+ }
+ }
+ }
+
+ ignore("read array[int] from parquet") {
+ withTempPath { dir =>
+ // create input file with Comet disabled
+ withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+ val df = spark
+ .range(5)
+ // Spark does not allow null as a key but does allow null as a
+ // value, and the entire map be null
+ .select(when(col("id") > 1, sequence(lit(0), col("id") * 2)).alias("array1"))
+ df.write.parquet(dir.toString())
+ }
+
+ Seq("", "parquet").foreach { v1List =>
+ withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> v1List) {
+ val df = spark.read.parquet(dir.toString())
+ checkSparkAnswerAndOperator(df.select("array1"))
+ checkSparkAnswerAndOperator(df.select(element_at(col("array1"), lit(1))))
+ }
+ }
+ }
+ }
+
test("get_struct_field with DataFusion ParquetExec - simple case") {
withTempPath { dir =>
// create input file with Comet disabled
@@ -2412,4 +2534,86 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}
+
+ test("array_append") {
+ assume(isSpark34Plus)
+ Seq(true, false).foreach { dictionaryEnabled =>
+ withTempDir { dir =>
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1");
+ checkSparkAnswerAndOperator(spark.sql("Select array_append(array(_1),false) from t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_append(array(_2, _3, _4), 4) FROM t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_append(array(_2, _3, _4), null) FROM t1"));
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_append(array(_6, _7), CAST(6.5 AS DOUBLE)) FROM t1"));
+ checkSparkAnswerAndOperator(spark.sql("SELECT array_append(array(_8), 'test') FROM t1"));
+ checkSparkAnswerAndOperator(spark.sql("SELECT array_append(array(_19), _19) FROM t1"));
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_append((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
+ }
+ }
+ }
+
+ test("array_prepend") {
+ assume(isSpark35Plus) // in Spark 3.5 array_prepend is implemented via array_insert
+ Seq(true, false).foreach { dictionaryEnabled =>
+ withTempDir { dir =>
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)
+ spark.read.parquet(path.toString).createOrReplaceTempView("t1");
+ checkSparkAnswerAndOperator(spark.sql("Select array_prepend(array(_1),false) from t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_prepend(array(_2, _3, _4), 4) FROM t1"))
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_prepend(array(_2, _3, _4), null) FROM t1"));
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_prepend(array(_6, _7), CAST(6.5 AS DOUBLE)) FROM t1"));
+ checkSparkAnswerAndOperator(spark.sql("SELECT array_prepend(array(_8), 'test') FROM t1"));
+ checkSparkAnswerAndOperator(spark.sql("SELECT array_prepend(array(_19), _19) FROM t1"));
+ checkSparkAnswerAndOperator(
+ spark.sql("SELECT array_prepend((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
+ }
+ }
+ }
+
+ test("ArrayInsert") {
+ assume(isSpark34Plus)
+ Seq(true, false).foreach(dictionaryEnabled =>
+ withTempDir { dir =>
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
+ val df = spark.read
+ .parquet(path.toString)
+ .withColumn("arr", array(col("_4"), lit(null), col("_4")))
+ .withColumn("arrInsertResult", expr("array_insert(arr, 1, 1)"))
+ .withColumn("arrInsertNegativeIndexResult", expr("array_insert(arr, -1, 1)"))
+ .withColumn("arrPosGreaterThanSize", expr("array_insert(arr, 8, 1)"))
+ .withColumn("arrNegPosGreaterThanSize", expr("array_insert(arr, -8, 1)"))
+ .withColumn("arrInsertNone", expr("array_insert(arr, 1, null)"))
+ checkSparkAnswerAndOperator(df.select("arrInsertResult"))
+ checkSparkAnswerAndOperator(df.select("arrInsertNegativeIndexResult"))
+ checkSparkAnswerAndOperator(df.select("arrPosGreaterThanSize"))
+ checkSparkAnswerAndOperator(df.select("arrNegPosGreaterThanSize"))
+ checkSparkAnswerAndOperator(df.select("arrInsertNone"))
+ })
+ }
+
+ test("ArrayInsertUnsupportedArgs") {
+ // This test checks that the else branch in ArrayInsert
+ // mapping to the comet is valid and fallback to spark is working fine.
+ assume(isSpark34Plus)
+ withTempDir { dir =>
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllTypes(path, dictionaryEnabled = false, 10000)
+ val df = spark.read
+ .parquet(path.toString)
+ .withColumn("arr", array(col("_4"), lit(null), col("_4")))
+ .withColumn("idx", udf((_: Int) => 1).apply(col("_4")))
+ .withColumn("arrUnsupportedArgs", expr("array_insert(arr, idx, 1)"))
+ checkSparkAnswer(df.select("arrUnsupportedArgs"))
+ }
+ }
}
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
index ecc056ddd..6130e4cd5 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
@@ -40,6 +40,7 @@ import org.apache.comet.CometConf
abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper {
protected val adaptiveExecutionEnabled: Boolean
protected val numElementsForceSpillThreshold: Int = 10
+ protected val useUnifiedMemoryAllocator: Boolean = true
override protected def sparkConf: SparkConf = {
val conf = super.sparkConf
@@ -57,6 +58,8 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
CometConf.COMET_COLUMNAR_SHUFFLE_SPILL_THRESHOLD.key -> numElementsForceSpillThreshold.toString,
CometConf.COMET_EXEC_ENABLED.key -> "false",
CometConf.COMET_SHUFFLE_MODE.key -> "jvm",
+ CometConf.COMET_COLUMNAR_SHUFFLE_UNIFIED_MEMORY_ALLOCATOR_IN_TEST.key ->
+ useUnifiedMemoryAllocator.toString,
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_COLUMNAR_SHUFFLE_MEMORY_SIZE.key -> "1536m") {
testFun
@@ -747,6 +750,10 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
$"_6",
$"_7",
$"_8",
+ $"_9",
+ $"_10",
+ $"_11",
+ $"_12",
$"_13",
$"_14",
$"_15",
@@ -968,6 +975,13 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
}
}
+class CometTestMemoryAllocatorShuffleSuite extends CometColumnarShuffleSuite {
+ override protected val asyncShuffleEnable: Boolean = false
+ override protected val adaptiveExecutionEnabled: Boolean = true
+ // Explicitly test with `CometTestShuffleMemoryAllocator`
+ override protected val useUnifiedMemoryAllocator: Boolean = false
+}
+
class CometAsyncShuffleSuite extends CometColumnarShuffleSuite {
override protected val asyncShuffleEnable: Boolean = true
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index a54b70ea4..1d1af7b3e 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -39,6 +39,7 @@ import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHas
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
+import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
@@ -129,7 +130,7 @@ class CometExecSuite extends CometTestBase {
sql(
"CREATE VIEW lv_noalias AS SELECT myTab.* FROM src " +
"LATERAL VIEW explode(map('key1', 100, 'key2', 200)) myTab LIMIT 2")
- val df = sql("SELECT * FROM lv_noalias a JOIN lv_noalias b ON a.key=b.key")
+ val df = sql("SELECT * FROM lv_noalias a JOIN lv_noalias b ON a.key=b.key");
checkSparkAnswer(df)
}
}
@@ -1889,6 +1890,14 @@ class CometExecSuite extends CometTestBase {
}
}
}
+
+ test("Supported file formats for CometScanExec") {
+ assert(CometScanExec.isFileFormatSupported(new ParquetFileFormat()))
+
+ class CustomParquetFileFormat extends ParquetFileFormat {}
+
+ assert(!CometScanExec.isFileFormatSupported(new CustomParquetFileFormat()))
+ }
}
case class BucketedTableTestSpec(
diff --git a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala
index 65fe94591..b97865a1f 100644
--- a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala
@@ -433,8 +433,8 @@ abstract class ParquetReadSuite extends CometTestBase {
i.toFloat,
i.toDouble,
i.toString * 48,
- java.lang.Byte.toUnsignedInt((-i).toByte),
- java.lang.Short.toUnsignedInt((-i).toShort),
+ (-i).toByte,
+ (-i).toShort,
java.lang.Integer.toUnsignedLong(-i),
new BigDecimal(UnsignedLong.fromLongBits((-i).toLong).bigIntegerValue()),
i.toString,
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index e997c5bfd..39af52e90 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -234,11 +234,9 @@ abstract class CometTestBase
df: => DataFrame): (Option[Throwable], Option[Throwable]) = {
var expected: Option[Throwable] = None
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
- val dfSpark = Dataset.ofRows(spark, df.logicalPlan)
- expected = Try(dfSpark.collect()).failed.toOption
+ expected = Try(Dataset.ofRows(spark, df.logicalPlan).collect()).failed.toOption
}
- val dfComet = Dataset.ofRows(spark, df.logicalPlan)
- val actual = Try(dfComet.collect()).failed.toOption
+ val actual = Try(Dataset.ofRows(spark, df.logicalPlan).collect()).failed.toOption
(expected, actual)
}
diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala
index 3dd930f67..3ee37bd66 100644
--- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala
@@ -274,23 +274,23 @@ object CometExecBenchmark extends CometBenchmarkBase {
}
override def runCometBenchmark(mainArgs: Array[String]): Unit = {
-// runBenchmarkWithTable("Subquery", 1024 * 1024 * 10) { v =>
-// subqueryExecBenchmark(v)
-// }
-//
-// runBenchmarkWithTable("Expand", 1024 * 1024 * 10) { v =>
-// expandExecBenchmark(v)
-// }
-//
-// runBenchmarkWithTable("Project + Filter", 1024 * 1024 * 10) { v =>
-// for (fractionOfZeros <- List(0.0, 0.50, 0.95)) {
-// numericFilterExecBenchmark(v, fractionOfZeros)
-// }
-// }
-//
-// runBenchmarkWithTable("Sort", 1024 * 1024 * 10) { v =>
-// sortExecBenchmark(v)
-// }
+ runBenchmarkWithTable("Subquery", 1024 * 1024 * 10) { v =>
+ subqueryExecBenchmark(v)
+ }
+
+ runBenchmarkWithTable("Expand", 1024 * 1024 * 10) { v =>
+ expandExecBenchmark(v)
+ }
+
+ runBenchmarkWithTable("Project + Filter", 1024 * 1024 * 10) { v =>
+ for (fractionOfZeros <- List(0.0, 0.50, 0.95)) {
+ numericFilterExecBenchmark(v, fractionOfZeros)
+ }
+ }
+
+ runBenchmarkWithTable("Sort", 1024 * 1024 * 10) { v =>
+ sortExecBenchmark(v)
+ }
runBenchmarkWithTable("BloomFilterAggregate", 1024 * 1024 * 10) { v =>
for (card <- List(100, 1024, 1024 * 1024)) {
diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
index 080655fe2..c3513e59e 100644
--- a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala
@@ -26,6 +26,7 @@ import scala.collection.mutable
import org.apache.commons.io.FileUtils
import org.apache.spark.SparkContext
+import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE}
import org.apache.spark.sql.TPCDSBase
import org.apache.spark.sql.catalyst.expressions.AttributeSet
import org.apache.spark.sql.catalyst.util.resourceToString
@@ -293,6 +294,8 @@ trait CometPlanStabilitySuite extends DisableAdaptiveExecutionSuite with TPCDSBa
conf.set(
"spark.shuffle.manager",
"org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager")
+ conf.set(MEMORY_OFFHEAP_ENABLED.key, "true")
+ conf.set(MEMORY_OFFHEAP_SIZE.key, "2g")
conf.set(CometConf.COMET_ENABLED.key, "true")
conf.set(CometConf.COMET_EXEC_ENABLED.key, "true")
conf.set(CometConf.COMET_NATIVE_SCAN_ENABLED.key, "true")