Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creates an empty column for the null LIST Scalar [skip ci] #8173

Merged
merged 3 commits into from
May 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -1433,4 +1433,48 @@ public static ColumnVector timestampNanoSecondsFromBoxedLongs(Long... values) {
return build(DType.TIMESTAMP_NANOSECONDS, values.length, (b) -> b.appendBoxed(values));
}

/**
* Creates an empty column according to the data type.
*
* It will create all the nested columns by iterating all the children in the input
* type object 'colType'.
*
* The performance is not good, so use it carefully. We may want to move this implementation
* to the native once figuring out a way to pass the nested data type to the native.
jlowe marked this conversation as resolved.
Show resolved Hide resolved
*
* @param colType the data type of the empty column
* @return an empty ColumnVector with its children. Each children contains zero elements.
* Users should close the ColumnVector to avoid memory leak.
*/
public static ColumnVector empty(HostColumnVector.DataType colType) {
if (colType == null || colType.getType() == null) {
throw new IllegalArgumentException("The data type and its 'DType' should NOT be null.");
}
if (colType instanceof HostColumnVector.BasicType) {
// Non nested type
DType dt = colType.getType();
return new ColumnVector(makeEmptyCudfColumn(dt.typeId.getNativeId(), dt.getScale()));
} else if (colType instanceof HostColumnVector.ListType) {
// List type
assert colType.getNumChildren() == 1 : "List type requires one child type";
try (ColumnVector child = empty(colType.getChild(0))) {
return makeList(child);
}
} else if (colType instanceof HostColumnVector.StructType) {
// Struct type
ColumnVector[] children = new ColumnVector[colType.getNumChildren()];
try {
for (int i = 0; i < children.length; i++) {
children[i] = empty(colType.getChild(i));
}
return makeStruct(children);
} finally {
for (ColumnVector cv : children) {
if (cv != null) cv.close();
}
}
} else {
throw new IllegalArgumentException("Unsupported data type: " + colType);
}
}
}
20 changes: 18 additions & 2 deletions java/src/main/java/ai/rapids/cudf/Scalar.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public static Scalar fromNull(DType type) {
case DECIMAL64:
return new Scalar(type, makeDecimal64Scalar(0L, type.getScale(), false));
case LIST:
return new Scalar(type, makeListScalar(0L, false));
throw new IllegalArgumentException("Please call 'listFromNull' to create a null list scalar.");
default:
throw new IllegalArgumentException("Unexpected type: " + type);
}
Expand Down Expand Up @@ -335,6 +335,21 @@ public static Scalar fromString(String value) {
return new Scalar(DType.STRING, makeStringScalar(value.getBytes(StandardCharsets.UTF_8), true));
}

/**
* Creates a null scalar of list type.
*
* Having this special API because the element type is required to build an empty
* nested column as the underlying column of the list scalar.
*
* @param elementType the data type of the element in the list.
* @return a null scalar of list type
*/
public static Scalar listFromNull(HostColumnVector.DataType elementType) {
try (ColumnVector col = ColumnVector.empty(elementType)) {
return new Scalar(DType.LIST, makeListScalar(col.getNativeView(), false));
}
}

/**
* Creates a scalar of list from a ColumnView.
*
Expand All @@ -343,7 +358,8 @@ public static Scalar fromString(String value) {
*/
public static Scalar listFromColumnView(ColumnView list) {
if (list == null) {
return Scalar.fromNull(DType.LIST);
throw new IllegalArgumentException("'list' should NOT be null." +
" Please call 'listFromNull' to create a null list scalar.");
}
return new Scalar(DType.LIST, makeListScalar(list.getNativeView(), true));
}
Expand Down
53 changes: 2 additions & 51 deletions java/src/main/native/src/ColumnVectorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,55 +31,6 @@
#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"

