Skip to content

Commit

Permalink
Add nested type support to ColumnVector#getDeviceMemorySize (#6786)
Browse files Browse the repository at this point in the history
Adds nested type support to ColumnVector device size calculations. This also moves the calculation to JNI code to avoid manifesting Java objects for any child columns.
  • Loading branch information
jlowe authored Nov 17, 2020
1 parent 8ae857c commit 01b8b5c
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
- PR #6748 Add Java API to concatenate serialized tables to ContiguousTable
- PR #6734 Binary operations support for decimal type in cudf Java
- PR #6761 Add Java/JNI bindings for round
- PR #6786 Add nested type support to ColumnVector#getDeviceMemorySize
- PR #6780 Move `cudf::cast` tests to separate test file

## Bug Fixes
Expand Down
19 changes: 4 additions & 15 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ public long getRowCount() {
* Returns the amount of device memory used.
*/
public long getDeviceMemorySize() {
return offHeap != null ? offHeap.getDeviceMemorySize() : 0;
return getDeviceMemorySize(getNativeView());
}

/**
Expand Down Expand Up @@ -3174,6 +3174,9 @@ static native long makeCudfColumnView(int type, int scale, long data, long dataS

private static native int getNativeNumChildren(long viewHandle) throws CudfException;

// calculate the amount of device memory used by this column including any child columns
private static native long getDeviceMemorySize(long viewHandle) throws CudfException;

////////
// Native methods specific to cudf::column. These either take or create a cudf::column
// instead of a cudf::column_view so they need to be used with caution. These should
Expand Down Expand Up @@ -3504,20 +3507,6 @@ protected boolean cleanImpl(boolean logErrorIfNotClean) {
public boolean isClean() {
return viewHandle == 0 && columnHandle == 0 && toClose.isEmpty();
}

/**
* This returns total memory allocated in device for the ColumnVector.
* @return number of device bytes allocated for this column
*/
public long getDeviceMemorySize() {
BaseDeviceMemoryBuffer valid = getValid();
BaseDeviceMemoryBuffer data = getData();
BaseDeviceMemoryBuffer offsets = getOffsets();
long size = valid != null ? valid.getLength() : 0;
size += offsets != null ? offsets.getLength() : 0;
size += data != null ? data.getLength() : 0;
return size;
}
}

public static ColumnVector createNestedColumnVector(DType type, int rows, HostMemoryBuffer data, HostMemoryBuffer valid, HostMemoryBuffer offsets,
Expand Down
36 changes: 36 additions & 0 deletions java/src/main/native/src/ColumnVectorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

#include <numeric>

#include <cudf/aggregation.hpp>
#include <cudf/binaryop.hpp>
#include <cudf/column/column_factories.hpp>
Expand All @@ -24,6 +26,7 @@
#include <cudf/datetime.hpp>
#include <cudf/filling.hpp>
#include <cudf/hashing.hpp>
#include <cudf/null_mask.hpp>
#include <cudf/quantiles.hpp>
#include <cudf/reduction.hpp>
#include <cudf/replace.hpp>
Expand Down Expand Up @@ -59,6 +62,28 @@
#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"

namespace {

std::size_t calc_device_memory_size(cudf::column_view const &view) {
std::size_t total = 0;
auto row_count = view.size();

if (view.nullable()) {
total += cudf::bitmask_allocation_size_bytes(row_count);
}

auto dtype = view.type();
if (cudf::is_fixed_width(dtype)) {
total += cudf::size_of(dtype) * view.size();
}

return std::accumulate(view.child_begin(), view.child_end(), total,
[](std::size_t t, cudf::column_view const &v) {
return t + calc_device_memory_size(v);
});
}

} // anonymous namespace

extern "C" {

Expand Down Expand Up @@ -1604,6 +1629,17 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_getNativeValidPointerSi
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_getDeviceMemorySize(JNIEnv *env, jclass,
jlong handle) {
JNI_NULL_CHECK(env, handle, "native handle is null", 0);
try {
cudf::jni::auto_set_device(env);
auto view = reinterpret_cast<cudf::column_view const *>(handle);
return calc_device_memory_size(*view);
}
CATCH_STD(env, 0);
}

////////
// Native methods specific to cudf::column. These either take or return a cudf::column
// instead of a cudf::column_view so they need to be used with caution. These should
Expand Down
71 changes: 69 additions & 2 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,20 @@

package ai.rapids.cudf;

import ai.rapids.cudf.HostColumnVector.BasicType;
import ai.rapids.cudf.HostColumnVector.DataType;
import ai.rapids.cudf.HostColumnVector.ListType;
import ai.rapids.cudf.HostColumnVector.StructData;
import ai.rapids.cudf.HostColumnVector.StructType;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Collectors;
Expand All @@ -36,9 +43,14 @@
import static ai.rapids.cudf.QuantileMethod.MIDPOINT;
import static ai.rapids.cudf.QuantileMethod.NEAREST;
import static ai.rapids.cudf.TableTest.assertColumnsAreEqual;
import static ai.rapids.cudf.TableTest.assertTablesAreEqual;
import static ai.rapids.cudf.TableTest.assertStructColumnsAreEqual;
import static org.junit.jupiter.api.Assertions.*;
import static ai.rapids.cudf.TableTest.assertTablesAreEqual;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assumptions.assumeTrue;

public class ColumnVectorTest extends CudfTestBase {
Expand Down Expand Up @@ -695,6 +707,61 @@ void testGetDeviceMemorySizeStrings() {
}
}

@SuppressWarnings("unchecked")
@Test
void testGetDeviceMemorySizeLists() {
DataType svType = new ListType(true, new BasicType(true, DType.STRING));
DataType ivType = new ListType(false, new BasicType(false, DType.INT32));
try (ColumnVector sv = ColumnVector.fromLists(svType,
Arrays.asList("first", "second", "third"),
Arrays.asList("fourth", null),
null);
ColumnVector iv = ColumnVector.fromLists(ivType,
Arrays.asList(1, 2, 3),
Collections.singletonList(4),
Arrays.asList(5, 6),
Collections.singletonList(7))) {
// 64 bytes for validity of list column
// 16 bytes for offsets of list column
// 64 bytes for validity of string column
// 24 bytes for offsets of of string column
// 22 bytes of string character size
assertEquals(64+16+64+24+22, sv.getDeviceMemorySize());

// 20 bytes for offsets of list column
// 28 bytes for data of INT32 column
assertEquals(20+28, iv.getDeviceMemorySize());
}
}

@Test
void testGetDeviceMemorySizeStructs() {
DataType structType = new StructType(true,
new ListType(true, new BasicType(true, DType.STRING)),
new BasicType(true, DType.INT64));
try (ColumnVector v = ColumnVector.fromStructs(structType,
new StructData(
Arrays.asList("first", "second", "third"),
10L),
new StructData(
Arrays.asList("fourth", null),
20L),
new StructData(
null,
null),
null)) {
// 64 bytes for validity of the struct column
// 64 bytes for validity of list column
// 20 bytes for offsets of list column
// 64 bytes for validity of string column
// 28 bytes for offsets of of string column
// 22 bytes of string character size
// 64 bytes for validity of int64 column
// 28 bytes for data of the int64 column
assertEquals(64+64+20+64+28+22+64+28, v.getDeviceMemorySize());
}
}

@Test
void testSequenceInt() {
try (Scalar zero = Scalar.fromInt(0);
Expand Down

0 comments on commit 01b8b5c

Please sign in to comment.