From 599f62d1f3aea59fd6429911bfeb394349428c83 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Mon, 29 Mar 2021 21:10:12 -0500 Subject: [PATCH] Add Java bindings for row_bit_count (#7749) Adds Java bindings for `cudf::row_bit_count`. This depends on #7534. Authors: - Jason Lowe (@jlowe) Approvers: - Robert (Bobby) Evans (@revans2) URL: https://github.com/rapidsai/cudf/pull/7749 --- java/src/main/java/ai/rapids/cudf/Table.java | 24 +++++++++++++++++ java/src/main/native/src/TableJni.cpp | 11 ++++++++ .../test/java/ai/rapids/cudf/TableTest.java | 26 +++++++++++++++++++ 3 files changed, 61 insertions(+) diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index fc6ad55044a..8f256987dd2 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -529,6 +529,8 @@ private static native long[] repeatColumnCount(long tableHandle, long columnHandle, boolean checkCount); + private static native long rowBitCount(long tableHandle) throws CudfException; + private static native long[] explode(long tableHandle, int index); private static native long[] explodePosition(long tableHandle, int index); @@ -1906,6 +1908,28 @@ public Table explodeOuterPosition(int index) { return new Table(explodeOuterPosition(nativeHandle, index)); } + /** + * Returns an approximate cumulative size in bits of all columns in the `table_view` for each row. + * This function counts bits instead of bytes to account for the null mask which only has one + * bit per row. Each row in the returned column is the sum of the per-row bit size for each column + * in the table. + * + * In some cases, this is an inexact approximation. Specifically, columns of lists and strings + * require N+1 offsets to represent N rows. It is up to the caller to calculate the small + * additional overhead of the terminating offset for any group of rows being considered. + * + * This function returns the per-row bit sizes as the columns are currently formed. This can + * end up being larger than the number you would get by gathering the rows. Specifically, + * the push-down of struct column validity masks can nullify rows that contain data for + * string or list columns. In these cases, the size returned is conservative such that: + * row_bit_count(column(x)) >= row_bit_count(gather(column(x))) + * + * @return INT32 column of bit size per row of the table + */ + public ColumnVector rowBitCount() { + return new ColumnVector(rowBitCount(getNativeView())); + } + /** * Gathers the rows of this table according to `gatherMap` such that row "i" * in the resulting table's columns will contain row "gatherMap[i]" from this table. diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 0e66cde3ee1..346ae8435cc 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -2366,4 +2366,15 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_explodeOuterPosition(JNIE CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_rowBitCount(JNIEnv* env, jclass, jlong j_table) { + JNI_NULL_CHECK(env, j_table, "table is null", 0); + try { + cudf::jni::auto_set_device(env); + auto t = reinterpret_cast(j_table); + std::unique_ptr result = cudf::row_bit_count(*t); + return reinterpret_cast(result.release()); + } + CATCH_STD(env, 0); +} + } // extern "C" diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index ac71f96d3c3..9c67966c16c 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -4344,6 +4344,32 @@ void testGroupByNoAggs() { } } + @Test + void testRowBitCount() { + try (Table t = new Table.TestBuilder() + .column(0, 1, null, 3) // 33 bits per row (4 bytes + valid bit) + .column(0.0, null, 2.0, 3.0) // 65 bits per row (8 bytes + valid bit) + .column("zero", null, "two", "three") // 33 bits (4 byte offset + valid bit) + char bits + .build(); + ColumnVector expected = ColumnVector.fromInts(163, 131, 155, 171); + ColumnVector actual = t.rowBitCount()) { + assertColumnsAreEqual(expected, actual); + } + } + + @Test + void testRowBitCountEmpty() { + try (Table t = new Table.TestBuilder() + .column(new Integer[0]) + .column(new Double[0]) + .column(new String[0]) + .build(); + ColumnVector c = t.rowBitCount()) { + assertEquals(DType.INT32, c.getType()); + assertEquals(0, c.getRowCount()); + } + } + @Test void testSimpleGather() { try (Table testTable = new Table.TestBuilder()