namespace cudf {
namespace jni {
/**
* @brief Creates an empty column according to the type tree specified
* in @p view
*
* An empty column contains zero elements and no validity mask.
*
* Unlike the 'cudf::make_empty_column', it takes care of the nested type by
* iterating the children columns in the @p view
*
* @param[in] view The input column view
* @return An empty column for the input column view
*/
std::unique_ptr<cudf::column> make_empty_column(JNIEnv *env, cudf::column_view const& view) {
auto tid = view.type().id();
if (tid == cudf::type_id::LIST) {
// List
if (view.num_children() != 2) {
throw jni_exception("List type requires two children(offset, data).");
}
// Only needs the second child.
auto data_view = view.child(1);
// offsets: [0]
auto offsets_buffer = rmm::device_buffer(sizeof(cudf::size_type));
device_memset_async(env, offsets_buffer, 0);
auto offsets = std::make_unique<column>(cudf::data_type{cudf::type_id::INT32}, 1,
std::move(offsets_buffer),
rmm::device_buffer(), 0);
auto data_col = make_empty_column(env, data_view);
return cudf::make_lists_column(0, std::move(offsets), std::move(data_col),
0, rmm::device_buffer());
} else if (tid == cudf::type_id::STRUCT) {
// Struct
std::vector<std::unique_ptr<column>> children(view.num_children());
std::transform(view.child_begin(), view.child_end(), children.begin(),
[env](auto const& child_v) {
return make_empty_column(env, child_v);
});
return cudf::make_structs_column(0, std::move(children), 0, rmm::device_buffer());
} else {
// Non nested types
return cudf::make_empty_column(view.type());
}
}

} // namespace jni
} // namespace cudf

