Skip to content

Commit

Permalink
[Vectorized][java-udf] add datetime&&largeint&&decimal type to java-udf
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangstar333 committed May 7, 2022
1 parent 0604ecb commit 5e0dc00
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 62 deletions.
7 changes: 5 additions & 2 deletions be/src/util/jni-util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

#include "util/jni-util.h"
#include "jni_md.h"
#ifdef LIBJVM
#include <jni.h>
#include <stdlib.h>
Expand Down Expand Up @@ -45,8 +46,10 @@ void FindOrCreateJavaVM() {
vm_args.nOptions = 1;
vm_args.ignoreUnrecognized = JNI_TRUE;

int res = JNI_CreateJavaVM(&g_vm, (void**)&env, &vm_args);
DCHECK_LT(res, 0) << "Failed tp create JVM, code= " << res;
jint res = JNI_CreateJavaVM(&g_vm, (void**)&env, &vm_args);
if (JNI_OK != res) {
DCHECK(false) << "Failed tp create JVM, code= " << res;
}
} else {
CHECK_EQ(rv, 0) << "Could not find any created Java VM";
CHECK_EQ(num_vms, 1) << "No VMs returned";
Expand Down
26 changes: 17 additions & 9 deletions be/src/vec/functions/function_java_udf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ JavaFunctionCall::JavaFunctionCall(const TFunction& fn, const DataTypes& argumen

Status JavaFunctionCall::prepare(FunctionContext* context,
FunctionContext::FunctionStateScope scope) {
DCHECK(executor_cl_ == NULL) << "Init() already called!";
JNIEnv* env;
//DCHECK(executor_cl_ == NULL) << "Init() already called!";
JNIEnv* env = nullptr;
RETURN_IF_ERROR(JniUtil::GetJNIEnv(&env));
if (env == NULL) return Status::InternalError("Failed to get/create JVM");
if (env == nullptr) {
return Status::InternalError("Failed to get/create JVM");
}
RETURN_IF_ERROR(JniUtil::GetGlobalClassRef(env, EXECUTOR_CLASS, &executor_cl_));
executor_ctor_id_ = env->GetMethodID(executor_cl_, "<init>", EXECUTOR_CTOR_SIGNATURE);
RETURN_ERROR_IF_EXC(env);
Expand Down Expand Up @@ -101,14 +103,15 @@ Status JavaFunctionCall::prepare(FunctionContext* context,
Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
const ColumnNumbers& arguments, size_t result, size_t num_rows,
bool dry_run) {
JNIEnv* env;
JNIEnv* env = nullptr;
RETURN_IF_ERROR(JniUtil::GetJNIEnv(&env));
JniContext* jni_ctx = reinterpret_cast<JniContext*>(
context->get_function_state(FunctionContext::THREAD_LOCAL));
int arg_idx = 0;
for (size_t col_idx : arguments) {
ColumnWithTypeAndName& column = block.get_by_position(col_idx);
auto col = column.column->convert_to_full_column_if_const();
auto& col_type = column.type;
if (!_argument_types[arg_idx]->equals(*column.type)) {
return Status::InvalidArgument(strings::Substitute(
"$0-th input column's type $1 does not equal to required type $2", arg_idx,
Expand All @@ -117,19 +120,23 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
auto data_col = col;
if (auto* nullable = check_and_get_column<const ColumnNullable>(*col)) {
data_col = nullable->get_nested_column_ptr();
col_type = remove_nullable(col_type);
auto null_col =
check_and_get_column<ColumnVector<UInt8>>(nullable->get_null_map_column_ptr());
jni_ctx->input_nulls_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(null_col->get_data().data());
} else {
jni_ctx->input_nulls_buffer_ptr.get()[arg_idx] = -1;
}
if (const ColumnString* str_col = check_and_get_column<ColumnString>(data_col.get())) {
WhichDataType type(col_type);
if (type.is_string_or_fixed_string()) {
const ColumnString* str_col = assert_cast<const ColumnString*>(data_col.get());
jni_ctx->input_values_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(str_col->get_chars().data());
jni_ctx->input_offsets_ptrs.get()[arg_idx] =
reinterpret_cast<int64_t>(str_col->get_offsets().data());
} else if (data_col->is_numeric()) {
} else if (type.is_int() || type.is_uint() || type.is_float() ||
type.is_date_or_datetime() || type.is_decimal()) {
jni_ctx->input_values_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(data_col->get_raw_data().data);
} else {
Expand All @@ -151,7 +158,8 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
*(jni_ctx->output_null_value) = reinterpret_cast<int64_t>(null_col->get_data().data());
#ifndef EVALUATE_JAVA_UDF
#define EVALUATE_JAVA_UDF \
if (const ColumnString* str_col = check_and_get_column<ColumnString>(data_col.get())) { \
if (data_col->is_column_string()) { \
const ColumnString* str_col = assert_cast<const ColumnString*>(data_col.get()); \
ColumnString::Chars& chars = const_cast<ColumnString::Chars&>(str_col->get_chars()); \
ColumnString::Offsets& offsets = \
const_cast<ColumnString::Offsets&>(str_col->get_offsets()); \
Expand All @@ -177,7 +185,7 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
env->CallNonvirtualVoidMethodA(jni_ctx->executor, executor_cl_, executor_evaluate_id_, \
nullptr); \
} \
} else if (data_col->is_numeric()) { \
} else if (data_col->is_numeric() || data_col->is_column_decimal()) { \
data_col->reserve(num_rows); \
data_col->resize(num_rows); \
*(jni_ctx->output_value_buffer) = \
Expand Down Expand Up @@ -205,7 +213,7 @@ Status JavaFunctionCall::close(FunctionContext* context,
FunctionContext::FunctionStateScope scope) {
JniContext* jni_ctx = reinterpret_cast<JniContext*>(
context->get_function_state(FunctionContext::THREAD_LOCAL));
if (jni_ctx != NULL) {
if (jni_ctx != nullptr) {
delete jni_ctx;
context->set_function_state(FunctionContext::THREAD_LOCAL, nullptr);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,15 @@
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Parameter;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -370,6 +374,10 @@ private void analyzeJavaUdf(String clazz) throws AnalysisException {
.put(PrimitiveType.CHAR, Sets.newHashSet(String.class))
.put(PrimitiveType.VARCHAR, Sets.newHashSet(String.class))
.put(PrimitiveType.STRING, Sets.newHashSet(String.class))
.put(PrimitiveType.DATE, Sets.newHashSet(LocalDate.class))
.put(PrimitiveType.DATETIME, Sets.newHashSet(LocalDateTime.class))
.put(PrimitiveType.LARGEINT, Sets.newHashSet(BigInteger.class))
.put(PrimitiveType.DECIMALV2, Sets.newHashSet(BigDecimal.class))
.build();

private void checkUdfType(Class clazz, Method method, Type expType, Class pType, String pname)
Expand Down
Loading

0 comments on commit 5e0dc00

Please sign in to comment.