Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add java API to get size of host memory needed to copy column view #13919

Merged
merged 6 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,15 @@ public final int getNumChildren() {
* Returns the amount of device memory used.
*/
public long getDeviceMemorySize() {
return getDeviceMemorySize(getNativeView());
return getDeviceMemorySize(getNativeView(), false);
}

/**
* Returns the amount of memory used by this, but padded for 64-bit alignment. This makes it
* so it could be used as the amount of memory needed to copy the data to the host.
*/
public long getDeviceMemorySizeAligned() {
revans2 marked this conversation as resolved.
Show resolved Hide resolved
return getDeviceMemorySize(getNativeView(), true);
}

@Override
Expand Down Expand Up @@ -4789,7 +4797,7 @@ static native long makeCudfColumnView(int type, int scale, long data, long dataS
static native int getNativeNumChildren(long viewHandle) throws CudfException;

// calculate the amount of device memory used by this column including any child columns
static native long getDeviceMemorySize(long viewHandle) throws CudfException;
static native long getDeviceMemorySize(long viewHandle, boolean aligned) throws CudfException;

static native long copyColumnViewToCV(long viewHandle) throws CudfException;

Expand Down Expand Up @@ -5160,6 +5168,13 @@ public HostColumnVector copyToHost() {
}
}

/**
* Calculate the total space required to copy the data to the host.
*/
public long getHostBytesRequired() {
razajafri marked this conversation as resolved.
Show resolved Hide resolved
return getDeviceMemorySizeAligned();
}

/**
* Exact check if a column or its descendants have non-empty null rows
*
Expand Down
22 changes: 16 additions & 6 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,22 +91,31 @@ using cudf::jni::release_as_jlong;

namespace {

std::size_t calc_device_memory_size(cudf::column_view const &view) {
std::size_t align_size(std::size_t size, bool const do_it) {
revans2 marked this conversation as resolved.
Show resolved Hide resolved
if (do_it) {
constexpr std::size_t ALIGN = 1 << 6; // 64-bit alignment
return (size + (ALIGN - 1)) & ~(ALIGN - 1);
} else {
return size;
}
}

std::size_t calc_device_memory_size(cudf::column_view const &view, bool const aligned) {
std::size_t total = 0;
auto row_count = view.size();

if (view.nullable()) {
total += cudf::bitmask_allocation_size_bytes(row_count);
total += align_size(cudf::bitmask_allocation_size_bytes(row_count), aligned);
}

auto dtype = view.type();
if (cudf::is_fixed_width(dtype)) {
total += cudf::size_of(dtype) * view.size();
total += align_size(cudf::size_of(dtype) * view.size(), aligned);
}

return std::accumulate(
view.child_begin(), view.child_end(), total,
[](std::size_t t, cudf::column_view const &v) { return t + calc_device_memory_size(v); });
[aligned](std::size_t t, cudf::column_view const &v) { return t + calc_device_memory_size(v, aligned); });
}

} // anonymous namespace
Expand Down Expand Up @@ -2217,12 +2226,13 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_getNativeValidityLength(J
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_getDeviceMemorySize(JNIEnv *env, jclass,
jlong handle) {
jlong handle,
jboolean aligned) {
JNI_NULL_CHECK(env, handle, "native handle is null", 0);
try {
cudf::jni::auto_set_device(env);
auto view = reinterpret_cast<cudf::column_view const *>(handle);
return calc_device_memory_size(*view);
return calc_device_memory_size(*view, aligned);
}
CATCH_STD(env, 0);
}
Expand Down
8 changes: 8 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,9 @@ void testGetDeviceMemorySizeNonStrings() {
try (ColumnVector v0 = ColumnVector.fromBoxedInts(1, 2, 3, 4, 5, 6);
ColumnVector v1 = ColumnVector.fromBoxedInts(1, 2, 3, null, null, 4, 5, 6)) {
assertEquals(24, v0.getDeviceMemorySize()); // (6*4B)
assertEquals(64, v0.getHostBytesRequired()); // account for alignment padding
razajafri marked this conversation as resolved.
Show resolved Hide resolved
assertEquals(96, v1.getDeviceMemorySize()); // (8*4B) + 64B(for validity vector)
assertEquals(64 + 64, v1.getHostBytesRequired());
}
}

Expand All @@ -1040,7 +1042,9 @@ void testGetDeviceMemorySizeStrings() {
try (ColumnVector v0 = ColumnVector.fromStrings("onetwothree", "four", "five");
ColumnVector v1 = ColumnVector.fromStrings("onetwothree", "four", null, "five")) {
assertEquals(35, v0.getDeviceMemorySize()); //19B data + 4*4B offsets = 35
assertEquals(64 + 64, v0.getHostBytesRequired()); // account for alignment padding
assertEquals(103, v1.getDeviceMemorySize()); //19B data + 5*4B + 64B validity vector = 103B
assertEquals(64+64+64, v1.getHostBytesRequired()); // account for alignment padding
}
}

Expand All @@ -1064,10 +1068,12 @@ void testGetDeviceMemorySizeLists() {
// 24 bytes for offsets of of string column
// 22 bytes of string character size
assertEquals(64+16+64+24+22, sv.getDeviceMemorySize());
assertEquals(64+64+64+64+64, sv.getHostBytesRequired()); // account for alignment padding

// 20 bytes for offsets of list column
// 28 bytes for data of INT32 column
assertEquals(20+28, iv.getDeviceMemorySize());
assertEquals(64+64, iv.getHostBytesRequired()); // account for alignment padding
}
}

Expand Down Expand Up @@ -1096,6 +1102,8 @@ void testGetDeviceMemorySizeStructs() {
// 64 bytes for validity of int64 column
// 28 bytes for data of the int64 column
assertEquals(64+64+20+64+28+22+64+28, v.getDeviceMemorySize());
// account for alignment padding
assertEquals(64+64+64+64+64+64+64+64, v.getHostBytesRequired());
}
}

Expand Down