extern "C" {

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_sequence(JNIEnv *env, jclass,
Expand Down Expand Up @@ -246,10 +197,10 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromScalar(JNIEnv *env,
// 'row_count' times, because 'cudf::make_column_from_scalar' does not support list
// type.
// (Assumes the `row_count` is not big, otherwise there would be a performance issue.)
// Checks the `row_count` because `cudf::concatenate` does not support no columns.
// Checks the `row_count` because `cudf::concatenate` does not support no rows.
auto data_col = row_count > 0
? cudf::concatenate(std::vector<cudf::column_view>(row_count, s_val))
: cudf::jni::make_empty_column(env, s_val);
: cudf::empty_like(s_val);
col = cudf::make_lists_column(row_count, std::move(offsets), std::move(data_col),
cudf::state_null_count(mask_state, row_count),
cudf::create_null_mask(row_count, mask_state));
Expand Down
12 changes: 6 additions & 6 deletions java/src/main/native/src/ScalarJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,16 +461,16 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_binaryOpSV(JNIEnv *env, jclas
JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeListScalar(JNIEnv *env, jclass,
jlong view_handle,
jboolean is_valid) {
if (is_valid) {
JNI_NULL_CHECK(env, view_handle,
"list scalar is set to `valid` but column view is null", 0);
}
JNI_NULL_CHECK(env, view_handle, "Column view should NOT be null", 0);
try {
cudf::jni::auto_set_device(env);
auto col_view = reinterpret_cast<cudf::column_view *>(view_handle);

cudf::scalar* s = is_valid ? new cudf::list_scalar(*col_view)
: new cudf::list_scalar();
// Instead of calling the `cudf::empty_like` to create an empty column when `is_valid`
// is false, always passes the input view to the scalar, to avoid copying the column
// twice.
// Let the Java layer make sure the view is empty when `is_valid` is false.
cudf::scalar* s = new cudf::list_scalar(*col_view);
s->set_valid(is_valid);
return reinterpret_cast<jlong>(s);
}
Expand Down
46 changes: 45 additions & 1 deletion java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1391,7 +1391,7 @@ void testFromScalar() {
void testFromScalarNull() {
final int rowCount = 4;
for (DType.DTypeEnum typeEnum : DType.DTypeEnum.values()) {
if (typeEnum == DType.DTypeEnum.EMPTY || typeEnum == DType.DTypeEnum.STRUCT) {
if (typeEnum == DType.DTypeEnum.EMPTY || typeEnum == DType.DTypeEnum.LIST || typeEnum == DType.DTypeEnum.STRUCT) {
continue;
}
DType dType;
Expand Down Expand Up @@ -1428,6 +1428,50 @@ void testFromScalarNullByte() {
}
}

@Test
void testFromScalarNullList() {
final int rowCount = 4;
for (DType.DTypeEnum typeEnum : DType.DTypeEnum.values()) {
DType dType = typeEnum.isDecimalType() ? DType.create(typeEnum, -8): DType.create(typeEnum);
DataType hDataType;
if (DType.EMPTY.equals(dType)) {
continue;
} else if (DType.LIST.equals(dType)) {
// list of list of int32
hDataType = new ListType(true, new BasicType(true, DType.INT32));
} else if (DType.STRUCT.equals(dType)) {
// list of struct of int32
hDataType = new StructType(true, new BasicType(true, DType.INT32));
} else {
// list of non nested type
hDataType = new BasicType(true, dType);
}
try (Scalar s = Scalar.listFromNull(hDataType);
ColumnVector c = ColumnVector.fromScalar(s, rowCount);
HostColumnVector hc = c.copyToHost()) {
assertEquals(DType.LIST, c.getType());
assertEquals(rowCount, c.getRowCount());
assertEquals(rowCount, c.getNullCount());
for (int i = 0; i < rowCount; ++i) {
assertTrue(hc.isNull(i));
}

try (ColumnView child = c.getChildColumnView(0)) {
assertEquals(dType, child.getType());
assertEquals(0L, child.getRowCount());
assertEquals(0L, child.getNullCount());
if (child.getType().isNestedType()) {
try (ColumnView grandson = child.getChildColumnView(0)) {
assertEquals(DType.INT32, grandson.getType());
assertEquals(0L, grandson.getRowCount());
assertEquals(0L, grandson.getNullCount());
}
}
}
}
}
}

@Test
void testFromScalarListOfList() {
HostColumnVector.DataType childType = new HostColumnVector.ListType(true,
Expand Down
38 changes: 37 additions & 1 deletion java/src/test/java/ai/rapids/cudf/ScalarTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@

package ai.rapids.cudf;

import ai.rapids.cudf.HostColumnVector.BasicType;
import ai.rapids.cudf.HostColumnVector.DataType;
import ai.rapids.cudf.HostColumnVector.ListType;
import ai.rapids.cudf.HostColumnVector.StructData;
import ai.rapids.cudf.HostColumnVector.StructType;

import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
Expand Down Expand Up @@ -56,12 +62,42 @@ public void testNull() {
} else {
type = DType.create(dataType);
}
if (!type.isNestedType() || type.equals(DType.LIST)) {
if (!type.isNestedType()) {
try (Scalar s = Scalar.fromNull(type)) {
assertEquals(type, s.getType());
assertFalse(s.isValid(), "null validity for " + type);
}
}

// list scalar
HostColumnVector.DataType hDataType;
if (DType.EMPTY.equals(type)) {
continue;
} else if (DType.LIST.equals(type)) {
// list of list of int32
hDataType = new ListType(true, new BasicType(true, DType.INT32));
} else if (DType.STRUCT.equals(type)) {
// list of struct of int32
hDataType = new StructType(true, new BasicType(true, DType.INT32));
} else {
// list of non nested type
hDataType = new BasicType(true, type);
}
try (Scalar s = Scalar.listFromNull(hDataType);
ColumnView listCv = s.getListAsColumnView()) {
assertFalse(s.isValid(), "null validity for " + type);
assertEquals(DType.LIST, s.getType());
assertEquals(type, listCv.getType());
assertEquals(0L, listCv.getRowCount());
assertEquals(0L, listCv.getNullCount());
if (type.isNestedType()) {
try (ColumnView child = listCv.getChildColumnView(0)) {
assertEquals(DType.INT32, child.getType());
assertEquals(0L, child.getRowCount());
assertEquals(0L, child.getNullCount());
}
}
}
}
}

Expand Down