Skip to content

Commit

Permalink
Creates an empty column for the null LIST Scalar (#8173)
Browse files Browse the repository at this point in the history
This PR is to support creating a null `LIST Scalar` with an empty column but containing all the necessary children, from a `HostColumnVector.DataType`.

Also removed the function `cudf::jni::make_empty_column`, which is a duplicate of `cudf::empty_like`. and replaced it with the `cudf::empty_like` when building a column from a list scalar. 

closes #8170

Signed-off-by: Firestarman <[email protected]>

Authors:
  - Liangcai Li (https://github.com/firestarman)

Approvers:
  - Robert (Bobby) Evans (https://github.com/revans2)
  - Jason Lowe (https://github.com/jlowe)

URL: #8173
  • Loading branch information
firestarman authored May 8, 2021
1 parent bb62cf1 commit 3813d9b
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 61 deletions.
44 changes: 44 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -1489,4 +1489,48 @@ public static ColumnVector timestampNanoSecondsFromBoxedLongs(Long... values) {
return build(DType.TIMESTAMP_NANOSECONDS, values.length, (b) -> b.appendBoxed(values));
}

/**
* Creates an empty column according to the data type.
*
* It will create all the nested columns by iterating all the children in the input
* type object 'colType'.
*
* The performance is not good, so use it carefully. We may want to move this implementation
* to the native once figuring out a way to pass the nested data type to the native.
*
* @param colType the data type of the empty column
* @return an empty ColumnVector with its children. Each children contains zero elements.
* Users should close the ColumnVector to avoid memory leak.
*/
public static ColumnVector empty(HostColumnVector.DataType colType) {
if (colType == null || colType.getType() == null) {
throw new IllegalArgumentException("The data type and its 'DType' should NOT be null.");
}
if (colType instanceof HostColumnVector.BasicType) {
// Non nested type
DType dt = colType.getType();
return new ColumnVector(makeEmptyCudfColumn(dt.typeId.getNativeId(), dt.getScale()));
} else if (colType instanceof HostColumnVector.ListType) {
// List type
assert colType.getNumChildren() == 1 : "List type requires one child type";
try (ColumnVector child = empty(colType.getChild(0))) {
return makeList(child);
}
} else if (colType instanceof HostColumnVector.StructType) {
// Struct type
ColumnVector[] children = new ColumnVector[colType.getNumChildren()];
try {
for (int i = 0; i < children.length; i++) {
children[i] = empty(colType.getChild(i));
}
return makeStruct(children);
} finally {
for (ColumnVector cv : children) {
if (cv != null) cv.close();
}
}
} else {
throw new IllegalArgumentException("Unsupported data type: " + colType);
}
}
}
20 changes: 18 additions & 2 deletions java/src/main/java/ai/rapids/cudf/Scalar.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public static Scalar fromNull(DType type) {
case DECIMAL64:
return new Scalar(type, makeDecimal64Scalar(0L, type.getScale(), false));
case LIST:
return new Scalar(type, makeListScalar(0L, false));
throw new IllegalArgumentException("Please call 'listFromNull' to create a null list scalar.");
default:
throw new IllegalArgumentException("Unexpected type: " + type);
}
Expand Down Expand Up @@ -335,6 +335,21 @@ public static Scalar fromString(String value) {
return new Scalar(DType.STRING, makeStringScalar(value.getBytes(StandardCharsets.UTF_8), true));
}

/**
* Creates a null scalar of list type.
*
* Having this special API because the element type is required to build an empty
* nested column as the underlying column of the list scalar.
*
* @param elementType the data type of the element in the list.
* @return a null scalar of list type
*/
public static Scalar listFromNull(HostColumnVector.DataType elementType) {
try (ColumnVector col = ColumnVector.empty(elementType)) {
return new Scalar(DType.LIST, makeListScalar(col.getNativeView(), false));
}
}

/**
* Creates a scalar of list from a ColumnView.
*
Expand All @@ -343,7 +358,8 @@ public static Scalar fromString(String value) {
*/
public static Scalar listFromColumnView(ColumnView list) {
if (list == null) {
return Scalar.fromNull(DType.LIST);
throw new IllegalArgumentException("'list' should NOT be null." +
" Please call 'listFromNull' to create a null list scalar.");
}
return new Scalar(DType.LIST, makeListScalar(list.getNativeView(), true));
}
Expand Down
53 changes: 2 additions & 51 deletions java/src/main/native/src/ColumnVectorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,55 +33,6 @@
#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"

namespace cudf {
namespace jni {
/**
* @brief Creates an empty column according to the type tree specified
* in @p view
*
* An empty column contains zero elements and no validity mask.
*
* Unlike the 'cudf::make_empty_column', it takes care of the nested type by
* iterating the children columns in the @p view
*
* @param[in] view The input column view
* @return An empty column for the input column view
*/
std::unique_ptr<cudf::column> make_empty_column(JNIEnv *env, cudf::column_view const& view) {
auto tid = view.type().id();
if (tid == cudf::type_id::LIST) {
// List
if (view.num_children() != 2) {
throw jni_exception("List type requires two children(offset, data).");
}
// Only needs the second child.
auto data_view = view.child(1);
// offsets: [0]
auto offsets_buffer = rmm::device_buffer(sizeof(cudf::size_type));
device_memset_async(env, offsets_buffer, 0);
auto offsets = std::make_unique<column>(cudf::data_type{cudf::type_id::INT32}, 1,
std::move(offsets_buffer),
rmm::device_buffer(), 0);
auto data_col = make_empty_column(env, data_view);
return cudf::make_lists_column(0, std::move(offsets), std::move(data_col),
0, rmm::device_buffer());
} else if (tid == cudf::type_id::STRUCT) {
// Struct
std::vector<std::unique_ptr<column>> children(view.num_children());
std::transform(view.child_begin(), view.child_end(), children.begin(),
[env](auto const& child_v) {
return make_empty_column(env, child_v);
});
return cudf::make_structs_column(0, std::move(children), 0, rmm::device_buffer());
} else {
// Non nested types
return cudf::make_empty_column(view.type());
}
}

} // namespace jni
} // namespace cudf

extern "C" {

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_sequence(JNIEnv *env, jclass,
Expand Down Expand Up @@ -298,10 +249,10 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromScalar(JNIEnv *env,
// 'row_count' times, because 'cudf::make_column_from_scalar' does not support list
// type.
// (Assumes the `row_count` is not big, otherwise there would be a performance issue.)
// Checks the `row_count` because `cudf::concatenate` does not support no columns.
// Checks the `row_count` because `cudf::concatenate` does not support no rows.
auto data_col = row_count > 0
? cudf::concatenate(std::vector<cudf::column_view>(row_count, s_val))
: cudf::jni::make_empty_column(env, s_val);
: cudf::empty_like(s_val);
col = cudf::make_lists_column(row_count, std::move(offsets), std::move(data_col),
cudf::state_null_count(mask_state, row_count),
cudf::create_null_mask(row_count, mask_state));
Expand Down
12 changes: 6 additions & 6 deletions java/src/main/native/src/ScalarJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,16 +461,16 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_binaryOpSV(JNIEnv *env, jclas
JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeListScalar(JNIEnv *env, jclass,
jlong view_handle,
jboolean is_valid) {
if (is_valid) {
JNI_NULL_CHECK(env, view_handle,
"list scalar is set to `valid` but column view is null", 0);
}
JNI_NULL_CHECK(env, view_handle, "Column view should NOT be null", 0);
try {
cudf::jni::auto_set_device(env);
auto col_view = reinterpret_cast<cudf::column_view *>(view_handle);

cudf::scalar* s = is_valid ? new cudf::list_scalar(*col_view)
: new cudf::list_scalar();
// Instead of calling the `cudf::empty_like` to create an empty column when `is_valid`
// is false, always passes the input view to the scalar, to avoid copying the column
// twice.
// Let the Java layer make sure the view is empty when `is_valid` is false.
cudf::scalar* s = new cudf::list_scalar(*col_view);
s->set_valid(is_valid);
return reinterpret_cast<jlong>(s);
}
Expand Down
46 changes: 45 additions & 1 deletion java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1391,7 +1391,7 @@ void testFromScalar() {
void testFromScalarNull() {
final int rowCount = 4;
for (DType.DTypeEnum typeEnum : DType.DTypeEnum.values()) {
if (typeEnum == DType.DTypeEnum.EMPTY || typeEnum == DType.DTypeEnum.STRUCT) {
if (typeEnum == DType.DTypeEnum.EMPTY || typeEnum == DType.DTypeEnum.LIST || typeEnum == DType.DTypeEnum.STRUCT) {
continue;
}
DType dType;
Expand Down Expand Up @@ -1428,6 +1428,50 @@ void testFromScalarNullByte() {
}
}

@Test
void testFromScalarNullList() {
final int rowCount = 4;
for (DType.DTypeEnum typeEnum : DType.DTypeEnum.values()) {
DType dType = typeEnum.isDecimalType() ? DType.create(typeEnum, -8): DType.create(typeEnum);
DataType hDataType;
if (DType.EMPTY.equals(dType)) {
continue;
} else if (DType.LIST.equals(dType)) {
// list of list of int32
hDataType = new ListType(true, new BasicType(true, DType.INT32));
} else if (DType.STRUCT.equals(dType)) {
// list of struct of int32
hDataType = new StructType(true, new BasicType(true, DType.INT32));
} else {
// list of non nested type
hDataType = new BasicType(true, dType);
}
try (Scalar s = Scalar.listFromNull(hDataType);
ColumnVector c = ColumnVector.fromScalar(s, rowCount);
HostColumnVector hc = c.copyToHost()) {
assertEquals(DType.LIST, c.getType());
assertEquals(rowCount, c.getRowCount());
assertEquals(rowCount, c.getNullCount());
for (int i = 0; i < rowCount; ++i) {
assertTrue(hc.isNull(i));
}

try (ColumnView child = c.getChildColumnView(0)) {
assertEquals(dType, child.getType());
assertEquals(0L, child.getRowCount());
assertEquals(0L, child.getNullCount());
if (child.getType().isNestedType()) {
try (ColumnView grandson = child.getChildColumnView(0)) {
assertEquals(DType.INT32, grandson.getType());
assertEquals(0L, grandson.getRowCount());
assertEquals(0L, grandson.getNullCount());
}
}
}
}
}
}

@Test
void testFromScalarListOfList() {
HostColumnVector.DataType childType = new HostColumnVector.ListType(true,
Expand Down
38 changes: 37 additions & 1 deletion java/src/test/java/ai/rapids/cudf/ScalarTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@

package ai.rapids.cudf;

import ai.rapids.cudf.HostColumnVector.BasicType;
import ai.rapids.cudf.HostColumnVector.DataType;
import ai.rapids.cudf.HostColumnVector.ListType;
import ai.rapids.cudf.HostColumnVector.StructData;
import ai.rapids.cudf.HostColumnVector.StructType;

import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
Expand Down Expand Up @@ -56,12 +62,42 @@ public void testNull() {
} else {
type = DType.create(dataType);
}
if (!type.isNestedType() || type.equals(DType.LIST)) {
if (!type.isNestedType()) {
try (Scalar s = Scalar.fromNull(type)) {
assertEquals(type, s.getType());
assertFalse(s.isValid(), "null validity for " + type);
}
}

// list scalar
HostColumnVector.DataType hDataType;
if (DType.EMPTY.equals(type)) {
continue;
} else if (DType.LIST.equals(type)) {
// list of list of int32
hDataType = new ListType(true, new BasicType(true, DType.INT32));
} else if (DType.STRUCT.equals(type)) {
// list of struct of int32
hDataType = new StructType(true, new BasicType(true, DType.INT32));
} else {
// list of non nested type
hDataType = new BasicType(true, type);
}
try (Scalar s = Scalar.listFromNull(hDataType);
ColumnView listCv = s.getListAsColumnView()) {
assertFalse(s.isValid(), "null validity for " + type);
assertEquals(DType.LIST, s.getType());
assertEquals(type, listCv.getType());
assertEquals(0L, listCv.getRowCount());
assertEquals(0L, listCv.getNullCount());
if (type.isNestedType()) {
try (ColumnView child = listCv.getChildColumnView(0)) {
assertEquals(DType.INT32, child.getType());
assertEquals(0L, child.getRowCount());
assertEquals(0L, child.getNullCount());
}
}
}
}
}

Expand Down

0 comments on commit 3813d9b

Please sign in to comment.