Skip to content

Commit

Permalink
[Feature] support ARRAY/MAP input for Java Scalar UDF
Browse files Browse the repository at this point in the history
Signed-off-by: stdpain <[email protected]>
  • Loading branch information
stdpain committed Jan 23, 2025
1 parent a3566ca commit 6f3d167
Show file tree
Hide file tree
Showing 18 changed files with 1,206 additions and 248 deletions.
9 changes: 5 additions & 4 deletions be/src/exec/jdbc_scanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,12 @@ Status JDBCScanner::_init_column_class_name(RuntimeState* state) {
LOCAL_REF_GUARD_ENV(env, column_class_names);

auto& helper = JVMFunctionHelper::getInstance();
int len = helper.list_size(column_class_names);
JavaListStub list_stub(column_class_names);
ASSIGN_OR_RETURN(int len, list_stub.size());

_result_chunk = std::make_shared<Chunk>();
for (int i = 0; i < len; i++) {
jobject jelement = helper.list_get(column_class_names, i);
ASSIGN_OR_RETURN(jobject jelement, list_stub.get(i));
LOCAL_REF_GUARD_ENV(env, jelement);
std::string class_name = helper.to_string((jstring)(jelement));
ASSIGN_OR_RETURN(auto ret_type, _precheck_data_type(class_name, _slot_descs[i]));
Expand Down Expand Up @@ -311,12 +312,12 @@ Status JDBCScanner::_fill_chunk(jobject jchunk, size_t num_rows, ChunkPtr* chunk
{
auto& helper = JVMFunctionHelper::getInstance();
auto* env = helper.getEnv();

JavaListStub list_stub(jchunk);
COUNTER_UPDATE(_profile.rows_read_counter, num_rows);
(*chunk)->reset();

for (size_t i = 0; i < _slot_descs.size(); i++) {
jobject jcolumn = helper.list_get(jchunk, i);
ASSIGN_OR_RETURN(jobject jcolumn, list_stub.get(i));
LOCAL_REF_GUARD_ENV(env, jcolumn);
auto& result_column = _result_chunk->columns()[i];
auto st =
Expand Down
25 changes: 12 additions & 13 deletions be/src/exprs/agg/java_udaf_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
Column* to) const final {
auto* udaf_ctx = ctx->udaf_ctxs();
jvalue val = udaf_ctx->_func->finalize(this->data(state).handle);
append_jvalue(udaf_ctx->finalize->method_desc[0], to, val);
auto st = append_jvalue(ctx->get_return_type(), udaf_ctx->finalize->method_desc[0].is_box, to, val);
SET_FUNCTION_CONTEXT_ERR(st, ctx);
RETURN_IF_UNLIKELY(!st.ok(), (void)0);
release_jvalue(udaf_ctx->finalize->method_desc[0].is_box, val);
}

Expand All @@ -116,7 +118,6 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
RETURN_IF_UNLIKELY_NULL(rets, (void)0);
// 1.2 convert input as input array
int num_cols = ctx->get_num_args();
std::vector<DirectByteBuffer> buffers;
std::vector<jobject> args;
DeferOp defer = DeferOp([&]() {
// clean up arrays
Expand All @@ -131,8 +132,9 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
for (int i = 0; i < src.size(); ++i) {
raw_input_ptrs[i] = src[i].get();
}
auto st = JavaDataTypeConverter::convert_to_boxed_array(ctx, &buffers, raw_input_ptrs.data(), num_cols,
batch_size, &args);
auto st =
JavaDataTypeConverter::convert_to_boxed_array(ctx, raw_input_ptrs.data(), num_cols, batch_size, &args);
SET_FUNCTION_CONTEXT_ERR(st, ctx);
RETURN_IF_UNLIKELY(!st.ok(), (void)0);

// 2 batch call update
Expand Down Expand Up @@ -196,7 +198,6 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
void update_batch(FunctionContext* ctx, size_t batch_size, size_t state_offset, const Column** columns,
AggDataPtr* states) const override {
auto& helper = JVMFunctionHelper::getInstance();
std::vector<DirectByteBuffer> buffers;
std::vector<jobject> args;
int num_cols = ctx->get_num_args();
helper.getEnv()->PushLocalFrame(num_cols * 3 + 1);
Expand All @@ -205,8 +206,8 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
{
auto states_arr = JavaDataTypeConverter::convert_to_states(ctx, states, state_offset, batch_size);
RETURN_IF_UNLIKELY_NULL(states_arr, (void)0);
auto st =
JavaDataTypeConverter::convert_to_boxed_array(ctx, &buffers, columns, num_cols, batch_size, &args);
auto st = JavaDataTypeConverter::convert_to_boxed_array(ctx, columns, num_cols, batch_size, &args);
SET_FUNCTION_CONTEXT_ERR(st, ctx);
RETURN_IF_UNLIKELY(!st.ok(), (void)0);
helper.batch_update(ctx, ctx->udaf_ctxs()->handle.handle(), ctx->udaf_ctxs()->update->method.handle(),
states_arr, args.data(), args.size());
Expand All @@ -216,7 +217,6 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
void update_batch_selectively(FunctionContext* ctx, size_t batch_size, size_t state_offset, const Column** columns,
AggDataPtr* states, const Filter& filter) const override {
auto [env, helper] = JVMFunctionHelper::getInstanceWithEnv();
std::vector<DirectByteBuffer> buffers;
std::vector<jobject> args;
int num_cols = ctx->get_num_args();
helper.getEnv()->PushLocalFrame(num_cols * 3 + 1);
Expand All @@ -225,8 +225,8 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
auto states_arr = JavaDataTypeConverter::convert_to_states_with_filter(ctx, states, state_offset,
filter.data(), batch_size);
RETURN_IF_UNLIKELY_NULL(states_arr, (void)0);
auto st =
JavaDataTypeConverter::convert_to_boxed_array(ctx, &buffers, columns, num_cols, batch_size, &args);
auto st = JavaDataTypeConverter::convert_to_boxed_array(ctx, columns, num_cols, batch_size, &args);
SET_FUNCTION_CONTEXT_ERR(st, ctx);
RETURN_IF_UNLIKELY(!st.ok(), (void)0);
helper.batch_update_if_not_null(ctx, ctx->udaf_ctxs()->handle.handle(),
ctx->udaf_ctxs()->update->method.handle(), states_arr, args.data(),
Expand All @@ -239,13 +239,12 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
auto& helper = JVMFunctionHelper::getInstance();
auto* env = helper.getEnv();
std::vector<jobject> args;
std::vector<DirectByteBuffer> buffers;
int num_cols = ctx->get_num_args();
env->PushLocalFrame(num_cols * 3 + 1);
auto defer = DeferOp([env = env]() { env->PopLocalFrame(nullptr); });
{
auto st =
JavaDataTypeConverter::convert_to_boxed_array(ctx, &buffers, columns, num_cols, batch_size, &args);
auto st = JavaDataTypeConverter::convert_to_boxed_array(ctx, columns, num_cols, batch_size, &args);
SET_FUNCTION_CONTEXT_ERR(st, ctx);
RETURN_IF_UNLIKELY(!st.ok(), (void)0);

auto* stub = ctx->udaf_ctxs()->update_batch_call_stub.get();
Expand Down
5 changes: 2 additions & 3 deletions be/src/exprs/agg/java_window_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ class JavaWindowFunction final : public JavaUDAFAggregateFunction {
}

std::vector<jobject> args;
std::vector<DirectByteBuffer> buffers;
ConvertDirectBufferVistor vistor(buffers);
auto& helper = JVMFunctionHelper::getInstance();
JNIEnv* env = helper.getEnv();
DeferOp defer = DeferOp([&]() {
Expand All @@ -61,7 +59,8 @@ class JavaWindowFunction final : public JavaUDAFAggregateFunction {
}
}
});
auto st = JavaDataTypeConverter::convert_to_boxed_array(ctx, &buffers, columns, num_args, num_rows, &args);
auto st = JavaDataTypeConverter::convert_to_boxed_array(ctx, columns, num_args, num_rows, &args);
SET_FUNCTION_CONTEXT_ERR(st, ctx);
RETURN_IF_UNLIKELY(!st.ok(), (void)0);

ctx->udaf_ctxs()->_func->window_update_batch(data(state).handle, peer_group_start, peer_group_end, frame_start,
Expand Down
24 changes: 6 additions & 18 deletions be/src/exprs/java_function_call_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,9 @@ struct UDFFunctionCallHelper {
StatusOr<ColumnPtr> call(FunctionContext* ctx, Columns& columns, size_t size) {
auto& helper = JVMFunctionHelper::getInstance();
JNIEnv* env = helper.getEnv();
std::vector<DirectByteBuffer> buffers;
int num_cols = ctx->get_num_args();
std::vector<const Column*> input_cols;

for (auto& column : columns) {
if (column->only_null()) {
// we will handle NULL later
} else if (column->is_constant()) {
column = ColumnHelper::unpack_and_duplicate_const_column(size, column);
}
}

for (const auto& col : columns) {
input_cols.emplace_back(col.get());
}
Expand All @@ -68,29 +59,26 @@ struct UDFFunctionCallHelper {
auto defer = DeferOp([env]() { env->PopLocalFrame(nullptr); });
// convert input columns to object columns
std::vector<jobject> input_col_objs;
auto st = JavaDataTypeConverter::convert_to_boxed_array(ctx, &buffers, input_cols.data(), num_cols, size,
&input_col_objs);
RETURN_IF_UNLIKELY(!st.ok(), ColumnHelper::create_const_null_column(size));
auto st =
JavaDataTypeConverter::convert_to_boxed_array(ctx, input_cols.data(), num_cols, size, &input_col_objs);
RETURN_IF_ERROR(st);

// call UDF method
ASSIGN_OR_RETURN(auto res, helper.batch_call(fn_desc->call_stub.get(), input_col_objs.data(),
input_col_objs.size(), size));

RETURN_IF_UNLIKELY_NULL(res, ColumnHelper::create_const_null_column(size));
// get result
auto result_cols = get_boxed_result(ctx, res, size);
return result_cols;
}

ColumnPtr get_boxed_result(FunctionContext* ctx, jobject result, size_t num_rows) {
StatusOr<ColumnPtr> get_boxed_result(FunctionContext* ctx, jobject result, size_t num_rows) {
if (result == nullptr) {
return ColumnHelper::create_const_null_column(num_rows);
}
auto& helper = JVMFunctionHelper::getInstance();
DCHECK(call_desc->method_desc[0].is_box);
TypeDescriptor type_desc(call_desc->method_desc[0].type);
auto res = ColumnHelper::create_column(type_desc, true);
helper.get_result_from_boxed_array(ctx, type_desc.type, res.get(), result, num_rows);
auto res = ColumnHelper::create_column(ctx->get_return_type(), true);
RETURN_IF_ERROR(helper.get_result_from_boxed_array(ctx->get_return_type().type, res.get(), result, num_rows));
down_cast<NullableColumn*>(res.get())->update_has_null();
return res;
}
Expand Down
37 changes: 26 additions & 11 deletions be/src/exprs/table_function/java_udtf_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "column/nullable_column.h"
#include "column/vectorized_fwd.h"
#include "common/compiler_util.h"
#include "exprs/function_context.h"
#include "exprs/table_function/table_function.h"
#include "gutil/casts.h"
#include "jni.h"
Expand All @@ -41,15 +42,19 @@ const TableFunction* getJavaUDTFFunction() {

class JavaUDTFState : public TableFunctionState {
public:
JavaUDTFState(std::string libpath, std::string symbol, const TTypeDesc& desc)
: _libpath(std::move(libpath)), _symbol(std::move(symbol)), _ret_type(TypeDescriptor::from_thrift(desc)) {}
JavaUDTFState(std::string libpath, std::string symbol, std::vector<TypeDescriptor> type_desc, const TTypeDesc& desc)
: _libpath(std::move(libpath)),
_symbol(std::move(symbol)),
_arg_type_descs(std::move(type_desc)),
_ret_type(TypeDescriptor::from_thrift(desc)) {}
~JavaUDTFState() override = default;

Status open();
void close();

const TypeDescriptor& type_desc() { return _ret_type; }
JavaMethodDescriptor* method_process() { return _process.get(); }
const std::vector<TypeDescriptor>& arg_type_descs() const { return _arg_type_descs; }
jclass get_udtf_clazz() { return _udtf_class.clazz(); }
jobject handle() { return _udtf_handle.handle(); }

Expand All @@ -62,6 +67,7 @@ class JavaUDTFState : public TableFunctionState {
JVMClass _udtf_class = nullptr;
JavaGlobalRef _udtf_handle = nullptr;
std::unique_ptr<JavaMethodDescriptor> _process;
std::vector<TypeDescriptor> _arg_type_descs;
TypeDescriptor _ret_type;
};

Expand Down Expand Up @@ -96,7 +102,11 @@ Status JavaUDTFFunction::init(const TFunction& fn, TableFunctionState** state) c
std::string libpath;
RETURN_IF_ERROR(UserFunctionCache::instance()->get_libpath(fn.fid, fn.hdfs_location, fn.checksum, &libpath));
// Now we only support one return types
*state = new JavaUDTFState(std::move(libpath), fn.table_fn.symbol, fn.table_fn.ret_types[0]);
std::vector<TypeDescriptor> arg_typedescs;
for (auto& type : fn.arg_types) {
arg_typedescs.push_back(TypeDescriptor::from_thrift(type));
}
*state = new JavaUDTFState(std::move(libpath), fn.table_fn.symbol, arg_typedescs, fn.table_fn.ret_types[0]);
return Status::OK();
}

Expand Down Expand Up @@ -162,8 +172,12 @@ std::pair<Columns, UInt32Column::Ptr> JavaUDTFFunction::process(RuntimeState* ru

for (int j = 0; j < num_cols; ++j) {
auto method_type = stateUDTF->method_process()->method_desc[j + 1];
jvalue val = cast_to_jvalue<true>(method_type.type, method_type.is_box, cols[j].get(), i);
call_stack.push_back(val);
auto val_st = cast_to_jvalue(stateUDTF->arg_type_descs()[j], method_type.is_box, cols[j].get(), i);
if (!val_st.ok()) {
stateUDTF->set_status(val_st.status());
return {};
}
call_stack.push_back(val_st.value());
}

rets[i] = env->CallObjectMethodA(stateUDTF->handle(), methodID, call_stack.data());
Expand All @@ -185,22 +199,23 @@ std::pair<Columns, UInt32Column::Ptr> JavaUDTFFunction::process(RuntimeState* ru
auto col = ColumnHelper::create_column(stateUDTF->type_desc(), true);
col->reserve(num_rows);

// TODO: support primitive array
MethodTypeDescriptor method_desc{stateUDTF->type_desc().type, true, true};

for (int i = 0; i < num_rows; ++i) {
int len = rets[i] != nullptr ? env->GetArrayLength((jarray)rets[i]) : 0;
offsets[i + 1] = offsets[i] + len;
// update for col
for (int j = 0; j < len; ++j) {
jobject vi = env->GetObjectArrayElement((jobjectArray)rets[i], j);
LOCAL_REF_GUARD_ENV(env, vi);
auto st = check_type_matched(method_desc, vi);
auto st = check_type_matched(stateUDTF->type_desc(), vi);
if (UNLIKELY(!st.ok())) {
state->set_status(st);
return std::make_pair(Columns{}, nullptr);
return {};
}
auto res = append_jvalue(stateUDTF->type_desc(), true, col.get(), {.l = vi});
if (UNLIKELY(!res.ok())) {
state->set_status(Status::InternalError(res.to_string()));
return {};
}
append_jvalue(method_desc, col.get(), {.l = vi});
}
}

Expand Down
Loading

0 comments on commit 6f3d167

Please sign in to comment.