diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 0a7346d1cbc..7db40278d4e 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -307,7 +307,7 @@ public final int getNumChildren() { * Returns the amount of device memory used. */ public long getDeviceMemorySize() { - return getDeviceMemorySize(getNativeView()); + return getDeviceMemorySize(getNativeView(), false); } @Override @@ -4789,7 +4789,7 @@ static native long makeCudfColumnView(int type, int scale, long data, long dataS static native int getNativeNumChildren(long viewHandle) throws CudfException; // calculate the amount of device memory used by this column including any child columns - static native long getDeviceMemorySize(long viewHandle) throws CudfException; + static native long getDeviceMemorySize(long viewHandle, boolean shouldPadForCpu) throws CudfException; static native long copyColumnViewToCV(long viewHandle) throws CudfException; @@ -5160,6 +5160,19 @@ public HostColumnVector copyToHost() { } } + /** + * Calculate the total space required to copy the data to the host. This should be padded to + * the alignment that the CPU requires. + */ + public long getHostBytesRequired() { + return getDeviceMemorySize(getNativeView(), true); + } + + /** + * Get the size that the host will align memory allocations to in bytes. + */ + public static native long hostPaddingSizeInBytes(); + /** * Exact check if a column or its descendants have non-empty null rows * diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 1cb51a22bf3..d5aad03645f 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -91,22 +91,32 @@ using cudf::jni::release_as_jlong; namespace { -std::size_t calc_device_memory_size(cudf::column_view const &view) { +std::size_t pad_size(std::size_t size, bool const should_pad_for_cpu) { + if (should_pad_for_cpu) { + constexpr std::size_t ALIGN = sizeof(std::max_align_t); + return (size + (ALIGN - 1)) & ~(ALIGN - 1); + } else { + return size; + } +} + +std::size_t calc_device_memory_size(cudf::column_view const &view, bool const pad_for_cpu) { std::size_t total = 0; auto row_count = view.size(); if (view.nullable()) { - total += cudf::bitmask_allocation_size_bytes(row_count); + total += pad_size(cudf::bitmask_allocation_size_bytes(row_count), pad_for_cpu); } auto dtype = view.type(); if (cudf::is_fixed_width(dtype)) { - total += cudf::size_of(dtype) * view.size(); + total += pad_size(cudf::size_of(dtype) * view.size(), pad_for_cpu); } - return std::accumulate( - view.child_begin(), view.child_end(), total, - [](std::size_t t, cudf::column_view const &v) { return t + calc_device_memory_size(v); }); + return std::accumulate(view.child_begin(), view.child_end(), total, + [pad_for_cpu](std::size_t t, cudf::column_view const &v) { + return t + calc_device_memory_size(v, pad_for_cpu); + }); } } // anonymous namespace @@ -2217,16 +2227,21 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_getNativeValidityLength(J } JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_getDeviceMemorySize(JNIEnv *env, jclass, - jlong handle) { + jlong handle, + jboolean pad_for_cpu) { JNI_NULL_CHECK(env, handle, "native handle is null", 0); try { cudf::jni::auto_set_device(env); auto view = reinterpret_cast(handle); - return calc_device_memory_size(*view); + return calc_device_memory_size(*view, pad_for_cpu); } CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_hostPaddingSizeInBytes(JNIEnv *env, jclass) { + return sizeof(std::max_align_t); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_clamper(JNIEnv *env, jobject j_object, jlong handle, jlong j_lo_scalar, jlong j_lo_replace_scalar, diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 0e1fbad6129..1062a765800 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -1026,21 +1026,36 @@ void decimal128Cv() { } } + static final long HOST_ALIGN_BYTES = ColumnView.hostPaddingSizeInBytes(); + + static void assertHostAligned(long expectedDeviceSize, ColumnView cv) { + long deviceSize = cv.getDeviceMemorySize(); + assertEquals(expectedDeviceSize, deviceSize); + long hostSize = cv.getHostBytesRequired(); + assert(hostSize >= deviceSize); + long roundedHostSize = (hostSize / HOST_ALIGN_BYTES) * HOST_ALIGN_BYTES; + assertEquals(hostSize, roundedHostSize, "The host size should be a multiple of " + + HOST_ALIGN_BYTES); + } + @Test void testGetDeviceMemorySizeNonStrings() { try (ColumnVector v0 = ColumnVector.fromBoxedInts(1, 2, 3, 4, 5, 6); ColumnVector v1 = ColumnVector.fromBoxedInts(1, 2, 3, null, null, 4, 5, 6)) { - assertEquals(24, v0.getDeviceMemorySize()); // (6*4B) - assertEquals(96, v1.getDeviceMemorySize()); // (8*4B) + 64B(for validity vector) + assertHostAligned(24, v0); // (6*4B) + assertHostAligned(96, v1); // (8*4B) + 64B(for validity vector) } } @Test void testGetDeviceMemorySizeStrings() { + if (ColumnView.hostPaddingSizeInBytes() != 8) { + System.err.println("HOST PADDING SIZE: " + ColumnView.hostPaddingSizeInBytes()); + } try (ColumnVector v0 = ColumnVector.fromStrings("onetwothree", "four", "five"); ColumnVector v1 = ColumnVector.fromStrings("onetwothree", "four", null, "five")) { - assertEquals(35, v0.getDeviceMemorySize()); //19B data + 4*4B offsets = 35 - assertEquals(103, v1.getDeviceMemorySize()); //19B data + 5*4B + 64B validity vector = 103B + assertHostAligned(35, v0); //19B data + 4*4B offsets = 35 + assertHostAligned(103, v1); //19B data + 5*4B + 64B validity vector = 103B } } @@ -1061,13 +1076,13 @@ void testGetDeviceMemorySizeLists() { // 64 bytes for validity of list column // 16 bytes for offsets of list column // 64 bytes for validity of string column - // 24 bytes for offsets of of string column + // 24 bytes for offsets of string column // 22 bytes of string character size - assertEquals(64+16+64+24+22, sv.getDeviceMemorySize()); + assertHostAligned(64+16+64+24+22, sv); // 20 bytes for offsets of list column // 28 bytes for data of INT32 column - assertEquals(20+28, iv.getDeviceMemorySize()); + assertHostAligned(20+28, iv); } } @@ -1091,11 +1106,11 @@ void testGetDeviceMemorySizeStructs() { // 64 bytes for validity of list column // 20 bytes for offsets of list column // 64 bytes for validity of string column - // 28 bytes for offsets of of string column + // 28 bytes for offsets of string column // 22 bytes of string character size // 64 bytes for validity of int64 column // 28 bytes for data of the int64 column - assertEquals(64+64+20+64+28+22+64+28, v.getDeviceMemorySize()); + assertHostAligned(64+64+20+64+28+22+64+28, v); } }