diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 55bd5ec5ff9..4d9991d0dd9 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -101,11 +101,39 @@ public ColumnView(DType type, long rows, Optional nullCount, || !nullCount.isPresent(); } + /** + * Create a new column view based off of data already on the device. Ref count on the buffers + * is not incremented and none of the underlying buffers are owned by this view. The returned + * ColumnView is only valid as long as the underlying buffers remain valid. If the buffers are + * closed before this ColumnView is closed, it will result in undefined behavior. + * + * If ownership is needed, call {@link ColumnView#copyToColumnVector} + * + * @param type the type of the vector + * @param rows the number of rows in this vector. + * @param nullCount the number of nulls in the dataset. + * @param dataBuffer a host buffer required for nested types including strings and string + * categories. The ownership doesn't change on this buffer + * @param validityBuffer an optional validity buffer. Must be provided if nullCount != 0. + * The ownership doesn't change on this buffer + * @param offsetBuffer The offsetbuffer for columns that need an offset buffer + */ + public ColumnView(DType type, long rows, Optional nullCount, + BaseDeviceMemoryBuffer dataBuffer, + BaseDeviceMemoryBuffer validityBuffer, BaseDeviceMemoryBuffer offsetBuffer) { + this(type, (int) rows, nullCount.orElse(UNKNOWN_NULL_COUNT).intValue(), + dataBuffer, validityBuffer, offsetBuffer, null); + assert (!type.isNestedType()); + assert (nullCount.isPresent() && nullCount.get() <= Integer.MAX_VALUE) + || !nullCount.isPresent(); + } + private ColumnView(DType type, long rows, int nullCount, BaseDeviceMemoryBuffer dataBuffer, BaseDeviceMemoryBuffer validityBuffer, BaseDeviceMemoryBuffer offsetBuffer, ColumnView[] children) { this(ColumnVector.initViewHandle(type, (int) rows, nullCount, dataBuffer, validityBuffer, - offsetBuffer, Arrays.stream(children).mapToLong(c -> c.getNativeView()).toArray())); + offsetBuffer, children == null ? new long[]{} : + Arrays.stream(children).mapToLong(c -> c.getNativeView()).toArray())); } /** Creates a ColumnVector from a column view handle @@ -140,6 +168,32 @@ public final DType getType() { return type; } + /** + * Returns the child column views for this view + * Please note that it is the responsibility of the caller to close these views. + * @return an array of child column views + */ + public final ColumnView[] getChildColumnViews() { + int numChildren = getNumChildren(); + if (!getType().isNestedType()) { + return null; + } + ColumnView[] views = new ColumnView[numChildren]; + try { + for (int i = 0; i < numChildren; i++) { + views[i] = getChildColumnView(i); + } + return views; + } catch(Throwable t) { + for (ColumnView v: views) { + if (v != null) { + v.close(); + } + } + throw t; + } + } + /** * Returns the child column view at a given index. * Please note that it is the responsibility of the caller to close this view. diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index e725932ed5e..eeb2d308f1a 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -170,10 +170,19 @@ public long getDeviceMemorySize() { return total; } + /** + * This method is internal and exposed purely for testing purpopses + */ + static Table removeNullMasksIfNeeded(Table table) { + return new Table(removeNullMasksIfNeeded(table.nativeHandle)); + } + ///////////////////////////////////////////////////////////////////////////// // NATIVE APIs ///////////////////////////////////////////////////////////////////////////// - + + private static native long[] removeNullMasksIfNeeded(long tableView) throws CudfException; + private static native ContiguousTable[] contiguousSplit(long inputTable, int[] indices); private static native long[] partition(long inputTable, long partitionView, diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index f642a87b445..2bb56565f7a 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -929,6 +929,45 @@ jlongArray combine_join_results(JNIEnv *env, cudf::table &left_results, return combine_join_results(env, std::move(left_cols), std::move(right_cols)); } +cudf::column_view remove_validity_from_col(cudf::column_view column_view) { + if (!cudf::is_compound(column_view.type())) { + if (column_view.nullable() && column_view.null_count() == 0) { + // null_mask is allocated but no nulls present therefore we create a new column_view without + // the null_mask to avoid things blowing up in reading the parquet file + return cudf::column_view(column_view.type(), column_view.size(), column_view.head(), nullptr, + 0, column_view.offset()); + } else { + return cudf::column_view(column_view); + } + } else { + std::unique_ptr ret; + std::vector children; + children.reserve(column_view.num_children()); + for (auto it = column_view.child_begin(); it != column_view.child_end(); it++) { + children.push_back(remove_validity_from_col(*it)); + } + if (!column_view.nullable() || column_view.null_count() != 0) { + ret.reset(new cudf::column_view(column_view.type(), column_view.size(), nullptr, + column_view.null_mask(), column_view.null_count(), + column_view.offset(), children)); + } else { + ret.reset(new cudf::column_view(column_view.type(), column_view.size(), nullptr, nullptr, 0, + column_view.offset(), children)); + } + return *ret.release(); + } +} + +cudf::table_view remove_validity_if_needed(cudf::table_view *input_table_view) { + std::vector views; + views.reserve(input_table_view->num_columns()); + for (auto it = input_table_view->begin(); it != input_table_view->end(); it++) { + views.push_back(remove_validity_from_col(*it)); + } + + return cudf::table_view(views); +} + } // namespace } // namespace jni @@ -936,6 +975,25 @@ jlongArray combine_join_results(JNIEnv *env, cudf::table &left_results, extern "C" { +// This is a method purely added for testing remove_validity_if_needed method +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_removeNullMasksIfNeeded(JNIEnv *env, jclass, + jlong j_table_view) { + JNI_NULL_CHECK(env, j_table_view, "table view handle is null", 0); + try { + cudf::table_view *tview = reinterpret_cast(j_table_view); + cudf::table_view result = cudf::jni::remove_validity_if_needed(tview); + cudf::table m_tbl(result); + std::vector> cols = m_tbl.release(); + auto results = cudf::jni::native_jlongArray(env, cols.size()); + int i = 0; + for (auto it = cols.begin(); it != cols.end(); it++) { + results[i++] = reinterpret_cast(it->release()); + } + return results.get_jArray(); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_createCudfTableView(JNIEnv *env, jclass, jlongArray j_cudf_columns) { JNI_NULL_CHECK(env, j_cudf_columns, "columns are null", 0); @@ -1357,7 +1415,8 @@ JNIEXPORT void JNICALL Java_ai_rapids_cudf_Table_writeParquetChunk(JNIEnv *env, JNI_NULL_CHECK(env, j_state, "null state", ); using namespace cudf::io; - cudf::table_view *tview = reinterpret_cast(j_table); + cudf::table_view *tview_with_empty_nullmask = reinterpret_cast(j_table); + cudf::table_view tview = cudf::jni::remove_validity_if_needed(tview_with_empty_nullmask); cudf::jni::native_parquet_writer_handle *state = reinterpret_cast(j_state); @@ -1367,7 +1426,7 @@ JNIEXPORT void JNICALL Java_ai_rapids_cudf_Table_writeParquetChunk(JNIEnv *env, } try { cudf::jni::auto_set_device(env); - state->writer->write(*tview); + state->writer->write(tview); } CATCH_STD(env, ) } diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index aeb94e4824a..cc030c392cb 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -49,19 +49,14 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.nio.file.Files; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Comparator; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.stream.Collectors; import static ai.rapids.cudf.ParquetColumnWriterOptions.mapColumn; import static ai.rapids.cudf.ParquetWriterOptions.listBuilder; import static ai.rapids.cudf.ParquetWriterOptions.structBuilder; import static ai.rapids.cudf.Table.TestBuilder; +import static ai.rapids.cudf.Table.removeNullMasksIfNeeded; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -111,7 +106,7 @@ public static void assertColumnsAreEqual(ColumnView expect, ColumnView cv) { * @param colName The name of the column */ public static void assertColumnsAreEqual(ColumnView expected, ColumnView cv, String colName) { - assertPartialColumnsAreEqual(expected, 0, expected.getRowCount(), cv, colName, true); + assertPartialColumnsAreEqual(expected, 0, expected.getRowCount(), cv, colName, true, false); } /** @@ -121,7 +116,7 @@ public static void assertColumnsAreEqual(ColumnView expected, ColumnView cv, Str * @param colName The name of the host column */ public static void assertColumnsAreEqual(HostColumnVector expected, HostColumnVector cv, String colName) { - assertPartialColumnsAreEqual(expected, 0, expected.getRowCount(), cv, colName, true); + assertPartialColumnsAreEqual(expected, 0, expected.getRowCount(), cv, colName, true, false); } /** @@ -130,7 +125,7 @@ public static void assertColumnsAreEqual(HostColumnVector expected, HostColumnVe * @param cv The input Struct column */ public static void assertStructColumnsAreEqual(ColumnView expected, ColumnView cv) { - assertPartialStructColumnsAreEqual(expected, 0, expected.getRowCount(), cv, "unnamed", true); + assertPartialStructColumnsAreEqual(expected, 0, expected.getRowCount(), cv, "unnamed", true, false); } /** @@ -140,13 +135,14 @@ public static void assertStructColumnsAreEqual(ColumnView expected, ColumnView c * @param length The number of rows to consider * @param cv The input Struct column * @param colName The name of the column - * @param enableNullCheck Whether to check for nulls in the Struct column + * @param enableNullCountCheck Whether to check for nulls in the Struct column + * @param enableNullabilityCheck Whether the table have a validity mask */ public static void assertPartialStructColumnsAreEqual(ColumnView expected, long rowOffset, long length, - ColumnView cv, String colName, boolean enableNullCheck) { + ColumnView cv, String colName, boolean enableNullCountCheck, boolean enableNullabilityCheck) { try (HostColumnVector hostExpected = expected.copyToHost(); HostColumnVector hostcv = cv.copyToHost()) { - assertPartialColumnsAreEqual(hostExpected, rowOffset, length, hostcv, colName, enableNullCheck); + assertPartialColumnsAreEqual(hostExpected, rowOffset, length, hostcv, colName, enableNullCountCheck, enableNullabilityCheck); } } @@ -156,12 +152,13 @@ public static void assertPartialStructColumnsAreEqual(ColumnView expected, long * @param cv The input column * @param colName The name of the column * @param enableNullCheck Whether to check for nulls in the column + * @param enableNullabilityCheck Whether the table have a validity mask */ public static void assertPartialColumnsAreEqual(ColumnView expected, long rowOffset, long length, - ColumnView cv, String colName, boolean enableNullCheck) { + ColumnView cv, String colName, boolean enableNullCheck, boolean enableNullabilityCheck) { try (HostColumnVector hostExpected = expected.copyToHost(); HostColumnVector hostcv = cv.copyToHost()) { - assertPartialColumnsAreEqual(hostExpected, rowOffset, length, hostcv, colName, enableNullCheck); + assertPartialColumnsAreEqual(hostExpected, rowOffset, length, hostcv, colName, enableNullCheck, enableNullabilityCheck); } } @@ -172,18 +169,21 @@ public static void assertPartialColumnsAreEqual(ColumnView expected, long rowOff * @param length number of rows from starting offset * @param cv The input host column * @param colName The name of the host column - * @param enableNullCheck Whether to check for nulls in the host column + * @param enableNullCountCheck Whether to check for nulls in the host column */ public static void assertPartialColumnsAreEqual(HostColumnVectorCore expected, long rowOffset, long length, - HostColumnVectorCore cv, String colName, boolean enableNullCheck) { + HostColumnVectorCore cv, String colName, boolean enableNullCountCheck, boolean enableNullabilityCheck) { assertEquals(expected.getType(), cv.getType(), "Type For Column " + colName); assertEquals(length, cv.getRowCount(), "Row Count For Column " + colName); assertEquals(expected.getNumChildren(), cv.getNumChildren(), "Child Count for Column " + colName); - if (enableNullCheck) { + if (enableNullCountCheck) { assertEquals(expected.getNullCount(), cv.getNullCount(), "Null Count For Column " + colName); } else { // TODO add in a proper check when null counts are supported by serializing a partitioned column } + if (enableNullabilityCheck) { + assertEquals(expected.hasValidityVector(), cv.hasValidityVector(), "Column nullability is different than expected"); + } DType type = expected.getType(); for (long expectedRow = rowOffset; expectedRow < (rowOffset + length); expectedRow++) { long tableRow = expectedRow - rowOffset; @@ -269,7 +269,7 @@ public static void assertPartialColumnsAreEqual(HostColumnVectorCore expected, l } assertPartialColumnsAreEqual(expected.getNestedChildren().get(0), expectedChildRowOffset, numChildRows, cv.getNestedChildren().get(0), colName + " list child", - enableNullCheck); + enableNullCountCheck, enableNullabilityCheck); break; case STRUCT: List expectedChildren = expected.getNestedChildren(); @@ -280,7 +280,7 @@ public static void assertPartialColumnsAreEqual(HostColumnVectorCore expected, l String childName = colName + " child " + i; assertEquals(length, cvChild.getRowCount(), "Row Count for Column " + colName); assertPartialColumnsAreEqual(expectedChild, rowOffset, length, cvChild, - colName, enableNullCheck); + colName, enableNullCountCheck, enableNullabilityCheck); } break; default: @@ -296,9 +296,10 @@ public static void assertPartialColumnsAreEqual(HostColumnVectorCore expected, l * @param length the number of rows to check * @param table the input table to compare against expected * @param enableNullCheck whether to check for nulls or not + * @param enableNullabilityCheck whether the table have a validity mask */ public static void assertPartialTablesAreEqual(Table expected, long rowOffset, long length, Table table, - boolean enableNullCheck) { + boolean enableNullCheck, boolean enableNullabilityCheck) { assertEquals(expected.getNumberOfColumns(), table.getNumberOfColumns()); assertEquals(length, table.getRowCount(), "ROW COUNT"); for (int col = 0; col < expected.getNumberOfColumns(); col++) { @@ -308,7 +309,7 @@ public static void assertPartialTablesAreEqual(Table expected, long rowOffset, l if (rowOffset != 0 || length != expected.getRowCount()) { name = name + " PART " + rowOffset + "-" + (rowOffset + length - 1); } - assertPartialColumnsAreEqual(expect, rowOffset, length, cv, name, enableNullCheck); + assertPartialColumnsAreEqual(expect, rowOffset, length, cv, name, enableNullCheck, enableNullabilityCheck); } } @@ -318,7 +319,7 @@ public static void assertPartialTablesAreEqual(Table expected, long rowOffset, l * @param table the input table to compare against expected */ public static void assertTablesAreEqual(Table expected, Table table) { - assertPartialTablesAreEqual(expected, 0, expected.getRowCount(), table, true); + assertPartialTablesAreEqual(expected, 0, expected.getRowCount(), table, true, false); } void assertTablesHaveSameValues(HashMap[] expectedTable, Table table) { @@ -3235,7 +3236,7 @@ void testSerializationRoundTripConcatHostSide() throws IOException { try (Table found = JCudfSerialization.readAndConcat( headers.toArray(new JCudfSerialization.SerializedTableHeader[headers.size()]), buffers.toArray(new HostMemoryBuffer[buffers.size()]))) { - assertPartialTablesAreEqual(t, 0, t.getRowCount(), found, false); + assertPartialTablesAreEqual(t, 0, t.getRowCount(), found, false, false); } } finally { for (HostMemoryBuffer buff: buffers) { @@ -3288,7 +3289,7 @@ void testConcatHost() throws IOException { try (Table result = JCudfSerialization.readAndConcat( new JCudfSerialization.SerializedTableHeader[] {header, header}, new HostMemoryBuffer[] {buff, buff})) { - assertPartialTablesAreEqual(expected, 0, expected.getRowCount(), result, false); + assertPartialTablesAreEqual(expected, 0, expected.getRowCount(), result, false, false); } } } @@ -3329,7 +3330,7 @@ void testSerializationRoundTripSlicedHostSide() throws IOException { buffers.toArray(new HostMemoryBuffer[buffers.size()]), bout2); ByteArrayInputStream bin2 = new ByteArrayInputStream(bout2.toByteArray()); try (JCudfSerialization.TableAndRowCountPair found = JCudfSerialization.readTableFrom(bin2)) { - assertPartialTablesAreEqual(t, 0, t.getRowCount(), found.getTable(), false); + assertPartialTablesAreEqual(t, 0, t.getRowCount(), found.getTable(), false, false); assertEquals(found.getTable(), found.getContiguousTable().getTable()); assertNotNull(found.getContiguousTable().getBuffer()); } @@ -3355,7 +3356,7 @@ void testSerializationRoundTripSliced() throws IOException { JCudfSerialization.writeToStream(t, bout, i, len); ByteArrayInputStream bin = new ByteArrayInputStream(bout.toByteArray()); try (JCudfSerialization.TableAndRowCountPair found = JCudfSerialization.readTableFrom(bin)) { - assertPartialTablesAreEqual(t, i, len, found.getTable(), i == 0 && len == t.getRowCount()); + assertPartialTablesAreEqual(t, i, len, found.getTable(), i == 0 && len == t.getRowCount(), false); assertEquals(found.getTable(), found.getContiguousTable().getTable()); assertNotNull(found.getContiguousTable().getBuffer()); } @@ -6360,6 +6361,121 @@ void testAllFilteredFromValidity() { } } + ColumnView replaceValidity(ColumnView cv, DeviceMemoryBuffer validity, long nullCount) { + assert (validity.length >= BitVectorHelper.getValidityAllocationSizeInBytes(cv.rows)); + if (cv.type.isNestedType()) { + ColumnView[] children = cv.getChildColumnViews(); + try { + return new ColumnView(cv.type, + cv.rows, + Optional.of(nullCount), + validity, + cv.getOffsets(), + children); + } finally { + for (ColumnView v : children) { + if (v != null) { + v.close(); + } + } + } + } else { + return new ColumnView(cv.type, cv.rows, Optional.of(nullCount), cv.getData(), validity, cv.getOffsets()); + } + } + + @Test + void testRemoveNullMasksIfNeeded() { + ListType nestedType = new ListType(true, new StructType(false, + new BasicType(true, DType.INT32), + new BasicType(true, DType.INT64))); + + List data1 = Arrays.asList(10, 20L); + List data2 = Arrays.asList(50, 60L); + HostColumnVector.StructData structData1 = new HostColumnVector.StructData(data1); + HostColumnVector.StructData structData2 = new HostColumnVector.StructData(data2); + + //First we create ColumnVectors + try (ColumnVector nonNullVector0 = ColumnVector.fromBoxedInts(1, 2, 3); + ColumnVector nonNullVector2 = ColumnVector.fromStrings("1", "2", "3"); + ColumnVector nonNullVector1 = ColumnVector.fromLists(nestedType, + Arrays.asList(structData1, structData2), + Arrays.asList(structData1, structData2), + Arrays.asList(structData1, structData2))) { + //Then we take the created ColumnVectors and add validity masks even though the nullCount = 0 + long allocSize = BitVectorHelper.getValidityAllocationSizeInBytes(nonNullVector0.rows); + try (DeviceMemoryBuffer dm0 = DeviceMemoryBuffer.allocate(allocSize); + DeviceMemoryBuffer dm1 = DeviceMemoryBuffer.allocate(allocSize); + DeviceMemoryBuffer dm2 = DeviceMemoryBuffer.allocate(allocSize); + DeviceMemoryBuffer dm3_child = + DeviceMemoryBuffer.allocate(BitVectorHelper.getValidityAllocationSizeInBytes(2))) { + Cuda.memset(dm0.address, (byte) 0xFF, allocSize); + Cuda.memset(dm1.address, (byte) 0xFF, allocSize); + Cuda.memset(dm2.address, (byte) 0xFF, allocSize); + Cuda.memset(dm3_child.address, (byte) 0xFF, + BitVectorHelper.getValidityAllocationSizeInBytes(2)); + + try (ColumnView cv0View = replaceValidity(nonNullVector0, dm0, 0); + ColumnVector cv0 = cv0View.copyToColumnVector(); + ColumnView struct = nonNullVector1.getChildColumnView(0); + ColumnView structChild0 = struct.getChildColumnView(0); + ColumnView newStructChild0 = replaceValidity(structChild0, dm3_child, 0); + ColumnView newStruct = struct.replaceChildrenWithViews(new int[]{0}, new ColumnView[]{newStructChild0}); + ColumnView list = nonNullVector1.replaceChildrenWithViews(new int[]{0}, new ColumnView[]{newStruct}); + ColumnView cv1View = replaceValidity(list, dm1, 0); + ColumnVector cv1 = cv1View.copyToColumnVector(); + ColumnView cv2View = replaceValidity(nonNullVector2, dm2, 0); + ColumnVector cv2 = cv2View.copyToColumnVector()) { + + try (Table t = new Table(new ColumnVector[]{cv0, cv1, cv2}); + Table tableWithoutNullMask = removeNullMasksIfNeeded(t); + ColumnView tableStructChild0 = t.getColumn(1).getChildColumnView(0).getChildColumnView(0); + ColumnVector tableStructChild0Cv = tableStructChild0.copyToColumnVector(); + Table expected = new Table(new ColumnVector[]{nonNullVector0, nonNullVector1, + nonNullVector2})) { + assertTrue(t.getColumn(0).hasValidityVector()); + assertTrue(t.getColumn(1).hasValidityVector()); + assertTrue(t.getColumn(2).hasValidityVector()); + assertTrue(tableStructChild0Cv.hasValidityVector()); + + assertPartialTablesAreEqual(expected, + 0, + expected.getRowCount(), + tableWithoutNullMask, + true, + true); + } + } + } + } + } + + @Test + void testRemoveNullMasksIfNeededWithNulls() { + ListType nestedType = new ListType(true, new StructType(true, + new BasicType(true, DType.INT32), + new BasicType(true, DType.INT64))); + + List data1 = Arrays.asList(0, 10L); + List data2 = Arrays.asList(50, null); + HostColumnVector.StructData structData1 = new HostColumnVector.StructData(data1); + HostColumnVector.StructData structData2 = new HostColumnVector.StructData(data2); + + //First we create ColumnVectors + try (ColumnVector nonNullVector0 = ColumnVector.fromBoxedInts(1, null, 2, 3); + ColumnVector nonNullVector1 = ColumnVector.fromStrings("1", "2", null, "3"); + ColumnVector nonNullVector2 = ColumnVector.fromLists(nestedType, + Arrays.asList(structData1, structData2), + null, + Arrays.asList(structData1, structData2), + Arrays.asList(structData1, structData2))) { + try (Table expected = new Table(new ColumnVector[]{nonNullVector0, nonNullVector1, nonNullVector2}); + Table unchangedTable = removeNullMasksIfNeeded(expected)) { + assertTablesAreEqual(expected, unchangedTable); + } + } + } + @Test void testMismatchedSizesForFilter() { Boolean[] maskVals = new Boolean[3];