Skip to content

Commit

Permalink
support Dictionary Late Materialization (push down to GPU)
Browse files Browse the repository at this point in the history
  • Loading branch information
sperlingxx committed Apr 7, 2024
1 parent 5f1af86 commit 76639b0
Show file tree
Hide file tree
Showing 10 changed files with 811 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2610,7 +2610,7 @@ class MultiFileCloudParquetPartitionReader(
hostBuffer, 0, dataSize, metrics,
dateRebaseMode, timestampRebaseMode, hasInt96Timestamps,
clippedSchema, readDataSchema,
slotAcquired, hybridOpts.async)
slotAcquired, hybridOpts.enableDictLateMat, hybridOpts.async)

val batchIter = HostParquetIterator(asyncReader, hybridOpts, colTypes, metrics)

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.parquet.rapids

import java.util.Optional

import scala.collection.mutable

import ai.rapids.cudf.{DType, HostColumnVector, HostColumnVectorCore, HostMemoryBuffer}
import org.apache.parquet.column.ColumnDescriptor
import org.apache.parquet.column.page.{DataPage, DataPageV1, DataPageV2, DictionaryPage, PageReader, PageReadStore}
import org.apache.parquet.column.values.dictionary.PlainValuesDictionary.PlainBinaryDictionary

import org.apache.spark.internal.Logging

case class DictLatMatInfo(dictVector: HostColumnVector, dictPageOffsets: Array[Int])

