diff --git a/dist/pom.xml b/dist/pom.xml
index ce3cb054d7b..c8643f14559 100644
--- a/dist/pom.xml
+++ b/dist/pom.xml
@@ -354,6 +354,11 @@
+
+ org.apache.spark
+ spark-avro_${scala.binary.version}
+ ${spark.version}
+
diff --git a/docs/compatibility.md b/docs/compatibility.md
index 4393cb1cf51..ada04f79e93 100644
--- a/docs/compatibility.md
+++ b/docs/compatibility.md
@@ -558,6 +558,18 @@ parse some variants of `NaN` and `Infinity` even when this option is disabled
([SPARK-38060](https://issues.apache.org/jira/browse/SPARK-38060)). The RAPIDS Accelerator behavior is consistent with
Spark version 3.3.0 and later.
+## Avro
+
+The Avro format read is a very experimental feature which is expected to have some issues, so we disable
+it by default. If you would like to test it, you need to enable `spark.rapids.sql.format.avro.enabled` and
+`spark.rapids.sql.format.avro.read.enabled`.
+
+Currently, the GPU accelerated Avro reader doesn't support reading the Avro version 1.2 files.
+
+### Supported types
+
+The boolean, byte, short, int, long, float, double, string are supported in current version.
+
## Regular Expressions
Regular expression evaluation on the GPU can potentially have high memory overhead and cause out-of-memory errors so
diff --git a/docs/configs.md b/docs/configs.md
index 61ff2d91d2a..dd7ab036bd2 100644
--- a/docs/configs.md
+++ b/docs/configs.md
@@ -72,6 +72,8 @@ Name | Description | Default Value
spark.rapids.sql.enabled|Enable (true) or disable (false) sql operations on the GPU|true
spark.rapids.sql.explain|Explain why some parts of a query were not placed on a GPU or not. Possible values are ALL: print everything, NONE: print nothing, NOT_ON_GPU: print only parts of a query that did not go on the GPU|NONE
spark.rapids.sql.fast.sample|Option to turn on fast sample. If enable it is inconsistent with CPU sample because of GPU sample algorithm is inconsistent with CPU.|false
+spark.rapids.sql.format.avro.enabled|When set to true enables all avro input and output acceleration. (only input is currently supported anyways)|false
+spark.rapids.sql.format.avro.read.enabled|When set to true enables avro input acceleration|false
spark.rapids.sql.format.csv.enabled|When set to false disables all csv input and output acceleration. (only input is currently supported anyways)|true
spark.rapids.sql.format.csv.read.enabled|When set to false disables csv input acceleration|true
spark.rapids.sql.format.json.enabled|When set to true enables all json input and output acceleration. (only input is currently supported anyways)|false
@@ -390,6 +392,7 @@ Name | Description | Default Value | Notes
spark.rapids.sql.input.JsonScan|Json parsing|true|None|
spark.rapids.sql.input.OrcScan|ORC parsing|true|None|
spark.rapids.sql.input.ParquetScan|Parquet parsing|true|None|
+spark.rapids.sql.input.AvroScan|Avro parsing|true|None|
### Partitioning
diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index b94adb770b5..16ed1d7ec45 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -17914,6 +17914,49 @@ dates or timestamps, or for a lack of type coercion support.
UDT |
+Avro |
+Read |
+S |
+S |
+S |
+S |
+S |
+S |
+S |
+NS |
+NS |
+S |
+NS |
+ |
+NS |
+ |
+NS |
+NS |
+NS |
+NS |
+
+
+Write |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+ |
+NS |
+ |
+NS |
+NS |
+NS |
+NS |
+
+
CSV |
Read |
S |
diff --git a/integration_tests/pom.xml b/integration_tests/pom.xml
index eb1feb329c6..cd994ffda17 100644
--- a/integration_tests/pom.xml
+++ b/integration_tests/pom.xml
@@ -297,14 +297,19 @@
copy
-
- true
+
+ true
ai.rapids
cudf
${cuda.version}
+
+ org.apache.spark
+ spark-avro_${scala.binary.version}
+ ${spark.version}
+
diff --git a/integration_tests/run_pyspark_from_build.sh b/integration_tests/run_pyspark_from_build.sh
index f81c4cef728..b1417aec9c2 100755
--- a/integration_tests/run_pyspark_from_build.sh
+++ b/integration_tests/run_pyspark_from_build.sh
@@ -40,18 +40,35 @@ else
# support alternate local jars NOT building from the source code
if [ -d "$LOCAL_JAR_PATH" ]; then
CUDF_JARS=$(echo "$LOCAL_JAR_PATH"/cudf-*.jar)
+ AVRO_JARS=$(echo "$LOCAL_JAR_PATH"/spark-avro*.jar)
PLUGIN_JARS=$(echo "$LOCAL_JAR_PATH"/rapids-4-spark_*.jar)
# the integration-test-spark3xx.jar, should not include the integration-test-spark3xxtest.jar
TEST_JARS=$(echo "$LOCAL_JAR_PATH"/rapids-4-spark-integration-tests*-$SPARK_SHIM_VER.jar)
else
CUDF_JARS=$(echo "$SCRIPTPATH"/target/dependency/cudf-*.jar)
+ AVRO_JARS=$(echo "$SCRIPTPATH"/target/dependency/spark-avro*.jar)
PLUGIN_JARS=$(echo "$SCRIPTPATH"/../dist/target/rapids-4-spark_*.jar)
# the integration-test-spark3xx.jar, should not include the integration-test-spark3xxtest.jar
TEST_JARS=$(echo "$SCRIPTPATH"/target/rapids-4-spark-integration-tests*-$SPARK_SHIM_VER.jar)
fi
+ # `./run_pyspark_from_build.sh` runs all tests including avro_test.py with spark-avro.jar
+ # in the classpath.
+ #
+ # `./run_pyspark_from_build.sh -k xxx ` runs all xxx tests with spark-avro.jar in the classpath
+ #
+ # `INCLUDE_SPARK_AVRO_JAR=true ./run_pyspark_from_build.sh` run all tests (except the marker skipif())
+ # without spark-avro.jar
+ if [[ $( echo ${INCLUDE_SPARK_AVRO_JAR} | tr [:upper:] [:lower:] ) == "true" ]];
+ then
+ export INCLUDE_SPARK_AVRO_JAR=true
+ else
+ export INCLUDE_SPARK_AVRO_JAR=false
+ AVRO_JARS=""
+ fi
+
# Only 3 jars: cudf.jar dist.jar integration-test.jar
- ALL_JARS="$CUDF_JARS $PLUGIN_JARS $TEST_JARS"
+ ALL_JARS="$CUDF_JARS $PLUGIN_JARS $TEST_JARS $AVRO_JARS"
echo "AND PLUGIN JARS: $ALL_JARS"
if [[ "${TEST}" != "" ]];
then
diff --git a/integration_tests/src/main/python/avro_test.py b/integration_tests/src/main/python/avro_test.py
new file mode 100644
index 00000000000..418701d8e8e
--- /dev/null
+++ b/integration_tests/src/main/python/avro_test.py
@@ -0,0 +1,90 @@
+# Copyright (c) 2022, NVIDIA CORPORATION.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+from spark_session import with_cpu_session
+import pytest
+
+from asserts import assert_gpu_and_cpu_are_equal_collect
+from data_gen import *
+from marks import *
+from pyspark.sql.types import *
+
+if os.environ.get('INCLUDE_SPARK_AVRO_JAR', 'false') == 'false':
+ pytestmark = pytest.mark.skip(reason=str("INCLUDE_SPARK_AVRO_JAR is disabled"))
+
+support_gens = numeric_gens + [string_gen, boolean_gen]
+
+_enable_all_types_conf = {
+ 'spark.rapids.sql.format.avro.enabled': 'true',
+ 'spark.rapids.sql.format.avro.read.enabled': 'true'}
+
+
+@pytest.mark.parametrize('gen', support_gens)
+@pytest.mark.parametrize('v1_enabled_list', ["avro", ""])
+def test_basic_read(spark_tmp_path, gen, v1_enabled_list):
+ data_path = spark_tmp_path + '/AVRO_DATA'
+ with_cpu_session(
+ lambda spark: unary_op_df(spark, gen).write.format("avro").save(data_path)
+ )
+
+ all_confs = copy_and_update(_enable_all_types_conf, {
+ 'spark.sql.sources.useV1SourceList': v1_enabled_list})
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark: spark.read.format("avro").load(data_path),
+ conf=all_confs)
+
+
+@pytest.mark.parametrize('v1_enabled_list', ["", "avro"])
+def test_avro_simple_partitioned_read(spark_tmp_path, v1_enabled_list):
+ gen_list = [('_c' + str(i), gen) for i, gen in enumerate(support_gens)]
+ first_data_path = spark_tmp_path + '/AVRO_DATA/key=0/key2=20'
+ with_cpu_session(
+ lambda spark: gen_df(spark, gen_list).write.format("avro").save(first_data_path))
+ second_data_path = spark_tmp_path + '/AVRO_DATA/key=1/key2=21'
+ with_cpu_session(
+ lambda spark: gen_df(spark, gen_list).write.format("avro").save(second_data_path))
+ third_data_path = spark_tmp_path + '/AVRO_DATA/key=2/key2=22'
+ with_cpu_session(
+ lambda spark: gen_df(spark, gen_list).write.format("avro").save(third_data_path))
+
+ data_path = spark_tmp_path + '/AVRO_DATA'
+
+ all_confs = copy_and_update(_enable_all_types_conf, {
+ 'spark.sql.sources.useV1SourceList': v1_enabled_list})
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark: spark.read.format("avro").load(data_path),
+ conf=all_confs)
+
+
+@pytest.mark.parametrize('v1_enabled_list', ["", "avro"])
+def test_avro_input_meta(spark_tmp_path, v1_enabled_list):
+ first_data_path = spark_tmp_path + '/AVRO_DATA/key=0'
+ with_cpu_session(
+ lambda spark: unary_op_df(spark, long_gen).write.format("avro").save(first_data_path))
+ second_data_path = spark_tmp_path + '/AVRO_DATA/key=1'
+ with_cpu_session(
+ lambda spark: unary_op_df(spark, long_gen).write.format("avro").save(second_data_path))
+ data_path = spark_tmp_path + '/AVRO_DATA'
+
+ all_confs = copy_and_update(_enable_all_types_conf, {
+ 'spark.sql.sources.useV1SourceList': v1_enabled_list})
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark: spark.read.format("avro").load(data_path)
+ .filter(f.col('a') > 0)
+ .selectExpr('a',
+ 'input_file_name()',
+ 'input_file_block_start()',
+ 'input_file_block_length()'),
+ conf=all_confs)
diff --git a/jenkins/databricks/build.sh b/jenkins/databricks/build.sh
index b6d72fb67d2..89ab24d0d68 100755
--- a/jenkins/databricks/build.sh
+++ b/jenkins/databricks/build.sh
@@ -158,6 +158,17 @@ JACKSONANNOTATION=----workspace_${SPARK_MAJOR_VERSION_STRING}--maven-trees--hive
HADOOPCOMMON=----workspace_${SPARK_MAJOR_VERSION_STRING}--maven-trees--hive-2.3__hadoop-${HADOOP_VERSION}--org.apache.hadoop--hadoop-common--org.apache.hadoop__hadoop-common__2.7.4.jar
HADOOPMAPRED=----workspace_${SPARK_MAJOR_VERSION_STRING}--maven-trees--hive-2.3__hadoop-${HADOOP_VERSION}--org.apache.hadoop--hadoop-mapreduce-client-core--org.apache.hadoop__hadoop-mapreduce-client-core__2.7.4.jar
+if [[ $BASE_SPARK_VERSION == "3.2.1" ]]
+then
+ AVROSPARKJAR=----workspace_${SPARK_MAJOR_VERSION_STRING}--vendor--avro--avro-hive-2.3__hadoop-3.2_2.12_deploy_shaded.jar
+ AVROMAPRED=----workspace_${SPARK_MAJOR_VERSION_STRING}--maven-trees--hive-2.3__hadoop-3.2--org.apache.avro--avro-mapred--org.apache.avro__avro-mapred__1.10.2.jar
+ AVROJAR=----workspace_${SPARK_MAJOR_VERSION_STRING}--maven-trees--hive-2.3__hadoop-3.2--org.apache.avro--avro--org.apache.avro__avro__1.10.2.jar
+else
+ AVROSPARKJAR=----workspace_${SPARK_MAJOR_VERSION_STRING}--vendor--avro--avro_2.12_deploy_shaded.jar
+ AVROMAPRED=----workspace_${SPARK_MAJOR_VERSION_STRING}--maven-trees--hive-2.3__hadoop-2.7--org.apache.avro--avro-mapred-hadoop2--org.apache.avro__avro-mapred-hadoop2__1.8.2.jar
+ AVROJAR=----workspace_${SPARK_MAJOR_VERSION_STRING}--maven-trees--hive-2.3__hadoop-2.7--org.apache.avro--avro--org.apache.avro__avro__1.8.2.jar
+fi
+
# Please note we are installing all of these dependencies using the Spark version (SPARK_VERSION_TO_INSTALL_DATABRICKS_JARS) to make it easier
# to specify the dependencies in the pom files
@@ -177,6 +188,30 @@ mvn -B install:install-file \
-Dversion=$SPARK_VERSION_TO_INSTALL_DATABRICKS_JARS \
-Dpackaging=jar
+mvn -B install:install-file \
+ -Dmaven.repo.local=$M2DIR \
+ -Dfile=$JARDIR/$AVROSPARKJAR\
+ -DgroupId=org.apache.spark \
+ -DartifactId=spark-avro_$SCALA_VERSION \
+ -Dversion=$SPARK_VERSION_TO_INSTALL_DATABRICKS_JARS \
+ -Dpackaging=jar
+
+mvn -B install:install-file \
+ -Dmaven.repo.local=$M2DIR \
+ -Dfile=$JARDIR/$AVROMAPRED\
+ -DgroupId=org.apache.avro\
+ -DartifactId=avro-mapred \
+ -Dversion=$SPARK_VERSION_TO_INSTALL_DATABRICKS_JARS \
+ -Dpackaging=jar
+
+mvn -B install:install-file \
+ -Dmaven.repo.local=$M2DIR \
+ -Dfile=$JARDIR/$AVROJAR \
+ -DgroupId=org.apache.avro\
+ -DartifactId=avro \
+ -Dversion=$SPARK_VERSION_TO_INSTALL_DATABRICKS_JARS \
+ -Dpackaging=jar
+
mvn -B install:install-file \
-Dmaven.repo.local=$M2DIR \
-Dfile=$JARDIR/$ANNOTJAR \
diff --git a/jenkins/spark-premerge-build.sh b/jenkins/spark-premerge-build.sh
index 8eccfa3bee1..be7ab5f5ace 100755
--- a/jenkins/spark-premerge-build.sh
+++ b/jenkins/spark-premerge-build.sh
@@ -113,6 +113,7 @@ ci_2() {
TEST_PARALLEL=5 TEST='struct_test or time_window_test' ./integration_tests/run_pyspark_from_build.sh
TEST='not conditionals_test and not window_function_test and not struct_test and not time_window_test' \
./integration_tests/run_pyspark_from_build.sh
+ INCLUDE_SPARK_AVRO_JAR=true TEST='avro_test.py' ./integration_tests/run_pyspark_from_build.sh
}
diff --git a/sql-plugin/pom.xml b/sql-plugin/pom.xml
index 9897e7c7293..5ef4ff7fcf8 100644
--- a/sql-plugin/pom.xml
+++ b/sql-plugin/pom.xml
@@ -56,6 +56,12 @@
com.google.flatbuffers
flatbuffers-java
+
+ org.apache.spark
+ spark-avro_${scala.binary.version}
+ ${spark.version}
+ provided
+
@@ -119,6 +125,24 @@
org.apache.spark
spark-sql_${scala.binary.version}
+
+ org.apache.spark
+ spark-avro_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.avro
+ avro-mapred
+ ${spark.version}
+ provided
+
+
+ org.apache.avro
+ avro
+ ${spark.version}
+ provided
+
org.apache.hive
hive-exec
diff --git a/sql-plugin/src/main/311until320-all/scala/org/apache/spark/sql/rapids/shims/AvroUtils.scala b/sql-plugin/src/main/311until320-all/scala/org/apache/spark/sql/rapids/shims/AvroUtils.scala
new file mode 100644
index 00000000000..a7c15ebab41
--- /dev/null
+++ b/sql-plugin/src/main/311until320-all/scala/org/apache/spark/sql/rapids/shims/AvroUtils.scala
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids.shims
+
+import com.nvidia.spark.rapids.RapidsMeta
+
+import org.apache.spark.sql.avro.AvroOptions
+
+object AvroUtils {
+
+ def tagSupport(
+ parsedOptions: AvroOptions,
+ meta: RapidsMeta[_, _, _]): Unit = {
+
+ }
+
+}
diff --git a/sql-plugin/src/main/320+/scala/org/apache/spark/sql/rapids/shims/AvroUtils.scala b/sql-plugin/src/main/320+/scala/org/apache/spark/sql/rapids/shims/AvroUtils.scala
new file mode 100644
index 00000000000..464ef92f54c
--- /dev/null
+++ b/sql-plugin/src/main/320+/scala/org/apache/spark/sql/rapids/shims/AvroUtils.scala
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids.shims
+
+import com.nvidia.spark.rapids.RapidsMeta
+
+import org.apache.spark.sql.avro.AvroOptions
+
+object AvroUtils {
+
+ def tagSupport(
+ parsedOptions: AvroOptions,
+ meta: RapidsMeta[_, _, _]): Unit = {
+
+ if (parsedOptions.positionalFieldMatching) {
+ meta.willNotWorkOnGpu("positional field matching is not supported")
+ }
+ }
+
+}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AvroDataReader.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AvroDataReader.scala
new file mode 100644
index 00000000000..1b999ee40ad
--- /dev/null
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AvroDataReader.scala
@@ -0,0 +1,198 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids
+
+import java.io.{InputStream, IOException}
+import java.nio.charset.StandardCharsets
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.avro.Schema
+import org.apache.avro.file.{DataFileConstants, SeekableInput}
+import org.apache.avro.file.DataFileConstants.{MAGIC, SYNC_SIZE}
+import org.apache.avro.io.{BinaryData, BinaryDecoder, DecoderFactory}
+
+private class SeekableInputStream(in: SeekableInput) extends InputStream with SeekableInput {
+ var oneByte = new Array[Byte](1)
+
+ override def read(): Int = {
+ val n = read(oneByte, 0, 1)
+ if (n == 1) return oneByte(0) & 0xff else return n
+ }
+
+ override def read(b: Array[Byte]): Int = in.read(b, 0, b.length)
+
+ override def read(b: Array[Byte], off: Int, len: Int): Int = in.read(b, off, len)
+
+ override def seek(p: Long): Unit = {
+ if (p < 0) throw new IOException("Illegal seek: " + p)
+ in.seek(p)
+ }
+
+ override def tell(): Long = in.tell()
+
+ override def length(): Long = in.length()
+
+ override def close(): Unit = {
+ in.close()
+ super.close()
+ }
+
+ override def available(): Int = {
+ val remaining = in.length() - in.tell()
+ if (remaining > Int.MaxValue) Int.MaxValue else remaining.toInt
+ }
+}
+
+/**
+ * The header information of Avro file
+ */
+class Header {
+ var meta = Map[String, Array[Byte]]()
+ var metaKeyList = ArrayBuffer[String]()
+ var sync = new Array[Byte](DataFileConstants.SYNC_SIZE)
+ var schema: Schema = _
+ private var firstBlockStart: Long = _
+
+ private[rapids] def update(schemaValue: String, firstBlockStart: Long) = {
+ schema = new Schema.Parser().setValidate(false).setValidateDefaults(false)
+ .parse(schemaValue)
+ this.firstBlockStart = firstBlockStart
+ }
+
+ def getFirstBlockStart: Long = firstBlockStart
+}
+
+/**
+ * The each Avro block information
+ *
+ * @param blockStart the start of block
+ * @param blockLength the whole block length = the size between two sync buffers + sync buffer
+ * @param blockSize the block data size
+ * @param count how many entries in this block
+ */
+case class BlockInfo(blockStart: Long, blockLength: Long, blockDataSize: Long, count: Long)
+
+/**
+ * AvroDataFileReader parses the Avro file to get the header and all block information
+ */
+class AvroDataFileReader(si: SeekableInput) extends AutoCloseable {
+ private val sin = new SeekableInputStream(si)
+ sin.seek(0) // seek to the start of file and get some meta info.
+ private var vin: BinaryDecoder = DecoderFactory.get.binaryDecoder(sin, vin);
+ private val header: Header = new Header()
+ private var firstBlockStart: Long = 0
+
+ // store all blocks info
+ private val blocks: ArrayBuffer[BlockInfo] = ArrayBuffer.empty
+
+ initialize()
+
+ def getBlocks(): ArrayBuffer[BlockInfo] = {
+ blocks
+ }
+
+ def getHeader(): Header = header
+
+ private def initialize() = {
+ val magic = new Array[Byte](MAGIC.length)
+ vin.readFixed(magic)
+
+ magic match {
+ case Array(79, 98, 106, 1) => // current avro format
+ case Array(79, 98, 106, 0) => // old format
+ throw new UnsupportedOperationException("avro 1.2 format is not support by GPU")
+ case _ => throw new RuntimeException("Not an Avro data file.")
+ }
+
+ var l = vin.readMapStart().toInt
+ if (l > 0) {
+ do {
+ for (i <- 1 to l) {
+ val key = vin.readString(null).toString
+ val value = vin.readBytes(null)
+ val bb = new Array[Byte](value.remaining())
+ value.get(bb)
+ header.meta += (key -> bb)
+ header.metaKeyList += key
+ }
+ l = vin.mapNext().toInt
+ } while (l != 0)
+ }
+ vin.readFixed(header.sync)
+ firstBlockStart = sin.tell - vin.inputStream.available // get the first block Start address
+ header.update(getMetaString(DataFileConstants.SCHEMA), firstBlockStart)
+ parseBlocks()
+ }
+
+ private def seek(position: Long): Unit = {
+ sin.seek(position)
+ vin = DecoderFactory.get().binaryDecoder(this.sin, vin);
+ }
+
+ private def parseBlocks(): Unit = {
+ if (firstBlockStart >= sin.length() || vin.isEnd()) {
+ // no blocks
+ return
+ }
+ // buf is used for writing long
+ val buf = new Array[Byte](12)
+ var blockStart = firstBlockStart
+ while (blockStart < sin.length()) {
+ seek(blockStart)
+ if (vin.isEnd()) {
+ return
+ }
+ val blockCount = vin.readLong()
+ val blockDataSize = vin.readLong()
+ if (blockDataSize > Integer.MAX_VALUE || blockDataSize < 0) {
+ throw new IOException("Block size invalid or too large: " + blockDataSize)
+ }
+
+ // Get how many bytes used to store the value of count and block data size.
+ val blockCountLen = BinaryData.encodeLong(blockCount, buf, 0)
+ val blockDataSizeLen: Int = BinaryData.encodeLong(blockDataSize, buf, 0)
+
+ // (len of entries) + (len of block size) + (block size) + (sync size)
+ val blockLength = blockCountLen + blockDataSizeLen + blockDataSize + SYNC_SIZE
+ blocks += BlockInfo(blockStart, blockLength, blockDataSize, blockCount)
+
+ // Do we need to check the SYNC BUFFER, or just let cudf do it?
+ blockStart += blockLength
+ }
+ }
+
+ /** Return the value of a metadata property. */
+ private def getMeta(key: String): Array[Byte] = header.meta.getOrElse(key, new Array[Byte](1))
+
+ private def getMetaString(key: String): String = {
+ val value = getMeta(key)
+ if (value == null) return null
+ new String(value, StandardCharsets.UTF_8)
+ }
+
+ override def close(): Unit = {
+ vin.inputStream().close()
+ }
+}
+
+object AvroDataFileReader {
+
+ def openReader(si: SeekableInput): AvroDataFileReader = {
+ new AvroDataFileReader(si)
+ }
+}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index 61c66adb941..ef0342f2fe8 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -429,6 +429,9 @@ object OrcFormatType extends FileFormatType {
object JsonFormatType extends FileFormatType {
override def toString = "JSON"
}
+object AvroFormatType extends FileFormatType {
+ override def toString = "Avro"
+}
sealed trait FileFormatOp
object ReadFileOp extends FileFormatOp {
@@ -825,6 +828,12 @@ object GpuOverrides extends Logging {
(JsonFormatType, FileFormatChecks(
cudfRead = TypeSig.commonCudfTypes + TypeSig.DECIMAL_128,
cudfWrite = TypeSig.none,
+ sparkSig = (TypeSig.cpuAtomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP +
+ TypeSig.UDT).nested())),
+ (AvroFormatType, FileFormatChecks(
+ cudfRead = TypeSig.BOOLEAN + TypeSig.BYTE + TypeSig.SHORT + TypeSig.INT + TypeSig.LONG +
+ TypeSig.FLOAT + TypeSig.DOUBLE + TypeSig.STRING,
+ cudfWrite = TypeSig.none,
sparkSig = (TypeSig.cpuAtomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP +
TypeSig.UDT).nested())))
@@ -3473,7 +3482,7 @@ object GpuOverrides extends Logging {
})).map(r => (r.getClassFor.asSubclass(classOf[Scan]), r)).toMap
val scans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] =
- commonScans ++ SparkShimImpl.getScans
+ commonScans ++ SparkShimImpl.getScans ++ ExternalSource.getScans
def wrapPart[INPUT <: Partitioning](
part: INPUT,
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
index 59f83e001a6..d03614b66af 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
@@ -908,6 +908,17 @@ object RapidsConf {
.booleanConf
.createWithDefault(false)
+ val ENABLE_AVRO = conf("spark.rapids.sql.format.avro.enabled")
+ .doc("When set to true enables all avro input and output acceleration. " +
+ "(only input is currently supported anyways)")
+ .booleanConf
+ .createWithDefault(false)
+
+ val ENABLE_AVRO_READ = conf("spark.rapids.sql.format.avro.read.enabled")
+ .doc("When set to true enables avro input acceleration")
+ .booleanConf
+ .createWithDefault(false)
+
val ENABLE_RANGE_WINDOW_BYTES = conf("spark.rapids.sql.window.range.byte.enabled")
.doc("When the order-by column of a range based window is byte type and " +
"the range boundary calculated for a value has overflow, CPU and GPU will get " +
@@ -1658,6 +1669,10 @@ class RapidsConf(conf: Map[String, String]) extends Logging {
lazy val isJsonReadEnabled: Boolean = get(ENABLE_JSON_READ)
+ lazy val isAvroEnabled: Boolean = get(ENABLE_AVRO)
+
+ lazy val isAvroReadEnabled: Boolean = get(ENABLE_AVRO_READ)
+
lazy val shuffleManagerEnabled: Boolean = get(SHUFFLE_MANAGER_ENABLED)
lazy val shuffleTransportEnabled: Boolean = get(SHUFFLE_TRANSPORT_ENABLE)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
index 37ddd8b910d..44c1d2ce866 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
@@ -2137,6 +2137,7 @@ object SupportedOpsForTools {
case "parquet" => conf.isParquetEnabled && conf.isParquetReadEnabled
case "orc" => conf.isOrcEnabled && conf.isOrcReadEnabled
case "json" => conf.isJsonEnabled && conf.isJsonReadEnabled
+ case "avro" => conf.isAvroEnabled && conf.isAvroReadEnabled
case _ =>
throw new IllegalArgumentException("Format is unknown we need to add it here!")
}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala
new file mode 100644
index 00000000000..84b44fc502c
--- /dev/null
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala
@@ -0,0 +1,96 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids
+
+import scala.util.{Failure, Success, Try}
+
+import com.nvidia.spark.rapids._
+
+import org.apache.spark.sql.avro.AvroFileFormat
+import org.apache.spark.sql.connector.read.Scan
+import org.apache.spark.sql.execution.FileSourceScanExec
+import org.apache.spark.sql.execution.datasources.FileFormat
+import org.apache.spark.sql.v2.avro.AvroScan
+import org.apache.spark.util.Utils
+
+object ExternalSource {
+
+ lazy val hasSparkAvroJar = {
+ val loader = Utils.getContextOrSparkClassLoader
+
+ /** spark-avro is an optional package for Spark, so the RAPIDS Accelerator
+ * must run successfully without it. */
+ Try(loader.loadClass("org.apache.spark.sql.v2.avro.AvroScan")) match {
+ case Failure(_) => false
+ case Success(_) => true
+ }
+ }
+
+ def tagSupportForGpuFileSourceScanExec(meta: SparkPlanMeta[FileSourceScanExec]): Unit = {
+ if (hasSparkAvroJar) {
+ meta.wrapped.relation.fileFormat match {
+ case _: AvroFileFormat => GpuReadAvroFileFormat.tagSupport(meta)
+ case f =>
+ meta.willNotWorkOnGpu(s"unsupported file format: ${f.getClass.getCanonicalName}")
+ }
+ } else {
+ meta.wrapped.relation.fileFormat match {
+ case f =>
+ meta.willNotWorkOnGpu(s"unsupported file format: ${f.getClass.getCanonicalName}")
+ }
+ }
+ }
+
+ def convertFileFormatForGpuFileSourceScanExec(format: FileFormat): FileFormat = {
+ if (hasSparkAvroJar) {
+ format match {
+ case _: AvroFileFormat => new GpuReadAvroFileFormat
+ case f =>
+ throw new IllegalArgumentException(s"${f.getClass.getCanonicalName} is not supported")
+ }
+ } else {
+ format match {
+ case f =>
+ throw new IllegalArgumentException(s"${f.getClass.getCanonicalName} is not supported")
+ }
+ }
+ }
+
+ def getScans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] = {
+ if (hasSparkAvroJar) {
+ Seq(
+ GpuOverrides.scan[AvroScan](
+ "Avro parsing",
+ (a, conf, p, r) => new ScanMeta[AvroScan](a, conf, p, r) {
+ override def tagSelfForGpu(): Unit = GpuAvroScan.tagSupport(this)
+
+ override def convertToGpu(): Scan =
+ GpuAvroScan(a.sparkSession,
+ a.fileIndex,
+ a.dataSchema,
+ a.readDataSchema,
+ a.readPartitionSchema,
+ a.options,
+ conf,
+ a.partitionFilters,
+ a.dataFilters)
+ })
+ ).map(r => (r.getClassFor.asSubclass(classOf[Scan]), r)).toMap
+ } else Map.empty
+ }
+
+}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuAvroScan.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuAvroScan.scala
new file mode 100644
index 00000000000..b2aaa3e3b26
--- /dev/null
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuAvroScan.scala
@@ -0,0 +1,472 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids
+
+import java.io.OutputStream
+import java.net.URI
+
+import scala.annotation.tailrec
+import scala.collection.JavaConverters.mapAsScalaMapConverter
+import scala.collection.mutable.ArrayBuffer
+import scala.math.max
+
+import ai.rapids.cudf.{AvroOptions => CudfAvroOptions, HostMemoryBuffer, NvtxColor, NvtxRange, Table}
+import com.nvidia.spark.rapids.{Arm, AvroDataFileReader, AvroFormatType, BlockInfo, ColumnarPartitionReaderWithPartitionValues, FileFormatChecks, FilePartitionReaderBase, GpuBatchUtils, GpuColumnVector, GpuMetric, GpuSemaphore, Header, HostMemoryOutputStream, NvtxWithMetrics, PartitionReaderWithBytesRead, RapidsConf, RapidsMeta, ReadFileOp, ScanMeta, ScanWithMetrics}
+import com.nvidia.spark.rapids.GpuMetric.{GPU_DECODE_TIME, NUM_OUTPUT_BATCHES, PEAK_DEVICE_MEMORY, READ_FS_TIME, SEMAPHORE_WAIT_TIME, WRITE_BUFFER_TIME}
+import org.apache.avro.file.DataFileConstants.SYNC_SIZE
+import org.apache.avro.mapred.FsInput
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FSDataInputStream, Path}
+
+import org.apache.spark.TaskContext
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.avro.AvroOptions
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.connector.read.{PartitionReader, PartitionReaderFactory}
+import org.apache.spark.sql.execution.QueryExecutionException
+import org.apache.spark.sql.execution.datasources.{PartitionedFile, PartitioningAwareFileIndex}
+import org.apache.spark.sql.execution.datasources.v2.{FilePartitionReaderFactory, FileScan}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.rapids.shims.AvroUtils
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.sql.v2.avro.AvroScan
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.SerializableConfiguration
+
+object GpuAvroScan {
+
+ def tagSupport(scanMeta: ScanMeta[AvroScan]) : Unit = {
+ val scan = scanMeta.wrapped
+ tagSupport(
+ scan.sparkSession,
+ scan.readDataSchema,
+ scan.options.asScala.toMap,
+ scanMeta)
+ }
+
+ def tagSupport(
+ sparkSession: SparkSession,
+ readSchema: StructType,
+ options: Map[String, String],
+ meta: RapidsMeta[_, _, _]): Unit = {
+
+ val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options)
+ val parsedOptions = new AvroOptions(options, hadoopConf)
+
+ if (!meta.conf.isAvroEnabled) {
+ meta.willNotWorkOnGpu("Avro input and output has been disabled. To enable set " +
+ s"${RapidsConf.ENABLE_AVRO} to true")
+ }
+
+ if (!meta.conf.isAvroReadEnabled) {
+ meta.willNotWorkOnGpu("Avro input has been disabled. To enable set " +
+ s"${RapidsConf.ENABLE_AVRO_READ} to true")
+ }
+
+ AvroUtils.tagSupport(parsedOptions, meta)
+
+ FileFormatChecks.tag(meta, readSchema, AvroFormatType, ReadFileOp)
+ }
+
+}
+
+case class GpuAvroScan(
+ sparkSession: SparkSession,
+ fileIndex: PartitioningAwareFileIndex,
+ dataSchema: StructType,
+ readDataSchema: StructType,
+ readPartitionSchema: StructType,
+ options: CaseInsensitiveStringMap,
+ rapidsConf: RapidsConf,
+ partitionFilters: Seq[Expression] = Seq.empty,
+ dataFilters: Seq[Expression] = Seq.empty) extends FileScan with ScanWithMetrics {
+ override def isSplitable(path: Path): Boolean = true
+
+ @scala.annotation.nowarn(
+ "msg=value ignoreExtension in class AvroOptions is deprecated*"
+ )
+ override def createReaderFactory(): PartitionReaderFactory = {
+ val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
+ // Hadoop Configurations are case sensitive.
+ val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap)
+ val broadcastedConf = sparkSession.sparkContext.broadcast(
+ new SerializableConfiguration(hadoopConf))
+ val parsedOptions = new AvroOptions(caseSensitiveMap, hadoopConf)
+ // The partition values are already truncated in `FileScan.partitions`.
+ // We should use `readPartitionSchema` as the partition schema here.
+ GpuAvroPartitionReaderFactory(
+ sparkSession.sessionState.conf,
+ broadcastedConf,
+ dataSchema,
+ readDataSchema,
+ readPartitionSchema,
+ rapidsConf,
+ parsedOptions.ignoreExtension,
+ metrics)
+ }
+
+ // overrides nothing in 330
+ def withFilters(
+ partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
+ this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)
+
+}
+
+/** Avro partition reader factory to build columnar reader */
+case class GpuAvroPartitionReaderFactory(
+ sqlConf: SQLConf,
+ broadcastedConf: Broadcast[SerializableConfiguration],
+ dataSchema: StructType,
+ readDataSchema: StructType,
+ partitionSchema: StructType,
+ @transient rapidsConf: RapidsConf,
+ ignoreExtension: Boolean,
+ metrics: Map[String, GpuMetric]) extends FilePartitionReaderFactory with Logging {
+
+ private val debugDumpPrefix = rapidsConf.parquetDebugDumpPrefix
+ private val maxReadBatchSizeRows = rapidsConf.maxReadBatchSizeRows
+ private val maxReadBatchSizeBytes = rapidsConf.maxReadBatchSizeBytes
+
+ override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = {
+ throw new IllegalStateException("ROW BASED PARSING IS NOT SUPPORTED ON THE GPU...")
+ }
+
+ override def buildColumnarReader(partFile: PartitionedFile): PartitionReader[ColumnarBatch] = {
+ val conf = broadcastedConf.value.value
+ val blockMeta = GpuAvroFileFilterHandler(sqlConf, broadcastedConf,
+ ignoreExtension, broadcastedConf.value.value).filterBlocks(partFile)
+ val reader = new PartitionReaderWithBytesRead(new AvroPartitionReader(conf, partFile, blockMeta,
+ readDataSchema, debugDumpPrefix, maxReadBatchSizeRows,
+ maxReadBatchSizeBytes, metrics))
+ ColumnarPartitionReaderWithPartitionValues.newReader(partFile, reader, partitionSchema)
+ }
+}
+
+/**
+ * A tool to filter Avro blocks
+ *
+ * @param sqlConf SQLConf
+ * @param broadcastedConf the Hadoop configuration
+ */
+private case class GpuAvroFileFilterHandler(
+ @transient sqlConf: SQLConf,
+ broadcastedConf: Broadcast[SerializableConfiguration],
+ ignoreExtension: Boolean,
+ hadoopConf: Configuration) extends Arm with Logging {
+
+ def filterBlocks(partFile: PartitionedFile): AvroBlockMeta = {
+
+ def passSync(blockStart: Long, position: Long): Boolean = {
+ blockStart >= position + SYNC_SIZE
+ }
+
+ if (ignoreExtension || partFile.filePath.endsWith(".avro")) {
+ val in = new FsInput(new Path(new URI(partFile.filePath)), hadoopConf)
+ closeOnExcept(in) { _ =>
+ withResource(AvroDataFileReader.openReader(in)) { reader =>
+ val blocks = reader.getBlocks()
+ val filteredBlocks = new ArrayBuffer[BlockInfo]()
+ blocks.foreach(block => {
+ if (partFile.start <= block.blockStart - SYNC_SIZE &&
+ !passSync(block.blockStart, partFile.start + partFile.length)) {
+ filteredBlocks.append(block)
+ }
+ })
+ AvroBlockMeta(reader.getHeader(), filteredBlocks)
+ }
+ }
+ } else {
+ AvroBlockMeta(new Header(), Seq.empty)
+ }
+ }
+}
+
+/**
+ * Avro block meta info
+ *
+ * @param header the header of avro file
+ * @param blocks the total block info of avro file
+ */
+case class AvroBlockMeta(header: Header, blocks: Seq[BlockInfo])
+
+/**
+ * CopyRange to indicate from where to copy.
+ *
+ * @param offset from where to copy
+ * @param length how many bytes to copy
+ */
+case class CopyRange(offset: Long, length: Long)
+
+/**
+ *
+ * @param conf the Hadoop configuration
+ * @param partFile the partitioned files to read
+ * @param blockMeta the block meta info of partFile
+ * @param readDataSchema the Spark schema describing what will be read
+ * @param debugDumpPrefix a path prefix to use for dumping the fabricated avro data or null
+ * @param maxReadBatchSizeRows soft limit on the maximum number of rows the reader reads per batch
+ * @param maxReadBatchSizeBytes soft limit on the maximum number of bytes the reader reads per batch
+ * @param execMetrics metrics
+ */
+class AvroPartitionReader(
+ conf: Configuration,
+ partFile: PartitionedFile,
+ blockMeta: AvroBlockMeta,
+ readDataSchema: StructType,
+ debugDumpPrefix: String,
+ maxReadBatchSizeRows: Integer,
+ maxReadBatchSizeBytes: Long,
+ execMetrics: Map[String, GpuMetric]) extends FilePartitionReaderBase(conf, execMetrics) {
+
+ val filePath = new Path(new URI(partFile.filePath))
+ private val blockIterator: BufferedIterator[BlockInfo] = blockMeta.blocks.iterator.buffered
+
+ override def next(): Boolean = {
+ batch.foreach(_.close())
+ batch = None
+ if (!isDone) {
+ if (!blockIterator.hasNext) {
+ isDone = true
+ metrics(PEAK_DEVICE_MEMORY) += maxDeviceMemory
+ } else {
+ batch = readBatch()
+ }
+ }
+
+ // NOTE: At this point, the task may not have yet acquired the semaphore if `batch` is `None`.
+ // We are not acquiring the semaphore here since this next() is getting called from
+ // the `PartitionReaderIterator` which implements a standard iterator pattern, and
+ // advertises `hasNext` as false when we return false here. No downstream tasks should
+ // try to call next after `hasNext` returns false, and any task that produces some kind of
+ // data when `hasNext` is false is responsible to get the semaphore themselves.
+ batch.isDefined
+ }
+
+ private def readBatch(): Option[ColumnarBatch] = {
+ withResource(new NvtxRange("Avro readBatch", NvtxColor.GREEN)) { _ =>
+ val currentChunkedBlocks = populateCurrentBlockChunk(blockIterator,
+ maxReadBatchSizeRows, maxReadBatchSizeBytes)
+ if (readDataSchema.isEmpty) {
+ // not reading any data, so return a degenerate ColumnarBatch with the row count
+ val numRows = currentChunkedBlocks.map(_.count).sum.toInt
+ if (numRows == 0) {
+ None
+ } else {
+ Some(new ColumnarBatch(Array.empty, numRows.toInt))
+ }
+ } else {
+ val table = readToTable(currentChunkedBlocks)
+ try {
+ val colTypes = readDataSchema.fields.map(f => f.dataType)
+ val maybeBatch = table.map(t => GpuColumnVector.from(t, colTypes))
+ maybeBatch.foreach { batch =>
+ logDebug(s"GPU batch size: ${GpuColumnVector.getTotalDeviceMemoryUsed(batch)} bytes")
+ }
+ maybeBatch
+ } finally {
+ table.foreach(_.close())
+ }
+ }
+ }
+ }
+
+ private def readToTable(currentChunkedBlocks: Seq[BlockInfo]): Option[Table] = {
+ if (currentChunkedBlocks.isEmpty) {
+ return None
+ }
+ val (dataBuffer, dataSize) = readPartFile(currentChunkedBlocks, filePath)
+ try {
+ if (dataSize == 0) {
+ None
+ } else {
+
+ // Dump data into a file
+ dumpDataToFile(dataBuffer, dataSize, Array(partFile), Option(debugDumpPrefix), Some("avro"))
+
+ val includeColumns = readDataSchema.fieldNames.toSeq
+
+ val parseOpts = CudfAvroOptions.builder()
+ .includeColumn(includeColumns: _*).build()
+
+ // about to start using the GPU
+ GpuSemaphore.acquireIfNecessary(TaskContext.get(), metrics(SEMAPHORE_WAIT_TIME))
+
+ val table = withResource(new NvtxWithMetrics("Avro decode", NvtxColor.DARK_GREEN,
+ metrics(GPU_DECODE_TIME))) { _ =>
+ Table.readAvro(parseOpts, dataBuffer, 0, dataSize)
+ }
+ closeOnExcept(table) { _ =>
+ maxDeviceMemory = max(GpuColumnVector.getTotalDeviceMemoryUsed(table), maxDeviceMemory)
+ if (readDataSchema.length < table.getNumberOfColumns) {
+ throw new QueryExecutionException(s"Expected ${readDataSchema.length} columns " +
+ s"but read ${table.getNumberOfColumns} from $filePath")
+ }
+ }
+ metrics(NUM_OUTPUT_BATCHES) += 1
+ Some(table)
+ }
+ } finally {
+ dataBuffer.close()
+ }
+ }
+
+ /** Copy the data into HMB */
+ protected def copyDataRange(
+ range: CopyRange,
+ in: FSDataInputStream,
+ out: OutputStream,
+ copyBuffer: Array[Byte]): Unit = {
+ var readTime = 0L
+ var writeTime = 0L
+ if (in.getPos != range.offset) {
+ in.seek(range.offset)
+ }
+ var bytesLeft = range.length
+ while (bytesLeft > 0) {
+ // downcast is safe because copyBuffer.length is an int
+ val readLength = Math.min(bytesLeft, copyBuffer.length).toInt
+ val start = System.nanoTime()
+ in.readFully(copyBuffer, 0, readLength)
+ val mid = System.nanoTime()
+ out.write(copyBuffer, 0, readLength)
+ val end = System.nanoTime()
+ readTime += (mid - start)
+ writeTime += (end - mid)
+ bytesLeft -= readLength
+ }
+ execMetrics.get(READ_FS_TIME).foreach(_.add(readTime))
+ execMetrics.get(WRITE_BUFFER_TIME).foreach(_.add(writeTime))
+ }
+
+ /**
+ * Tried to combine the sequential blocks
+ * @param blocks blocks to be combined
+ * @param blocksRange the list of combined ranges
+ */
+ private def combineBlocks(blocks: Seq[BlockInfo],
+ blocksRange: ArrayBuffer[CopyRange]) = {
+ var currentCopyStart = 0L
+ var currentCopyEnd = 0L
+
+ // Combine the meta and blocks into a seq to get the copy range
+ val metaAndBlocks: Seq[BlockInfo] =
+ Seq(BlockInfo(0, blockMeta.header.getFirstBlockStart, 0, 0)) ++ blocks
+
+ metaAndBlocks.foreach { block =>
+ if (currentCopyEnd != block.blockStart) {
+ if (currentCopyEnd != 0) {
+ blocksRange.append(CopyRange(currentCopyStart, currentCopyEnd - currentCopyStart))
+ }
+ currentCopyStart = block.blockStart
+ currentCopyEnd = currentCopyStart
+ }
+ currentCopyEnd += block.blockLength
+ }
+
+ if (currentCopyEnd != currentCopyStart) {
+ blocksRange.append(CopyRange(currentCopyStart, currentCopyEnd - currentCopyStart))
+ }
+ }
+
+ protected def readPartFile(
+ blocks: Seq[BlockInfo],
+ filePath: Path): (HostMemoryBuffer, Long) = {
+ withResource(new NvtxWithMetrics("Avro buffer file split", NvtxColor.YELLOW,
+ metrics("bufferTime"))) { _ =>
+ withResource(filePath.getFileSystem(conf).open(filePath)) { in =>
+ val estTotalSize = calculateOutputSize(blocks)
+ closeOnExcept(HostMemoryBuffer.allocate(estTotalSize)) { hmb =>
+ val out = new HostMemoryOutputStream(hmb)
+ val copyRanges = new ArrayBuffer[CopyRange]()
+ combineBlocks(blocks, copyRanges)
+ val copyBuffer = new Array[Byte](8 * 1024 * 1024)
+ copyRanges.foreach(copyRange => copyDataRange(copyRange, in, out, copyBuffer))
+ // check we didn't go over memory
+ if (out.getPos > estTotalSize) {
+ throw new QueryExecutionException(s"Calculated buffer size $estTotalSize is to " +
+ s"small, actual written: ${out.getPos}")
+ }
+ (hmb, out.getPos)
+ }
+ }
+ }
+ }
+
+ /**
+ * Calculate the combined size
+ * @param currentChunkedBlocks the blocks to calculated
+ * @return the total size of blocks + header
+ */
+ protected def calculateOutputSize(currentChunkedBlocks: Seq[BlockInfo]): Long = {
+ var totalSize: Long = 0;
+ // For simplicity, we just copy the whole meta of AVRO
+ totalSize += blockMeta.header.getFirstBlockStart
+ // Add all blocks
+ totalSize += currentChunkedBlocks.map(_.blockLength).sum
+ totalSize
+ }
+
+ /**
+ * Get the block chunk according to the max batch size and max rows.
+ *
+ * @param blockIter blocks to be evaluated
+ * @param maxReadBatchSizeRows soft limit on the maximum number of rows the reader
+ * reads per batch
+ * @param maxReadBatchSizeBytes soft limit on the maximum number of bytes the reader
+ * reads per batch
+ * @return
+ */
+ protected def populateCurrentBlockChunk(
+ blockIter: BufferedIterator[BlockInfo],
+ maxReadBatchSizeRows: Int,
+ maxReadBatchSizeBytes: Long): Seq[BlockInfo] = {
+ val currentChunk = new ArrayBuffer[BlockInfo]
+ var numRows: Long = 0
+ var numBytes: Long = 0
+ var numAvroBytes: Long = 0
+
+ @tailrec
+ def readNextBatch(): Unit = {
+ if (blockIter.hasNext) {
+ val peekedRowGroup = blockIter.head
+ if (peekedRowGroup.count > Integer.MAX_VALUE) {
+ throw new UnsupportedOperationException("Too many rows in split")
+ }
+ if (numRows == 0 || numRows + peekedRowGroup.count <= maxReadBatchSizeRows) {
+ val estimatedBytes = GpuBatchUtils.estimateGpuMemory(readDataSchema,
+ peekedRowGroup.count)
+ if (numBytes == 0 || numBytes + estimatedBytes <= maxReadBatchSizeBytes) {
+ currentChunk += blockIter.next()
+ numRows += currentChunk.last.count
+ numAvroBytes += currentChunk.last.count
+ numBytes += estimatedBytes
+ readNextBatch()
+ }
+ }
+ }
+ }
+
+ readNextBatch()
+ logDebug(s"Loaded $numRows rows from Avro. bytes read: $numAvroBytes. " +
+ s"Estimated GPU bytes: $numBytes")
+ currentChunk
+ }
+}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala
index 0c00ae44c0a..16716d090bf 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala
@@ -595,8 +595,7 @@ object GpuFileSourceScanExec {
case f if GpuOrcFileFormat.isSparkOrcFormat(f) => GpuReadOrcFileFormat.tagSupport(meta)
case _: ParquetFileFormat => GpuReadParquetFileFormat.tagSupport(meta)
case _: JsonFileFormat => GpuReadJsonFileFormat.tagSupport(meta)
- case f =>
- meta.willNotWorkOnGpu(s"unsupported file format: ${f.getClass.getCanonicalName}")
+ case _ => ExternalSource.tagSupportForGpuFileSourceScanExec(meta)
}
}
@@ -606,8 +605,7 @@ object GpuFileSourceScanExec {
case f if GpuOrcFileFormat.isSparkOrcFormat(f) => new GpuReadOrcFileFormat
case _: ParquetFileFormat => new GpuReadParquetFileFormat
case _: JsonFileFormat => new GpuReadJsonFileFormat
- case f =>
- throw new IllegalArgumentException(s"${f.getClass.getCanonicalName} is not supported")
+ case _ => ExternalSource.convertFileFormatForGpuFileSourceScanExec(format)
}
}
}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuReadAvroFileFormat.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuReadAvroFileFormat.scala
new file mode 100644
index 00000000000..28f67cde860
--- /dev/null
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuReadAvroFileFormat.scala
@@ -0,0 +1,79 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids
+
+import com.nvidia.spark.rapids.{GpuMetric, GpuReadFileFormatWithMetrics, PartitionReaderIterator, RapidsConf, SparkPlanMeta}
+import com.nvidia.spark.rapids.shims.SparkShimImpl
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.avro.{AvroFileFormat, AvroOptions}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.FileSourceScanExec
+import org.apache.spark.sql.execution.datasources.PartitionedFile
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.SerializableConfiguration
+
+/**
+ * A FileFormat that allows reading Avro files with the GPU.
+ */
+class GpuReadAvroFileFormat extends AvroFileFormat with GpuReadFileFormatWithMetrics {
+
+ @scala.annotation.nowarn(
+ "msg=value ignoreExtension in class AvroOptions is deprecated*"
+ )
+ override def buildReaderWithPartitionValuesAndMetrics(
+ sparkSession: SparkSession,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ requiredSchema: StructType,
+ filters: Seq[Filter],
+ options: Map[String, String],
+ hadoopConf: Configuration,
+ metrics: Map[String, GpuMetric]): PartitionedFile => Iterator[InternalRow] = {
+ val sqlConf = sparkSession.sessionState.conf
+ val broadcastedHadoopConf =
+ sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
+
+ val parsedOptions = new AvroOptions(options, hadoopConf)
+ val ignoreExtension = parsedOptions.ignoreExtension
+
+ val factory = GpuAvroPartitionReaderFactory(
+ sqlConf,
+ broadcastedHadoopConf,
+ dataSchema,
+ requiredSchema,
+ partitionSchema,
+ new RapidsConf(sqlConf),
+ ignoreExtension,
+ metrics)
+ PartitionReaderIterator.buildReader(factory)
+ }
+}
+
+object GpuReadAvroFileFormat {
+ def tagSupport(meta: SparkPlanMeta[FileSourceScanExec]): Unit = {
+ val fsse = meta.wrapped
+ GpuAvroScan.tagSupport(
+ SparkShimImpl.sessionFromPlan(fsse),
+ fsse.requiredSchema,
+ fsse.relation.options,
+ meta
+ )
+ }
+}
diff --git a/tools/src/main/resources/supportedDataSource.csv b/tools/src/main/resources/supportedDataSource.csv
index bef3ceae4df..821acaa19cb 100644
--- a/tools/src/main/resources/supportedDataSource.csv
+++ b/tools/src/main/resources/supportedDataSource.csv
@@ -1,4 +1,5 @@
Format,Direction,BOOLEAN,BYTE,SHORT,INT,LONG,FLOAT,DOUBLE,DATE,TIMESTAMP,STRING,DECIMAL,NULL,BINARY,CALENDAR,ARRAY,MAP,STRUCT,UDT
+Avro,read,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO
CSV,read,S,S,S,S,S,S,S,S,CO,S,S,NA,NS,NA,NA,NA,NA,NA
JSON,read,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO
ORC,read,S,S,S,S,S,S,S,S,PS,S,S,NA,NS,NA,PS,PS,PS,NS