Skip to content

Commit

Permalink
Add java API to get size of host memory needed to copy column view (#…
Browse files Browse the repository at this point in the history
…13919)

To help with work for host memory management in java this provides an API to know how much memory is needed on the host to copy the data before it happens.

This was written by @jbrennan333 but I am taking over the patch to get it in.

Authors:
  - Robert (Bobby) Evans (https://github.com/revans2)

Approvers:
  - Nghia Truong (https://github.com/ttnghia)
  - Gera Shegalov (https://github.com/gerashegalov)
  - Raza Jafri (https://github.com/razajafri)

URL: #13919
  • Loading branch information
revans2 authored Aug 22, 2023
1 parent 261bcb2 commit 595308b
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 19 deletions.
17 changes: 15 additions & 2 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ public final int getNumChildren() {
* Returns the amount of device memory used.
*/
public long getDeviceMemorySize() {
return getDeviceMemorySize(getNativeView());
return getDeviceMemorySize(getNativeView(), false);
}

@Override
Expand Down Expand Up @@ -4789,7 +4789,7 @@ static native long makeCudfColumnView(int type, int scale, long data, long dataS
static native int getNativeNumChildren(long viewHandle) throws CudfException;

// calculate the amount of device memory used by this column including any child columns
static native long getDeviceMemorySize(long viewHandle) throws CudfException;
static native long getDeviceMemorySize(long viewHandle, boolean shouldPadForCpu) throws CudfException;

static native long copyColumnViewToCV(long viewHandle) throws CudfException;

Expand Down Expand Up @@ -5160,6 +5160,19 @@ public HostColumnVector copyToHost() {
}
}

/**
* Calculate the total space required to copy the data to the host. This should be padded to
* the alignment that the CPU requires.
*/
public long getHostBytesRequired() {
return getDeviceMemorySize(getNativeView(), true);
}

/**
* Get the size that the host will align memory allocations to in bytes.
*/
public static native long hostPaddingSizeInBytes();

/**
* Exact check if a column or its descendants have non-empty null rows
*
Expand Down
31 changes: 23 additions & 8 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,22 +91,32 @@ using cudf::jni::release_as_jlong;

namespace {

std::size_t calc_device_memory_size(cudf::column_view const &view) {
std::size_t pad_size(std::size_t size, bool const should_pad_for_cpu) {
if (should_pad_for_cpu) {
constexpr std::size_t ALIGN = sizeof(std::max_align_t);
return (size + (ALIGN - 1)) & ~(ALIGN - 1);
} else {
return size;
}
}

std::size_t calc_device_memory_size(cudf::column_view const &view, bool const pad_for_cpu) {
std::size_t total = 0;
auto row_count = view.size();

if (view.nullable()) {
total += cudf::bitmask_allocation_size_bytes(row_count);
total += pad_size(cudf::bitmask_allocation_size_bytes(row_count), pad_for_cpu);
}

auto dtype = view.type();
if (cudf::is_fixed_width(dtype)) {
total += cudf::size_of(dtype) * view.size();
total += pad_size(cudf::size_of(dtype) * view.size(), pad_for_cpu);
}

return std::accumulate(
view.child_begin(), view.child_end(), total,
[](std::size_t t, cudf::column_view const &v) { return t + calc_device_memory_size(v); });
return std::accumulate(view.child_begin(), view.child_end(), total,
[pad_for_cpu](std::size_t t, cudf::column_view const &v) {
return t + calc_device_memory_size(v, pad_for_cpu);
});
}

} // anonymous namespace
Expand Down Expand Up @@ -2217,16 +2227,21 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_getNativeValidityLength(J
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_getDeviceMemorySize(JNIEnv *env, jclass,
jlong handle) {
jlong handle,
jboolean pad_for_cpu) {
JNI_NULL_CHECK(env, handle, "native handle is null", 0);
try {
cudf::jni::auto_set_device(env);
auto view = reinterpret_cast<cudf::column_view const *>(handle);
return calc_device_memory_size(*view);
return calc_device_memory_size(*view, pad_for_cpu);
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_hostPaddingSizeInBytes(JNIEnv *env, jclass) {
return sizeof(std::max_align_t);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_clamper(JNIEnv *env, jobject j_object,
jlong handle, jlong j_lo_scalar,
jlong j_lo_replace_scalar,
Expand Down
33 changes: 24 additions & 9 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1026,21 +1026,36 @@ void decimal128Cv() {
}
}

static final long HOST_ALIGN_BYTES = ColumnView.hostPaddingSizeInBytes();

static void assertHostAligned(long expectedDeviceSize, ColumnView cv) {
long deviceSize = cv.getDeviceMemorySize();
assertEquals(expectedDeviceSize, deviceSize);
long hostSize = cv.getHostBytesRequired();
assert(hostSize >= deviceSize);
long roundedHostSize = (hostSize / HOST_ALIGN_BYTES) * HOST_ALIGN_BYTES;
assertEquals(hostSize, roundedHostSize, "The host size should be a multiple of " +
HOST_ALIGN_BYTES);
}

@Test
void testGetDeviceMemorySizeNonStrings() {
try (ColumnVector v0 = ColumnVector.fromBoxedInts(1, 2, 3, 4, 5, 6);
ColumnVector v1 = ColumnVector.fromBoxedInts(1, 2, 3, null, null, 4, 5, 6)) {
assertEquals(24, v0.getDeviceMemorySize()); // (6*4B)
assertEquals(96, v1.getDeviceMemorySize()); // (8*4B) + 64B(for validity vector)
assertHostAligned(24, v0); // (6*4B)
assertHostAligned(96, v1); // (8*4B) + 64B(for validity vector)
}
}

@Test
void testGetDeviceMemorySizeStrings() {
if (ColumnView.hostPaddingSizeInBytes() != 8) {
System.err.println("HOST PADDING SIZE: " + ColumnView.hostPaddingSizeInBytes());
}
try (ColumnVector v0 = ColumnVector.fromStrings("onetwothree", "four", "five");
ColumnVector v1 = ColumnVector.fromStrings("onetwothree", "four", null, "five")) {
assertEquals(35, v0.getDeviceMemorySize()); //19B data + 4*4B offsets = 35
assertEquals(103, v1.getDeviceMemorySize()); //19B data + 5*4B + 64B validity vector = 103B
assertHostAligned(35, v0); //19B data + 4*4B offsets = 35
assertHostAligned(103, v1); //19B data + 5*4B + 64B validity vector = 103B
}
}

Expand All @@ -1061,13 +1076,13 @@ void testGetDeviceMemorySizeLists() {
// 64 bytes for validity of list column
// 16 bytes for offsets of list column
// 64 bytes for validity of string column
// 24 bytes for offsets of of string column
// 24 bytes for offsets of string column
// 22 bytes of string character size
assertEquals(64+16+64+24+22, sv.getDeviceMemorySize());
assertHostAligned(64+16+64+24+22, sv);

// 20 bytes for offsets of list column
// 28 bytes for data of INT32 column
assertEquals(20+28, iv.getDeviceMemorySize());
assertHostAligned(20+28, iv);
}
}

Expand All @@ -1091,11 +1106,11 @@ void testGetDeviceMemorySizeStructs() {
// 64 bytes for validity of list column
// 20 bytes for offsets of list column
// 64 bytes for validity of string column
// 28 bytes for offsets of of string column
// 28 bytes for offsets of string column
// 22 bytes of string character size
// 64 bytes for validity of int64 column
// 28 bytes for data of the int64 column
assertEquals(64+64+20+64+28+22+64+28, v.getDeviceMemorySize());
assertHostAligned(64+64+20+64+28+22+64+28, v);
}
}

Expand Down

0 comments on commit 595308b

Please sign in to comment.