object DictLateMatUtils extends Logging {

def extractDict(rowGroups: Seq[PageReadStore],
descriptor: ColumnDescriptor): Option[DictLatMatInfo] = {

val dictPages = mutable.ArrayBuffer[DictionaryPage]()

// Go through each RowGroup and each page inside them to check if all pages use Dictionary.
// Dictionary late materialization only works if all pages use Dictionary.
rowGroups.foreach { rowGroup =>
val pageReader = rowGroup.getPageReader(descriptor)
val dictPage = pageReader.readDictionaryPage()
if (dictPage == null || !isAllDictEncoded(pageReader)) {
return None
}
dictPages += dictPage
}

Some(combineDictPages(dictPages, descriptor))
}

private def combineDictPages(dictPages: Seq[DictionaryPage],
descriptor: ColumnDescriptor): DictLatMatInfo = {
val pageOffsets = mutable.ArrayBuffer[Int](0)
var rowNum: Int = 0

val dictionaries = dictPages.map { dictPage =>
val dictionary = dictPage.getEncoding.initDictionary(descriptor, dictPage)
.asInstanceOf[PlainBinaryDictionary]
rowNum += dictionary.getMaxId + 1
pageOffsets += rowNum
dictionary
}

var charNum: Int = 0
val offsetBuf = HostMemoryBuffer.allocate((rowNum + 1) * 4L)
offsetBuf.setInt(0, 0)
var i = 1
dictionaries.foreach { dict =>
(0 to dict.getMaxId).foreach { j =>
charNum += dict.decodeToBinary(j).length()
offsetBuf.setInt(i * 4L, charNum)
i += 1
}
}

val charBuf = HostMemoryBuffer.allocate(charNum)
i = 0
dictionaries.foreach { dict =>
(0 to dict.getMaxId).foreach { j =>
val ba = dict.decodeToBinary(j).getBytes
charBuf.setBytes(offsetBuf.getInt(i * 4L), ba, 0, ba.length)
i += 1
}
}

val dictVector = new HostColumnVector(DType.STRING, rowNum, Optional.of(0L),
charBuf, null, offsetBuf, new java.util.ArrayList[HostColumnVectorCore]())

DictLatMatInfo(dictVector, pageOffsets.toArray)
}

private def isAllDictEncoded(pageReader: PageReader): Boolean = {
require(ccPageReader.isInstance(pageReader),
"Only supports org.apache.parquet.hadoop.ColumnChunkPageReadStore.ColumnChunkPageReader")
val rawPagesField = ccPageReader.getDeclaredField("compressedPages")
rawPagesField.setAccessible(true)

val pageQueue = rawPagesField.get(pageReader).asInstanceOf[java.util.ArrayDeque[DataPage]]
val swapQueue = new java.util.ArrayDeque[DataPage]()
var allDictEncoded = true

while (!pageQueue.isEmpty) {
swapQueue.addLast(pageQueue.pollFirst())
if (allDictEncoded) {
allDictEncoded = swapQueue.getLast match {
case p: DataPageV1 =>
p.getValueEncoding.usesDictionary()
case p: DataPageV2 =>
p.getDataEncoding.usesDictionary()
}
}
}
while (!swapQueue.isEmpty) {
pageQueue.addLast(swapQueue.pollFirst())
}

allDictEncoded
}

private val ccPageReader: Class[_] = {
val ccPageReadStore = Class.forName("org.apache.parquet.hadoop.ColumnChunkPageReadStore")
val ccPageReader = ccPageReadStore.getDeclaredClasses.find { memberClz =>
memberClz.getSimpleName.equals("ColumnChunkPageReader")
}
require(ccPageReader.nonEmpty, "can NOT find the Class definition of ColumnChunkPageReader")
ccPageReader.get
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.parquet.rapids;

import java.io.Closeable;

import ai.rapids.cudf.HostMemoryBuffer;
import org.apache.parquet.column.Dictionary;
import org.apache.parquet.column.values.dictionary.PlainValuesDictionary.PlainBinaryDictionary;

public class OffHeapBinaryDictionary extends Dictionary implements Closeable {

public OffHeapBinaryDictionary(PlainBinaryDictionary binDict) {
super(binDict.getEncoding());
this.size = binDict.getMaxId() + 1;
offsets = new int[this.size + 1];
for (int i = 0; i < this.size; i++) {
offsets[i + 1] = offsets[i] + binDict.decodeToBinary(i).length();
}
data = HostMemoryBuffer.allocate(offsets[this.size]);
for (int i = 0; i < this.size; i++) {
byte[] ba = binDict.decodeToBinary(i).getBytes();
data.setBytes(offsets[i], ba, 0, ba.length);
}
}

public HostMemoryBuffer getData() {
return data;
}

public int[] getOffsets() {
return offsets;
}

@Override
public int getMaxId() {
return this.size;
}

@Override
public void close() {
if (data != null) {
data.close();
}
}

private final int size;
private final int[] offsets;
private final HostMemoryBuffer data;

}
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ final class ParquetColumnVector {
Set<ParquetColumn> missingColumns,
boolean isTopLevel,
int maxRepetitiveDefLevel,
Object defaultValue) {
Object defaultValue,
boolean dictLateMaterialize) {

DataType sparkType = column.sparkType();
if (!sparkType.sameType(vector.dataType())) {
if (!dictLateMaterialize && !sparkType.sameType(vector.dataType())) {
throw new IllegalArgumentException("Spark type: " + sparkType +
" doesn't match the type: " + vector.dataType() + " in column vector");
}
Expand Down Expand Up @@ -119,7 +121,8 @@ final class ParquetColumnVector {
for (int i = 0; i < column.children().size(); i++) {
ParquetColumnVector childCv = new ParquetColumnVector(column.children().apply(i),
vector.getChild(i), capacity, missingColumns, false,
childMaxRepetitiveDefLevel, null);
childMaxRepetitiveDefLevel, null,
false);
children.add(childCv);


Expand Down Expand Up @@ -251,6 +254,9 @@ void setColumnReader(VectorizedColumnReader reader) {
if (!isPrimitive) {
throw new IllegalStateException("Can't set reader for non-primitive column");
}
if (this.columnReader != null) {
this.columnReader.close();
}
this.columnReader = reader;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* http://www.apache.org/licenses/LICENSE-2.0
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
Expand All @@ -19,9 +18,10 @@

import org.apache.parquet.column.Dictionary;

import org.apache.spark.sql.execution.vectorized.rapids.HostWritableColumnVector;
import org.apache.spark.sql.execution.vectorized.rapids.WritableColumnVector;

public interface ParquetVectorUpdater {
public abstract class ParquetVectorUpdater {
/**
* Read a batch of `total` values from `valuesReader` into `values`, starting from `offset`.
*
Expand All @@ -30,7 +30,7 @@ public interface ParquetVectorUpdater {
* @param values destination values vector
* @param valuesReader reader to read values from
*/
void readValues(
abstract void readValues(
int total,
int offset,
WritableColumnVector values,
Expand All @@ -42,7 +42,7 @@ void readValues(
* @param total total number of values to skip
* @param valuesReader reader to skip values from
*/
void skipValues(int total, VectorizedValuesReader valuesReader);
abstract void skipValues(int total, VectorizedValuesReader valuesReader);

/**
* Read a single value from `valuesReader` into `values`, at `offset`.
Expand All @@ -51,7 +51,7 @@ void readValues(
* @param values destination value vector
* @param valuesReader reader to read values from
*/
void readValue(int offset, WritableColumnVector values, VectorizedValuesReader valuesReader);
abstract void readValue(int offset, WritableColumnVector values, VectorizedValuesReader valuesReader);

/**
* Process a batch of `total` values starting from `offset` in `values`, whose null slots
Expand All @@ -64,15 +64,24 @@ void readValues(
* @param dictionaryIds vector storing the dictionary IDs
* @param dictionary Parquet dictionary used to decode a dictionary ID to its value
*/
default void decodeDictionaryIds(
public void decodeDictionaryIds(
int total,
int offset,
WritableColumnVector values,
WritableColumnVector dictionaryIds,
Dictionary dictionary) {

HostWritableColumnVector cv = (HostWritableColumnVector) values;

if (!cv.hasNullMask()) {
for (int i = offset; i < offset + total; i++) {
decodeSingleDictionaryId(i, cv, dictionaryIds, dictionary);
}
return;
}
for (int i = offset; i < offset + total; i++) {
if (!values.isNullAt(i)) {
decodeSingleDictionaryId(i, values, dictionaryIds, dictionary);
if (!cv.isNullAt(i)) {
decodeSingleDictionaryId(i, cv, dictionaryIds, dictionary);
}
}
}
Expand All @@ -86,7 +95,7 @@ default void decodeDictionaryIds(
* @param dictionaryIds vector storing the dictionary IDs
* @param dictionary Parquet dictionary used to decode a dictionary ID to its value
*/
void decodeSingleDictionaryId(
abstract void decodeSingleDictionaryId(
int offset,
WritableColumnVector values,
WritableColumnVector dictionaryIds,
Expand Down
Loading

0 comments on commit 76639b0

Please sign in to comment.