Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
fix string date
Browse files Browse the repository at this point in the history
Signed-off-by: Yuan Zhou <[email protected]>
  • Loading branch information
zhouyuan committed Jun 24, 2022
1 parent 174aad0 commit 1a9fa29
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
prepare_str_ += prepare_ss.str();
} else if (func_name.find("cast") != std::string::npos &&
func_name.compare("castDATE") != 0 &&
func_name.compare("castDATE_nullsafe") != 0 &&
func_name.compare("castDECIMAL") != 0 &&
func_name.compare("castDECIMALNullOnOverflow") != 0 &&
func_name.compare("castINTOrNull") != 0 &&
Expand Down Expand Up @@ -503,8 +504,8 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
fix_ss << " * 1.0 ";
}
prepare_ss << codes_str_ << " = static_cast<"
<< GetCTypeString(node.return_type()) << ">("
<< child_visitor_list[0]->GetResult() << fix_ss.str() << ");"
<< GetCTypeString(node.return_type()) << ">(castDATE32("
<< child_visitor_list[0]->GetResult() << fix_ss.str() << "));"
<< std::endl;
}
} else {
Expand Down Expand Up @@ -557,13 +558,13 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
real_codes_str_ = codes_str_;
real_validity_str_ = check_str_;
header_list_.push_back(R"(#include "third_party/murmurhash/murmurhash32.h")");
} else if (func_name.compare("castDATE") == 0) {
} else if (func_name.compare("castDATE") == 0 || func_name.compare("castDATE_nullsafe") == 0) {
codes_str_ = func_name + "_" + std::to_string(cur_func_id);
auto validity = codes_str_ + "_validity";
real_codes_str_ = codes_str_;
real_validity_str_ = validity;
std::stringstream prepare_ss;
auto typed_func_name = func_name;
auto typed_func_name = std::string("castDATE");
if (node.return_type()->id() == arrow::Type::INT32 ||
node.return_type()->id() == arrow::Type::DATE32) {
typed_func_name += "32";
Expand All @@ -580,7 +581,7 @@ arrow::Status ExpressionCodegenVisitor::Visit(const gandiva::FunctionNode& node)
<< ";" << std::endl;
prepare_ss << "if (" << validity << ") {" << std::endl;
prepare_ss << codes_str_ << " = " << typed_func_name << "("
<< child_visitor_list[0]->GetResult() << ");" << std::endl;
<< child_visitor_list[0]->GetResult() << ", reinterpret_cast<int64_t>(execution_context_.get()));" << std::endl;
prepare_ss << "}" << std::endl;

for (int i = 0; i < 1; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,8 @@ class TypedWholeStageCodeGenImpl : public CodeGenBase {
WholeStageCodeGenResultIterator(arrow::compute::ExecContext* ctx,
std::vector<std::shared_ptr<GandivaProjector>> gandiva_projector_list,
const std::shared_ptr<arrow::Schema>& result_schema)
: ctx_(ctx), result_schema_(result_schema), gandiva_projector_list_(gandiva_projector_list) {)";
: ctx_(ctx), result_schema_(result_schema), gandiva_projector_list_(gandiva_projector_list) {
execution_context_.reset(new gandiva::ExecutionContext());)";
if (!is_aggr_) {
codes_ss << GetBuilderInitializeCodes(output_field_list) << std::endl;
} else {
Expand Down Expand Up @@ -536,6 +537,7 @@ class TypedWholeStageCodeGenImpl : public CodeGenBase {
codes_ss << R"(
private:
arrow::compute::ExecContext* ctx_;
std::unique_ptr<gandiva::ExecutionContext> execution_context_;
bool should_stop_ = false;
std::vector<std::shared_ptr<GandivaProjector>> gandiva_projector_list_;
std::shared_ptr<arrow::Schema> result_schema_;)"
Expand Down
6 changes: 4 additions & 2 deletions native-sql-engine/cpp/src/precompile/gandiva.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <arrow/json/api.h>
#include <arrow/json/parser.h>
#include <arrow/util/decimal.h>
#include <gandiva/execution_context.h>
#include <math.h>
#include <re2/re2.h>

Expand All @@ -31,8 +32,9 @@
#include "third_party/gandiva/decimal_ops.h"
#include "third_party/gandiva/types.h"

int32_t castDATE32(int32_t in) { return castDATE_int32(in); }
int64_t castDATE64(int32_t in) { return castDATE_date32(in); }
int32_t castDATE32(int32_t in, int64_t ctx = 0) { return castDATE_int32(in); }
int64_t castDATE64(int32_t in, int64_t ctx = 0) { return castDATE_date32(in); }
int64_t castDATE64(const std::string in, int64_t ctx = 0) { return castDATE_utf8(ctx, in.c_str(), in.length()); }
int64_t extractYear(int64_t millis) { return extractYear_timestamp(millis); }
template <typename T>
T round2(T val, int precision = 2) {
Expand Down

0 comments on commit 1a9fa29

Please sign in to comment.