diff --git a/be/src/util/jni-util.cpp b/be/src/util/jni-util.cpp index 54ca27a64d77952..9eb096b52bf9e8b 100644 --- a/be/src/util/jni-util.cpp +++ b/be/src/util/jni-util.cpp @@ -16,6 +16,7 @@ // under the License. #include "util/jni-util.h" +#include "jni_md.h" #ifdef LIBJVM #include #include @@ -45,8 +46,10 @@ void FindOrCreateJavaVM() { vm_args.nOptions = 1; vm_args.ignoreUnrecognized = JNI_TRUE; - int res = JNI_CreateJavaVM(&g_vm, (void**)&env, &vm_args); - DCHECK_LT(res, 0) << "Failed tp create JVM, code= " << res; + jint res = JNI_CreateJavaVM(&g_vm, (void**)&env, &vm_args); + if (JNI_OK != res) { + DCHECK(false) << "Failed tp create JVM, code= " << res; + } } else { CHECK_EQ(rv, 0) << "Could not find any created Java VM"; CHECK_EQ(num_vms, 1) << "No VMs returned"; diff --git a/be/src/vec/functions/function_java_udf.cpp b/be/src/vec/functions/function_java_udf.cpp index ba7f58259aeb4bc..9e3681eb40ad09d 100644 --- a/be/src/vec/functions/function_java_udf.cpp +++ b/be/src/vec/functions/function_java_udf.cpp @@ -49,10 +49,12 @@ JavaFunctionCall::JavaFunctionCall(const TFunction& fn, const DataTypes& argumen Status JavaFunctionCall::prepare(FunctionContext* context, FunctionContext::FunctionStateScope scope) { - DCHECK(executor_cl_ == NULL) << "Init() already called!"; - JNIEnv* env; + //DCHECK(executor_cl_ == NULL) << "Init() already called!"; + JNIEnv* env = nullptr; RETURN_IF_ERROR(JniUtil::GetJNIEnv(&env)); - if (env == NULL) return Status::InternalError("Failed to get/create JVM"); + if (env == nullptr) { + return Status::InternalError("Failed to get/create JVM"); + } RETURN_IF_ERROR(JniUtil::GetGlobalClassRef(env, EXECUTOR_CLASS, &executor_cl_)); executor_ctor_id_ = env->GetMethodID(executor_cl_, "", EXECUTOR_CTOR_SIGNATURE); RETURN_ERROR_IF_EXC(env); @@ -101,7 +103,7 @@ Status JavaFunctionCall::prepare(FunctionContext* context, Status JavaFunctionCall::execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t num_rows, bool dry_run) { - JNIEnv* env; + JNIEnv* env = nullptr; RETURN_IF_ERROR(JniUtil::GetJNIEnv(&env)); JniContext* jni_ctx = reinterpret_cast( context->get_function_state(FunctionContext::THREAD_LOCAL)); @@ -109,6 +111,7 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, for (size_t col_idx : arguments) { ColumnWithTypeAndName& column = block.get_by_position(col_idx); auto col = column.column->convert_to_full_column_if_const(); + auto& col_type = column.type; if (!_argument_types[arg_idx]->equals(*column.type)) { return Status::InvalidArgument(strings::Substitute( "$0-th input column's type $1 does not equal to required type $2", arg_idx, @@ -117,6 +120,7 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, auto data_col = col; if (auto* nullable = check_and_get_column(*col)) { data_col = nullable->get_nested_column_ptr(); + col_type = remove_nullable(col_type); auto null_col = check_and_get_column>(nullable->get_null_map_column_ptr()); jni_ctx->input_nulls_buffer_ptr.get()[arg_idx] = @@ -124,12 +128,15 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, } else { jni_ctx->input_nulls_buffer_ptr.get()[arg_idx] = -1; } - if (const ColumnString* str_col = check_and_get_column(data_col.get())) { + WhichDataType type(col_type); + if (type.is_string_or_fixed_string()) { + const ColumnString* str_col = assert_cast(data_col.get()); jni_ctx->input_values_buffer_ptr.get()[arg_idx] = reinterpret_cast(str_col->get_chars().data()); jni_ctx->input_offsets_ptrs.get()[arg_idx] = reinterpret_cast(str_col->get_offsets().data()); - } else if (data_col->is_numeric()) { + } else if (type.is_int() || type.is_uint() || type.is_float() || + type.is_date_or_datetime() || type.is_decimal()) { jni_ctx->input_values_buffer_ptr.get()[arg_idx] = reinterpret_cast(data_col->get_raw_data().data); } else { @@ -151,7 +158,8 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, *(jni_ctx->output_null_value) = reinterpret_cast(null_col->get_data().data()); #ifndef EVALUATE_JAVA_UDF #define EVALUATE_JAVA_UDF \ - if (const ColumnString* str_col = check_and_get_column(data_col.get())) { \ + if (data_col->is_column_string()) { \ + const ColumnString* str_col = assert_cast(data_col.get()); \ ColumnString::Chars& chars = const_cast(str_col->get_chars()); \ ColumnString::Offsets& offsets = \ const_cast(str_col->get_offsets()); \ @@ -177,7 +185,7 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, env->CallNonvirtualVoidMethodA(jni_ctx->executor, executor_cl_, executor_evaluate_id_, \ nullptr); \ } \ - } else if (data_col->is_numeric()) { \ + } else if (data_col->is_numeric() || data_col->is_column_decimal()) { \ data_col->reserve(num_rows); \ data_col->resize(num_rows); \ *(jni_ctx->output_value_buffer) = \ @@ -205,7 +213,7 @@ Status JavaFunctionCall::close(FunctionContext* context, FunctionContext::FunctionStateScope scope) { JniContext* jni_ctx = reinterpret_cast( context->get_function_state(FunctionContext::THREAD_LOCAL)); - if (jni_ctx != NULL) { + if (jni_ctx != nullptr) { delete jni_ctx; context->set_function_state(FunctionContext::THREAD_LOCAL, nullptr); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java index 2446fb8249b13ab..d1c2f830c7398d9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java @@ -56,11 +56,15 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.lang.reflect.Parameter; +import java.math.BigDecimal; +import java.math.BigInteger; import java.net.MalformedURLException; import java.net.URL; import java.net.URLClassLoader; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; +import java.time.LocalDate; +import java.time.LocalDateTime; import java.util.List; import java.util.Map; import java.util.Set; @@ -370,6 +374,10 @@ private void analyzeJavaUdf(String clazz) throws AnalysisException { .put(PrimitiveType.CHAR, Sets.newHashSet(String.class)) .put(PrimitiveType.VARCHAR, Sets.newHashSet(String.class)) .put(PrimitiveType.STRING, Sets.newHashSet(String.class)) + .put(PrimitiveType.DATE, Sets.newHashSet(LocalDate.class)) + .put(PrimitiveType.DATETIME, Sets.newHashSet(LocalDateTime.class)) + .put(PrimitiveType.LARGEINT, Sets.newHashSet(BigInteger.class)) + .put(PrimitiveType.DECIMALV2, Sets.newHashSet(BigDecimal.class)) .build(); private void checkUdfType(Class clazz, Method method, Type expType, Class pType, String pname) diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java index cca787400a681ae..708b8fe8f54564e 100644 --- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java +++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java @@ -33,11 +33,16 @@ import java.io.IOException; import java.lang.reflect.Constructor; import java.lang.reflect.Method; +import java.math.BigDecimal; +import java.math.BigInteger; import java.net.MalformedURLException; import java.net.URL; import java.net.URLClassLoader; import java.nio.charset.StandardCharsets; +import java.time.LocalDate; +import java.time.LocalDateTime; import java.util.ArrayList; +import java.util.Arrays; public class UdfExecutor { private static final Logger LOG = Logger.getLogger(UdfExecutor.class); @@ -95,7 +100,11 @@ public enum JavaUdfDataType { DOUBLE("DOUBLE", TPrimitiveType.DOUBLE, 8), CHAR("CHAR", TPrimitiveType.CHAR, 0), VARCHAR("VARCHAR", TPrimitiveType.VARCHAR, 0), - STRING("STRING", TPrimitiveType.STRING, 0); + STRING("STRING", TPrimitiveType.STRING, 0), + DATE("DATE", TPrimitiveType.DATE, 8), + DATETIME("DATETIME", TPrimitiveType.DATETIME, 8), + LARGEINT("LARGEINT", TPrimitiveType.LARGEINT, 16), + DECIMALV2("DECIMALV2", TPrimitiveType.DECIMALV2, 16); private final String description_; private final TPrimitiveType thriftType_; @@ -139,13 +148,23 @@ public static JavaUdfDataType getType(Class c) { return JavaUdfDataType.CHAR; } else if (c == String.class) { return JavaUdfDataType.STRING; + } else if (c == LocalDate.class) { + return JavaUdfDataType.DATE; + } else if (c == LocalDateTime.class) { + return JavaUdfDataType.DATETIME; + } else if (c == BigInteger.class) { + return JavaUdfDataType.LARGEINT; + } else if (c == BigDecimal.class) { + return JavaUdfDataType.DECIMALV2; } return JavaUdfDataType.INVALID_TYPE; } public static boolean isSupported(Type t) { for (JavaUdfDataType javaType : JavaUdfDataType.values()) { - if (javaType == JavaUdfDataType.INVALID_TYPE) continue; + if (javaType == JavaUdfDataType.INVALID_TYPE) { + continue; + } if (javaType.getPrimitiveType() == t.getPrimitiveType().toThrift()) { return true; } @@ -160,14 +179,12 @@ public static boolean isSupported(Type t) { */ public UdfExecutor(byte[] thriftParams) throws Exception { TJavaUdfExecutorCtorParams request = new TJavaUdfExecutorCtorParams(); - TDeserializer deserializer = new TDeserializer(PROTOCOL_FACTORY); try { deserializer.deserialize(request, thriftParams); } catch (TException e) { throw new InternalException(e.getMessage()); } - String className = request.fn.scalar_fn.symbol; String jarFile = request.location; Type retType = UdfUtils.fromThrift(request.fn.ret_type, 0).first; @@ -221,9 +238,11 @@ public void close() { */ public void evaluate() throws UdfRuntimeException { int batch_size = UdfUtils.UNSAFE.getInt(null, batch_size_ptr_); + LOG.info("evaluate() and the row is " + batch_size); + try { if (retType_.equals(JavaUdfDataType.STRING) || retType_.equals(JavaUdfDataType.VARCHAR) - || retType_.equals(JavaUdfDataType.CHAR)) { + || retType_.equals(JavaUdfDataType.CHAR)) { // If this udf return variable-size type (e.g.) String, we have to allocate output // buffer multiple times until buffer size is enough to store output column. So we // always begin with the last evaluated row instead of beginning of this batch. @@ -240,9 +259,9 @@ public void evaluate() throws UdfRuntimeException { // Currently, -1 indicates this column is not nullable. So input argument is // null iff inputNullsPtrs_ != -1 and nullCol[row_idx] != 0. if (UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputNullsPtrs_, i)) == -1 || - UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputNullsPtrs_, i)) + row_idx_) == 0) { + UdfUtils.getAddressAtOffset(inputNullsPtrs_, i)) == -1 + || UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputNullsPtrs_, i)) + row_idx_) == 0) { inputArgs_[i] = inputObjects_[i]; } else { inputArgs_[i] = null; @@ -307,36 +326,96 @@ private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException switch (retType_) { case BOOLEAN: { boolean val = (boolean) obj; - UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), val ? (byte) 1 : 0); + UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + val ? (byte) 1 : 0); return true; } case TINYINT: { - UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), (byte) obj); + UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + (byte) obj); return true; } case SMALLINT: { - UdfUtils.UNSAFE.putShort(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), (short) obj); + UdfUtils.UNSAFE.putShort(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + (short) obj); return true; } case INT: { - UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), (int) obj); + UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + (int) obj); return true; } case BIGINT: { - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), (long) obj); + UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + (long) obj); return true; } case FLOAT: { - UdfUtils.UNSAFE.putFloat(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), (float) obj); + UdfUtils.UNSAFE.putFloat(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + (float) obj); return true; } case DOUBLE: { - UdfUtils.UNSAFE.putDouble(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), (double) obj); + UdfUtils.UNSAFE.putDouble(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + (double) obj); + return true; + } + case DATE: { + LocalDate date = (LocalDate) obj; + long time = + convertDateTimeToLong(date.getYear(), date.getMonthValue(), date.getDayOfMonth(), 0, 0, 0, true); + UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + time); + return true; + } + case DATETIME: { + LocalDateTime date = (LocalDateTime) obj; + long time = + convertDateTimeToLong(date.getYear(), date.getMonthValue(), date.getDayOfMonth(), date.getHour(), + date.getMinute(), date.getSecond(), false); + UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), + time); + return true; + } + case LARGEINT: { + BigInteger data = (BigInteger) obj; + byte[] bytes = convertByteOrder(data.toByteArray()); + + //here value is 16 bytes, so is result data greater than the maximum of 16 bytes + //it will return a wrong num to backend; + byte[] value = new byte[16]; + //check data is negative + if (data.signum() == -1) { + Arrays.fill(value, (byte) -1); + } + for (int index = 0; index < Math.min(bytes.length, value.length); ++index) { + value[index] = bytes[index]; + } + + UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, + UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), value.length); + return true; + } + case DECIMALV2: { + BigInteger data = ((BigDecimal) obj).unscaledValue(); + byte[] bytes = convertByteOrder(data.toByteArray()); + + byte[] value = new byte[16]; + if (data.signum() == -1) { + Arrays.fill(value, (byte) -1); + } + + for (int index = 0; index < Math.min(bytes.length, value.length); ++index) { + value[index] = bytes[index]; + } + + UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, + UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + row * retType_.getLen(), value.length); return true; } case CHAR: case VARCHAR: - case STRING: + case STRING: { long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr_); byte[] bytes = ((String) obj).getBytes(StandardCharsets.UTF_8); if (outputOffset_ + bytes.length + 1 > bufferSize) { @@ -344,13 +423,14 @@ private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException } outputOffset_ += (bytes.length + 1); UdfUtils.UNSAFE.putChar(UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + - outputOffset_ - 1, UdfUtils.END_OF_STRING); + outputOffset_ - 1, UdfUtils.END_OF_STRING); UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr_) + 4L * row, - Integer.parseUnsignedInt(String.valueOf(outputOffset_))); + Integer.parseUnsignedInt(String.valueOf(outputOffset_))); UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + - outputOffset_ - bytes.length - 1, bytes.length); + UdfUtils.UNSAFE.getLong(null, outputBufferPtr_) + + outputOffset_ - bytes.length - 1, bytes.length); return true; + } default: throw new UdfRuntimeException("Unsupported return type: " + retType_); } @@ -365,41 +445,81 @@ private void allocateInputObjects(long row) throws UdfRuntimeException { for (int i = 0; i < argTypes_.length; ++i) { switch (argTypes_[i]) { case BOOLEAN: - inputObjects_[i] = UdfUtils.UNSAFE.getBoolean(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + row); + inputObjects_[i] = UdfUtils.UNSAFE.getBoolean(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + row); break; case TINYINT: - inputObjects_[i] = UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + row); + inputObjects_[i] = UdfUtils.UNSAFE.getByte(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + row); break; case SMALLINT: - inputObjects_[i] = UdfUtils.UNSAFE.getShort(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 2L * row); + inputObjects_[i] = UdfUtils.UNSAFE.getShort(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 2L * row); break; case INT: - inputObjects_[i] = UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 4L * row); + inputObjects_[i] = UdfUtils.UNSAFE.getInt(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 4L * row); break; case BIGINT: - inputObjects_[i] = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L * row); + inputObjects_[i] = UdfUtils.UNSAFE.getLong(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L * row); break; case FLOAT: - inputObjects_[i] = UdfUtils.UNSAFE.getFloat(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 4L * row); + inputObjects_[i] = UdfUtils.UNSAFE.getFloat(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 4L * row); break; case DOUBLE: - inputObjects_[i] = UdfUtils.UNSAFE.getDouble(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L * row); + inputObjects_[i] = UdfUtils.UNSAFE.getDouble(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L * row); + break; + case DATE: { + long data = UdfUtils.UNSAFE.getLong(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L * row); + inputObjects_[i] = convertToDate(data); break; + } + case DATETIME: { + long data = UdfUtils.UNSAFE.getLong(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 8L * row); + inputObjects_[i] = convertToDateTime(data); + break; + } + case LARGEINT: { + long base = + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 16L * row; + byte[] bytes = new byte[16]; + UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16); + + inputObjects_[i] = new BigInteger(convertByteOrder(bytes)); + break; + } + case DECIMALV2: { + long base = + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + 16L * row; + byte[] bytes = new byte[16]; + UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16); + + BigInteger value = new BigInteger(convertByteOrder(bytes)); + inputObjects_[i] = new BigDecimal(value, 9); + break; + } case CHAR: case VARCHAR: - case STRING: - long offset = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputOffsetsPtrs_, i)) + 4L * row)); + case STRING: { + long offset = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputOffsetsPtrs_, i)) + 4L * row)); long numBytes = row == 0 ? offset - 1 : offset - Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputOffsetsPtrs_, i)) + 4L * (row - 1))) - 1; - long base = row == 0 ? UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) : - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + offset - numBytes - 1; + UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputOffsetsPtrs_, i)) + 4L * (row - 1))) - 1; + long base = + row == 0 ? UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) : + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs_, i)) + offset - + numBytes - 1; byte[] bytes = new byte[(int) numBytes]; UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes); inputObjects_[i] = new String(bytes, StandardCharsets.UTF_8); break; + } default: throw new UdfRuntimeException("Unsupported argument type: " + argTypes_[i]); } @@ -408,7 +528,7 @@ private void allocateInputObjects(long row) throws UdfRuntimeException { private URLClassLoader getClassLoader(String jarPath) throws MalformedURLException { URL url = new File(jarPath).toURI().toURL(); - return URLClassLoader.newInstance(new URL[]{url}, getClass().getClassLoader()); + return URLClassLoader.newInstance(new URL[] {url}, getClass().getClassLoader()); } /** @@ -417,7 +537,7 @@ private URLClassLoader getClassLoader(String jarPath) throws MalformedURLExcepti * if the return type is not supported. */ private boolean setReturnType(Type retType, Class udfReturnType) - throws InternalException { + throws InternalException { if (!JavaUdfDataType.isSupported(retType)) { throw new InternalException("Unsupported return type: " + retType.toSql()); } @@ -442,7 +562,7 @@ private boolean setArgTypes(Type[] parameterTypes, Class[] udfArgTypes) { for (int i = 0; i < udfArgTypes.length; ++i) { argTypes_[i] = JavaUdfDataType.getType(udfArgTypes[i]); if (argTypes_[i].getPrimitiveType() - != parameterTypes[i].getPrimitiveType().toThrift()) { + != parameterTypes[i].getPrimitiveType().toThrift()) { return false; } } @@ -469,32 +589,42 @@ private void init(String jarPath, String udfPath, Method[] methods = c.getMethods(); for (Method m : methods) { // By convention, the udf must contain the function "evaluate" - if (!m.getName().equals(UDF_FUNCTION_NAME)) continue; + if (!m.getName().equals(UDF_FUNCTION_NAME)) { + continue; + } signatures.add(m.toGenericString()); Class[] methodTypes = m.getParameterTypes(); // Try to match the arguments - if (methodTypes.length != parameterTypes.length) continue; + if (methodTypes.length != parameterTypes.length) { + continue; + } method_ = m; if (methodTypes.length == 0 && parameterTypes.length == 0) { // Special case where the UDF doesn't take any input args - if (!setReturnType(retType, m.getReturnType())) continue; + if (!setReturnType(retType, m.getReturnType())) { + continue; + } LOG.debug("Loaded UDF '" + udfPath + "' from " + jarPath); return; } - if (!setReturnType(retType, m.getReturnType())) continue; - if (!setArgTypes(parameterTypes, methodTypes)) continue; + if (!setReturnType(retType, m.getReturnType())) { + continue; + } + if (!setArgTypes(parameterTypes, methodTypes)) { + continue; + } LOG.debug("Loaded UDF '" + udfPath + "' from " + jarPath); return; } StringBuilder sb = new StringBuilder(); sb.append("Unable to find evaluate function with the correct signature: ") - .append(udfPath + ".evaluate(") - .append(Joiner.on(", ").join(parameterTypes)) - .append(")\n") - .append("UDF contains: \n ") - .append(Joiner.on("\n ").join(signatures)); + .append(udfPath + ".evaluate(") + .append(Joiner.on(", ").join(parameterTypes)) + .append(")\n") + .append("UDF contains: \n ") + .append(Joiner.on("\n ").join(signatures)); throw new UdfRuntimeException(sb.toString()); } catch (MalformedURLException e) { throw new UdfRuntimeException("Unable to load jar.", e); @@ -504,12 +634,74 @@ private void init(String jarPath, String udfPath, throw new UdfRuntimeException("Unable to find class.", e); } catch (NoSuchMethodException e) { throw new UdfRuntimeException( - "Unable to find constructor with no arguments.", e); + "Unable to find constructor with no arguments.", e); } catch (IllegalArgumentException e) { throw new UdfRuntimeException( - "Unable to call UDF constructor with no arguments.", e); + "Unable to call UDF constructor with no arguments.", e); } catch (Exception e) { throw new UdfRuntimeException("Unable to call create UDF instance.", e); } } -} + + // input is a 64bit num from backend, and then get year, month, day, hour, minus, second by the order of bits + // return a new LocalDateTime data to evaluate method; + private LocalDateTime convertToDateTime(long date) { + int year = (int) (date >> 48); + int year_month = (int) (date >> 40); + int year_month_day = (int) (date >> 32); + + int month = (year_month & 0XFF); + int day = (year_month_day & 0XFF); + + int hour_minute_second = (int) (date % (1 << 31)); + int minute_type_neg = (hour_minute_second % (1 << 16)); + + int hour = (hour_minute_second >> 24); + int minute = ((hour_minute_second >> 16) & 0XFF); + int second = (minute_type_neg >> 4); + //here don't need those bits are type = ((minus_type_neg >> 1) & 0x7); + + LocalDateTime value = LocalDateTime.of(year, month, day, hour, minute, second); + return value; + } + + private LocalDate convertToDate(long date) { + int year = (int) (date >> 48); + int year_month = (int) (date >> 40); + int year_month_day = (int) (date >> 32); + + int month = (year_month & 0XFF); + int day = (year_month_day & 0XFF); + LocalDate value = LocalDate.of(year, month, day); + return value; + } + + //input is the second, minute, hours, day , month and year respectively + //and then combining all num to a 64bit value return to backend; + long convertDateTimeToLong(int year, int month, int day, int hour, int minute, int second, Boolean isDate) { + long time = 0; + int type = isDate ? 2 : 3; + + time = time + year; + time = (time << 8) + month; + time = (time << 8) + day; + time = (time << 8) + hour; + time = (time << 8) + minute; + time = (time << 12) + second; + time = (time << 3) + type; + //this bit is int neg = 0; + time = (time << 1); + return time; + } + + // Change the order of the bytes, Because JVM is Big-Endian , x86 is Little-Endian + private byte[] convertByteOrder(byte[] bytes) { + int length = bytes.length; + for (int i = 0; i < length / 2; ++i) { + byte temp = bytes[i]; + bytes[i] = bytes[length - 1 - i]; + bytes[length - 1 - i] = temp; + } + return bytes; + } +} \ No newline at end of file