diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 6e0b7d3bb94..ca5abc4b55c 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -514,6 +514,8 @@ private static native long[] repeatColumnCount(long tableHandle, long columnHandle, boolean checkCount); + private static native long rowBitCount(long tableHandle) throws CudfException; + private static native long[] explode(long tableHandle, int index); private static native long[] explodePosition(long tableHandle, int index); @@ -1891,6 +1893,28 @@ public Table explodeOuterPosition(int index) { return new Table(explodeOuterPosition(nativeHandle, index)); } + /** + * Returns an approximate cumulative size in bits of all columns in the `table_view` for each row. + * This function counts bits instead of bytes to account for the null mask which only has one + * bit per row. Each row in the returned column is the sum of the per-row bit size for each column + * in the table. + * + * In some cases, this is an inexact approximation. Specifically, columns of lists and strings + * require N+1 offsets to represent N rows. It is up to the caller to calculate the small + * additional overhead of the terminating offset for any group of rows being considered. + * + * This function returns the per-row bit sizes as the columns are currently formed. This can + * end up being larger than the number you would get by gathering the rows. Specifically, + * the push-down of struct column validity masks can nullify rows that contain data for + * string or list columns. In these cases, the size returned is conservative such that: + * row_bit_count(column(x)) >= row_bit_count(gather(column(x))) + * + * @return INT32 column of bit size per row of the table + */ + public ColumnVector rowBitCount() { + return new ColumnVector(rowBitCount(getNativeView())); + } + /** * Gathers the rows of this table according to `gatherMap` such that row "i" * in the resulting table's columns will contain row "gatherMap[i]" from this table. diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 6beedf54f5a..f67581f000f 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -2180,4 +2180,15 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_explodeOuterPosition(JNIE CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_rowBitCount(JNIEnv* env, jclass, jlong j_table) { + JNI_NULL_CHECK(env, j_table, "table is null", 0); + try { + cudf::jni::auto_set_device(env); + auto t = reinterpret_cast(j_table); + std::unique_ptr result = cudf::row_bit_count(*t); + return reinterpret_cast(result.release()); + } + CATCH_STD(env, 0); +} + } // extern "C" diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index b6350a207c1..2ab9d7486fe 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -3968,6 +3968,32 @@ void testGroupByNoAggs() { } } + @Test + void testRowBitCount() { + try (Table t = new Table.TestBuilder() + .column(0, 1, null, 3) // 33 bits per row (4 bytes + valid bit) + .column(0.0, null, 2.0, 3.0) // 65 bits per row (8 bytes + valid bit) + .column("zero", null, "two", "three") // 33 bits (4 byte offset + valid bit) + char bits + .build(); + ColumnVector expected = ColumnVector.fromInts(163, 131, 155, 171); + ColumnVector actual = t.rowBitCount()) { + assertColumnsAreEqual(expected, actual); + } + } + + @Test + void testRowBitCountEmpty() { + try (Table t = new Table.TestBuilder() + .column(new Integer[0]) + .column(new Double[0]) + .column(new String[0]) + .build(); + ColumnVector c = t.rowBitCount()) { + assertEquals(DType.INT32, c.getType()); + assertEquals(0, c.getRowCount()); + } + } + @Test void testSimpleGather() { try (Table testTable = new Table.TestBuilder()