Skip to content

Commit

Permalink
Add jni for sequences (#9972)
Browse files Browse the repository at this point in the history
This PR add java binding for sequences API. and to fix #9424.

Authors:
  - Bobby Wang (https://github.com/wbo4958)

Approvers:
  - Jason Lowe (https://github.com/jlowe)
  - Robert (Bobby) Evans (https://github.com/revans2)

URL: #9972
  • Loading branch information
wbo4958 authored Jan 5, 2022
1 parent f7cc6a0 commit 6a6fbb3
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 3 deletions.
41 changes: 40 additions & 1 deletion java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -498,6 +498,42 @@ public static ColumnVector sequence(Scalar initialValue, int rows) {
}
return new ColumnVector(sequence(initialValue.getScalarHandle(), 0, rows));
}

/**
* Create a list column in which each row is a sequence of values starting from a `start` value,
* incrementing by one, and its cardinality is specified by a `size` value. The `start` and `size`
* values used to generate each list is taken from the corresponding row of the input start and
* size columns.
* @param start first values in the result sequences
* @param size numbers of values in the result sequences
* @return the new ColumnVector.
*/
public static ColumnVector sequence(ColumnView start, ColumnView size) {
assert start.getNullCount() == 0 || size.getNullCount() == 0 : "starts and sizes input " +
"columns must not have nulls.";
return new ColumnVector(sequences(start.getNativeView(), size.getNativeView(), 0));
}

/**
* Create a list column in which each row is a sequence of values starting from a `start` value,
* incrementing by a `step` value, and its cardinality is specified by a `size` value.
* The values `start`, `step`, and `size` used to generate each list is taken from the
* corresponding row of the input starts, steps, and sizes columns.
* @param start first values in the result sequences
* @param size numbers of values in the result sequences
* @param step increment values for the result sequences.
* @return the new ColumnVector.
*/
public static ColumnVector sequence(ColumnView start, ColumnView size, ColumnView step) {
assert start.getNullCount() == 0 || size.getNullCount() == 0 || step.getNullCount() == 0:
"start, size and step must not have nulls.";
assert step.getType() == start.getType() : "start and step input columns must" +
" have the same type.";

return new ColumnVector(sequences(start.getNativeView(), size.getNativeView(),
step.getNativeView()));
}

/**
* Create a new vector by concatenating multiple columns together.
* Note that all columns must have the same type.
Expand Down Expand Up @@ -789,6 +825,9 @@ public ColumnVector castTo(DType type) {

private static native long sequence(long initialValue, long step, int rows);

private static native long sequences(long startHandle, long sizeHandle, long stepHandle)
throws CudfException;

private static native long fromArrow(int type, long col_length,
long null_count, ByteBuffer data, ByteBuffer validity,
ByteBuffer offsets) throws CudfException;
Expand Down
25 changes: 24 additions & 1 deletion java/src/main/native/src/ColumnVectorJni.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,7 @@
#include <cudf/interop.hpp>
#include <cudf/lists/combine.hpp>
#include <cudf/lists/detail/concatenate.hpp>
#include <cudf/lists/filling.hpp>
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/reshape.hpp>
#include <cudf/scalar/scalar_factories.hpp>
Expand Down Expand Up @@ -54,6 +55,28 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_sequence(JNIEnv *env, j
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_sequences(JNIEnv *env, jclass,
jlong j_start_handle,
jlong j_size_handle,
jlong j_step_handle) {
JNI_NULL_CHECK(env, j_start_handle, "start is null", 0);
JNI_NULL_CHECK(env, j_size_handle, "size is null", 0);
try {
cudf::jni::auto_set_device(env);
auto start = reinterpret_cast<cudf::column_view const *>(j_start_handle);
auto size = reinterpret_cast<cudf::column_view const *>(j_size_handle);
auto step = reinterpret_cast<cudf::column_view const *>(j_step_handle);
std::unique_ptr<cudf::column> col;
if (step) {
col = cudf::lists::sequences(*start, *step, *size);
} else {
col = cudf::lists::sequences(*start, *size);
}
return reinterpret_cast<jlong>(col.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromArrow(
JNIEnv *env, jclass, jint j_type, jlong j_col_length, jlong j_null_count, jobject j_data_obj,
jobject j_validity_obj, jobject j_offsets_obj) {
Expand Down
54 changes: 53 additions & 1 deletion java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1216,6 +1216,58 @@ void testSequenceOtherTypes() {
});
}

@Test
void testSequencesInt() {
try (ColumnVector start = ColumnVector.fromBoxedInts(1, 2, 3, 4, 5);
ColumnVector size = ColumnVector.fromBoxedInts(2, 3, 2, 0, 1);
ColumnVector step = ColumnVector.fromBoxedInts(2, -1, 1, 1, 0);
ColumnVector cv = ColumnVector.sequence(start, size, step);
ColumnVector cv1 = ColumnVector.sequence(start, size);
ColumnVector expectCv = ColumnVector.fromLists(
new ListType(true, new BasicType(false, DType.INT32)),
Arrays.asList(1, 3),
Arrays.asList(2, 1, 0),
Arrays.asList(3, 4),
Arrays.asList(),
Arrays.asList(5));
ColumnVector expectCv1 = ColumnVector.fromLists(
new ListType(true, new BasicType(false, DType.INT32)),
Arrays.asList(1, 2),
Arrays.asList(2, 3, 4),
Arrays.asList(3, 4),
Arrays.asList(),
Arrays.asList(5))) {
assertColumnsAreEqual(expectCv, cv);
assertColumnsAreEqual(expectCv1, cv1);
}
}

@Test
void testSequencesDouble() {
try (ColumnVector start = ColumnVector.fromBoxedDoubles(1.2, 2.2, 3.2, 4.2, 5.2);
ColumnVector size = ColumnVector.fromBoxedInts(2, 3, 2, 0, 1);
ColumnVector step = ColumnVector.fromBoxedDoubles(2.1, -1.1, 1.1, 1.1, 0.1);
ColumnVector cv = ColumnVector.sequence(start, size, step);
ColumnVector cv1 = ColumnVector.sequence(start, size);
ColumnVector expectCv = ColumnVector.fromLists(
new ListType(true, new BasicType(false, DType.FLOAT64)),
Arrays.asList(1.2, 3.3),
Arrays.asList(2.2, 1.1, 0.0),
Arrays.asList(3.2, 4.3),
Arrays.asList(),
Arrays.asList(5.2));
ColumnVector expectCv1 = ColumnVector.fromLists(
new ListType(true, new BasicType(false, DType.FLOAT64)),
Arrays.asList(1.2, 2.2),
Arrays.asList(2.2, 3.2, 4.2),
Arrays.asList(3.2, 4.2),
Arrays.asList(),
Arrays.asList(5.2))) {
assertColumnsAreEqual(expectCv, cv);
assertColumnsAreEqual(expectCv1, cv1);
}
}

@Test
void testFromScalarZeroRows() {
// magic number to invoke factory method specialized for decimal types
Expand Down

0 comments on commit 6a6fbb3

Please sign in to comment.