diff --git a/CHANGELOG.md b/CHANGELOG.md index 1bd964756e4..2d311431a55 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -82,6 +82,7 @@ - PR #6748 Add Java API to concatenate serialized tables to ContiguousTable - PR #6734 Binary operations support for decimal type in cudf Java - PR #6761 Add Java/JNI bindings for round +- PR #6786 Add nested type support to ColumnVector#getDeviceMemorySize - PR #6780 Move `cudf::cast` tests to separate test file ## Bug Fixes diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index b295b104936..cd785f550a5 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -322,7 +322,7 @@ public long getRowCount() { * Returns the amount of device memory used. */ public long getDeviceMemorySize() { - return offHeap != null ? offHeap.getDeviceMemorySize() : 0; + return getDeviceMemorySize(getNativeView()); } /** @@ -3174,6 +3174,9 @@ static native long makeCudfColumnView(int type, int scale, long data, long dataS private static native int getNativeNumChildren(long viewHandle) throws CudfException; + // calculate the amount of device memory used by this column including any child columns + private static native long getDeviceMemorySize(long viewHandle) throws CudfException; + //////// // Native methods specific to cudf::column. These either take or create a cudf::column // instead of a cudf::column_view so they need to be used with caution. These should @@ -3504,20 +3507,6 @@ protected boolean cleanImpl(boolean logErrorIfNotClean) { public boolean isClean() { return viewHandle == 0 && columnHandle == 0 && toClose.isEmpty(); } - - /** - * This returns total memory allocated in device for the ColumnVector. - * @return number of device bytes allocated for this column - */ - public long getDeviceMemorySize() { - BaseDeviceMemoryBuffer valid = getValid(); - BaseDeviceMemoryBuffer data = getData(); - BaseDeviceMemoryBuffer offsets = getOffsets(); - long size = valid != null ? valid.getLength() : 0; - size += offsets != null ? offsets.getLength() : 0; - size += data != null ? data.getLength() : 0; - return size; - } } public static ColumnVector createNestedColumnVector(DType type, int rows, HostMemoryBuffer data, HostMemoryBuffer valid, HostMemoryBuffer offsets, diff --git a/java/src/main/native/src/ColumnVectorJni.cpp b/java/src/main/native/src/ColumnVectorJni.cpp index 06e534b26f4..0c9158f10fc 100644 --- a/java/src/main/native/src/ColumnVectorJni.cpp +++ b/java/src/main/native/src/ColumnVectorJni.cpp @@ -14,6 +14,8 @@ * limitations under the License. */ +#include + #include #include #include @@ -24,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -59,6 +62,28 @@ #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" +namespace { + +std::size_t calc_device_memory_size(cudf::column_view const &view) { + std::size_t total = 0; + auto row_count = view.size(); + + if (view.nullable()) { + total += cudf::bitmask_allocation_size_bytes(row_count); + } + + auto dtype = view.type(); + if (cudf::is_fixed_width(dtype)) { + total += cudf::size_of(dtype) * view.size(); + } + + return std::accumulate(view.child_begin(), view.child_end(), total, + [](std::size_t t, cudf::column_view const &v) { + return t + calc_device_memory_size(v); + }); +} + +} // anonymous namespace extern "C" { @@ -1604,6 +1629,17 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_getNativeValidPointerSi CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_getDeviceMemorySize(JNIEnv *env, jclass, + jlong handle) { + JNI_NULL_CHECK(env, handle, "native handle is null", 0); + try { + cudf::jni::auto_set_device(env); + auto view = reinterpret_cast(handle); + return calc_device_memory_size(*view); + } + CATCH_STD(env, 0); +} + //////// // Native methods specific to cudf::column. These either take or return a cudf::column // instead of a cudf::column_view so they need to be used with caution. These should diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 08ee8034471..4b358ff55c7 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -18,6 +18,12 @@ package ai.rapids.cudf; +import ai.rapids.cudf.HostColumnVector.BasicType; +import ai.rapids.cudf.HostColumnVector.DataType; +import ai.rapids.cudf.HostColumnVector.ListType; +import ai.rapids.cudf.HostColumnVector.StructData; +import ai.rapids.cudf.HostColumnVector.StructType; + import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; @@ -25,6 +31,7 @@ import java.math.RoundingMode; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -36,9 +43,14 @@ import static ai.rapids.cudf.QuantileMethod.MIDPOINT; import static ai.rapids.cudf.QuantileMethod.NEAREST; import static ai.rapids.cudf.TableTest.assertColumnsAreEqual; -import static ai.rapids.cudf.TableTest.assertTablesAreEqual; import static ai.rapids.cudf.TableTest.assertStructColumnsAreEqual; -import static org.junit.jupiter.api.Assertions.*; +import static ai.rapids.cudf.TableTest.assertTablesAreEqual; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assumptions.assumeTrue; public class ColumnVectorTest extends CudfTestBase { @@ -695,6 +707,61 @@ void testGetDeviceMemorySizeStrings() { } } + @SuppressWarnings("unchecked") + @Test + void testGetDeviceMemorySizeLists() { + DataType svType = new ListType(true, new BasicType(true, DType.STRING)); + DataType ivType = new ListType(false, new BasicType(false, DType.INT32)); + try (ColumnVector sv = ColumnVector.fromLists(svType, + Arrays.asList("first", "second", "third"), + Arrays.asList("fourth", null), + null); + ColumnVector iv = ColumnVector.fromLists(ivType, + Arrays.asList(1, 2, 3), + Collections.singletonList(4), + Arrays.asList(5, 6), + Collections.singletonList(7))) { + // 64 bytes for validity of list column + // 16 bytes for offsets of list column + // 64 bytes for validity of string column + // 24 bytes for offsets of of string column + // 22 bytes of string character size + assertEquals(64+16+64+24+22, sv.getDeviceMemorySize()); + + // 20 bytes for offsets of list column + // 28 bytes for data of INT32 column + assertEquals(20+28, iv.getDeviceMemorySize()); + } + } + + @Test + void testGetDeviceMemorySizeStructs() { + DataType structType = new StructType(true, + new ListType(true, new BasicType(true, DType.STRING)), + new BasicType(true, DType.INT64)); + try (ColumnVector v = ColumnVector.fromStructs(structType, + new StructData( + Arrays.asList("first", "second", "third"), + 10L), + new StructData( + Arrays.asList("fourth", null), + 20L), + new StructData( + null, + null), + null)) { + // 64 bytes for validity of the struct column + // 64 bytes for validity of list column + // 20 bytes for offsets of list column + // 64 bytes for validity of string column + // 28 bytes for offsets of of string column + // 22 bytes of string character size + // 64 bytes for validity of int64 column + // 28 bytes for data of the int64 column + assertEquals(64+64+20+64+28+22+64+28, v.getDeviceMemorySize()); + } + } + @Test void testSequenceInt() { try (Scalar zero = Scalar.fromInt(0);