Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-13568: [Gandiva] Support null data type for gandiva. #10884

Closed
wants to merge 10 commits into from
7 changes: 2 additions & 5 deletions c_glib/test/gandiva/test-null-literal-node.rb
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,8 @@ def setup

def test_invalid_type
return_type = Arrow::NullDataType.new
message =
"[gandiva][null-literal-node][new] " +
"failed to create: <#{return_type}>"
assert_raise(Arrow::Error::Invalid.new(message)) do
Gandiva::NullLiteralNode.new(return_type)
literal_node = Gandiva::NullLiteralNode.new(return_type)
assert_equal(return_type, literal_node.return_type)
end
end

Expand Down
2 changes: 2 additions & 0 deletions cpp/src/gandiva/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ set(SRC_FILES
expression_registry.cc
exported_funcs_registry.cc
filter.cc
null_ops.cc
function_ir_builder.cc
function_registry.cc
function_registry_arithmetic.cc
Expand Down Expand Up @@ -236,6 +237,7 @@ add_gandiva_test(internals-test
random_generator_holder_test.cc
hash_utils_test.cc
gdv_function_stubs_test.cc
null_ops_test.cc
EXTRA_DEPENDENCIES
LLVM::LLVM_INTERFACE
${GANDIVA_OPENSSL_LIBS}
Expand Down
18 changes: 13 additions & 5 deletions cpp/src/gandiva/annotator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,21 @@ void Annotator::PrepareBuffersForField(const FieldDescriptor& desc,
++buffer_idx;
}

uint8_t* data_buf = const_cast<uint8_t*>(array_data.buffers[buffer_idx]->data());
eval_batch->SetBuffer(desc.data_idx(), data_buf, array_data.offset);
if (array_data.type->id() == arrow::Type::NA) {
eval_batch->SetBuffer(desc.data_idx(), nullptr, array_data.offset);
} else {
uint8_t* data_buf = const_cast<uint8_t*>(array_data.buffers[buffer_idx]->data());
eval_batch->SetBuffer(desc.data_idx(), data_buf, array_data.offset);
}
if (is_output) {
// pass in the Buffer object for output data buffers. Can be used for resizing.
uint8_t* data_buf_ptr =
reinterpret_cast<uint8_t*>(array_data.buffers[buffer_idx].get());
eval_batch->SetBuffer(desc.data_buffer_ptr_idx(), data_buf_ptr, array_data.offset);
if (array_data.type->id() == arrow::Type::NA) {
eval_batch->SetBuffer(desc.data_buffer_ptr_idx(), nullptr, array_data.offset);
} else {
uint8_t* data_buf_ptr =
reinterpret_cast<uint8_t*>(array_data.buffers[buffer_idx].get());
eval_batch->SetBuffer(desc.data_buffer_ptr_idx(), data_buf_ptr, array_data.offset);
}
}
}

Expand Down
8 changes: 8 additions & 0 deletions cpp/src/gandiva/dex.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,14 @@ class GANDIVA_EXPORT LiteralDex : public Dex {
LiteralHolder holder_;
};

/// decomposed expression for a null literal.
class GANDIVA_EXPORT NullLiteralDex : public Dex {
public:
NullLiteralDex() {}

void Accept(DexVisitor& visitor) override { visitor.Visit(*this); }
};

/// decomposed if-else expression.
class GANDIVA_EXPORT IfDex : public Dex {
public:
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/gandiva/dex_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class VectorReadFixedLenValueDex;
class VectorReadVarLenValueDex;
class LocalBitMapValidityDex;
class LiteralDex;
class NullLiteralDex;
class TrueDex;
class FalseDex;
class NonNullableFuncDex;
Expand All @@ -54,6 +55,7 @@ class GANDIVA_EXPORT DexVisitor {
virtual void Visit(const TrueDex& dex) = 0;
virtual void Visit(const FalseDex& dex) = 0;
virtual void Visit(const LiteralDex& dex) = 0;
virtual void Visit(const NullLiteralDex& dex) = 0;
virtual void Visit(const NonNullableFuncDex& dex) = 0;
virtual void Visit(const NullableNeverFuncDex& dex) = 0;
virtual void Visit(const NullableInternalFuncDex& dex) = 0;
Expand All @@ -80,6 +82,7 @@ class GANDIVA_EXPORT DexDefaultVisitor : public DexVisitor {
VISIT_DCHECK(TrueDex)
VISIT_DCHECK(FalseDex)
VISIT_DCHECK(LiteralDex)
VISIT_DCHECK(NullLiteralDex)
VISIT_DCHECK(NonNullableFuncDex)
VISIT_DCHECK(NullableNeverFuncDex)
VISIT_DCHECK(NullableInternalFuncDex)
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/gandiva/exported_funcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ class ExportedFuncsBase {
virtual void AddMappings(Engine* engine) const = 0;
};

// Class for exporting Null functions
class ExportedNullFunctions : public ExportedFuncsBase {
void AddMappings(Engine* engine) const override;
};
REGISTER_EXPORTED_FUNCS(ExportedNullFunctions);

// Class for exporting Stub functions
class ExportedStubFunctions : public ExportedFuncsBase {
void AddMappings(Engine* engine) const override;
Expand Down
7 changes: 7 additions & 0 deletions cpp/src/gandiva/expr_decomposer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,13 @@ Status ExprDecomposer::Visit(const LiteralNode& node) {
return Status::OK();
}

Status ExprDecomposer::Visit(const NullLiteralNode& node) {
auto value_dex = std::make_shared<NullLiteralDex>();
auto validity_dex = std::make_shared<FalseDex>();
result_ = std::make_shared<ValueValidityPair>(validity_dex, value_dex);
return Status::OK();
}

// The bolow functions use a stack to detect :
// a. nested if-else expressions.
// In such cases, the local bitmap can be re-used.
Expand Down
1 change: 1 addition & 0 deletions cpp/src/gandiva/expr_decomposer.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class GANDIVA_EXPORT ExprDecomposer : public NodeVisitor {
Status Visit(const FunctionNode& node) override;
Status Visit(const IfNode& node) override;
Status Visit(const LiteralNode& node) override;
Status Visit(const NullLiteralNode& node) override;
Status Visit(const BooleanNode& node) override;
Status Visit(const InExpressionNode<int32_t>& node) override;
Status Visit(const InExpressionNode<int64_t>& node) override;
Expand Down
22 changes: 17 additions & 5 deletions cpp/src/gandiva/expr_validator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,14 @@ Status ExprValidator::Validate(const ExpressionPtr& expr) {
}

Status ExprValidator::Visit(const FieldNode& node) {
auto llvm_type = types_->IRType(node.return_type()->id());
ARROW_RETURN_IF(llvm_type == nullptr,
Status::ExpressionValidationError("Field ", node.field()->name(),
" has unsupported data type ",
node.return_type()->name()));
auto return_type = node.return_type();
if (return_type->id() != arrow::Type::NA) {
auto llvm_type = types_->DataVecType(node.return_type());
ARROW_RETURN_IF(llvm_type == nullptr,
Status::ExpressionValidationError("Field ", node.field()->name(),
" has unsupported data type ",
node.return_type()->name()));
}

// Ensure that field is found in schema
auto field_in_schema_entry = field_map_.find(node.field()->name());
Expand Down Expand Up @@ -120,6 +123,15 @@ Status ExprValidator::Visit(const LiteralNode& node) {
return Status::OK();
}

Status ExprValidator::Visit(const NullLiteralNode& node) {
auto llvm_type = types_->DataVecType(node.return_type());
ARROW_RETURN_IF(llvm_type != nullptr,
Status::ExpressionValidationError("Should be data type ",
node.return_type()->name()));

return Status::OK();
}

Status ExprValidator::Visit(const BooleanNode& node) {
ARROW_RETURN_IF(
node.children().size() < 2,
Expand Down
1 change: 1 addition & 0 deletions cpp/src/gandiva/expr_validator.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class ExprValidator : public NodeVisitor {
Status Visit(const FunctionNode& node) override;
Status Visit(const IfNode& node) override;
Status Visit(const LiteralNode& node) override;
Status Visit(const NullLiteralNode& node) override;
Status Visit(const BooleanNode& node) override;
Status Visit(const InExpressionNode<int32_t>& node) override;
Status Visit(const InExpressionNode<int64_t>& node) override;
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/gandiva/function_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "gandiva/function_registry_datetime.h"
#include "gandiva/function_registry_hash.h"
#include "gandiva/function_registry_math_ops.h"
#include "gandiva/function_registry_null.h"
#include "gandiva/function_registry_string.h"
#include "gandiva/function_registry_timestamp_arithmetic.h"

Expand Down Expand Up @@ -65,6 +66,9 @@ SignatureMap FunctionRegistry::InitPCMap() {
auto v6 = GetDateTimeArithmeticFunctionRegistry();
pc_registry_.insert(std::end(pc_registry_), v6.begin(), v6.end());

auto v8 = GetNullFunctionRegistry();
pc_registry_.insert(std::end(pc_registry_), v8.begin(), v8.end());

for (auto& elem : pc_registry_) {
for (auto& func_signature : elem.signatures()) {
map.insert(std::make_pair(&(func_signature), &elem));
Expand Down
1 change: 1 addition & 0 deletions cpp/src/gandiva/function_registry_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ using arrow::int32;
using arrow::int64;
using arrow::int8;
using arrow::month_interval;
using arrow::null;
using arrow::uint16;
using arrow::uint32;
using arrow::uint64;
Expand Down
40 changes: 40 additions & 0 deletions cpp/src/gandiva/function_registry_null.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include <vector>

#include "gandiva/native_function.h"

namespace gandiva {

std::vector<NativeFunction> GetNullFunctionRegistry() {
static std::vector<NativeFunction> null_fn_registry_ = {
NativeFunction("equal",
{"not_equal", "less_than", "less_than_or_equal_to", "greater_than",
"greater_than_or_equal_to"},
DataTypeVector{null(), null()}, null(), kResultNullNever,
"compare_null_null"),
NativeFunction("isnull", {}, DataTypeVector{null()}, boolean(), kResultNullNever,
"isnull_null"),
NativeFunction("isnotnull", {}, DataTypeVector{null()}, boolean(), kResultNullNever,
"isnotnull_null")};
return null_fn_registry_;
}

} // namespace gandiva
33 changes: 28 additions & 5 deletions cpp/src/gandiva/llvm_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ llvm::Value* LLVMGenerator::GetDataReference(llvm::Value* arg_addrs, int idx,
llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure. But should we add this null type to supported types so that dremio can see this null type as a supported type. there seems to be some code in llvm_types.h for supported types. But I do not know this well enough to know whether you have to add something to supported types, so that dremio knows that Gandiva can be used as a viable codegen for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @rkavanap. Could you involve some guy familiar with dremio and help review this? If dremio might update code correspondingly, I think the change should be small.

llvm::Type* base_type = types()->DataVecType(field->type());
llvm::Value* ret;
if (base_type == nullptr) {
return nullptr;
}
if (base_type->isPointerTy()) {
ret = ir_builder()->CreateIntToPtr(load, base_type, name + "_darray");
} else {
Expand Down Expand Up @@ -363,6 +366,8 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count,
AddFunctionCall("gdv_fn_populate_varlen_vector", types()->i32_type(),
{arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, loop_var,
output_value->data(), output_value->length()});
} else if (output_type_id == arrow::Type::NA) {
// Do nothing when data type is null
} else {
return Status::NotImplemented("output type ", output->Type()->ToString(),
" not supported");
Expand Down Expand Up @@ -452,6 +457,10 @@ void LLVMGenerator::ComputeBitMapsForExpr(const CompiledExpr& compiled_expr,
// Extract the destination bitmap address.
int out_idx = compiled_expr.output()->validity_idx();
uint8_t* dst_bitmap = eval_batch.GetBuffer(out_idx);
if (dst_bitmap == nullptr) {
// Return when dst_bitmap is null meaning data type is null
return;
}
// Compute the destination bitmap.
if (selection_vector == nullptr) {
accumulator.ComputeResult(dst_bitmap);
Expand Down Expand Up @@ -491,7 +500,7 @@ llvm::Value* LLVMGenerator::AddFunctionCall(const std::string& full_name,

// build a call to the llvm function.
llvm::Value* value;
if (ret_type->isVoidTy()) {
if (ret_type == nullptr || ret_type->isVoidTy()) {
// void functions can't have a name for the call.
value = ir_builder()->CreateCall(fn, args);
} else {
Expand Down Expand Up @@ -556,6 +565,9 @@ void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex& dex) {
break;
}

case arrow::Type::NA:
break;

default: {
auto slot_offset = CreateGEP(builder, slot_ref, slot_index);
slot_value = CreateLoad(builder, slot_offset, dex.FieldName());
Expand Down Expand Up @@ -720,6 +732,13 @@ void LLVMGenerator::Visitor::Visit(const LiteralDex& dex) {
result_.reset(new LValue(value, len));
}

void LLVMGenerator::Visitor::Visit(const NullLiteralDex& dex) {
llvm::Value* value = nullptr;
llvm::Value* len = nullptr;
ADD_VISITOR_TRACE("visit Literal null");
result_.reset(new LValue(value, len));
}

void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex& dex) {
const std::string& function_name = dex.func_descriptor()->name();
ADD_VISITOR_TRACE("visit NonNullableFunc base function " + function_name);
Expand Down Expand Up @@ -1133,6 +1152,9 @@ LValuePtr LLVMGenerator::Visitor::BuildIfElse(llvm::Value* condition,
// Emit the merge block.
builder->SetInsertPoint(merge_bb);
auto llvm_type = types->IRType(result_type->id());
if (llvm_type == nullptr) {
return nullptr;
}
llvm::PHINode* result_value = builder->CreatePHI(llvm_type, 2, "res_value");
result_value->addIncoming(then_lvalue->data(), then_bb);
result_value->addIncoming(else_lvalue->data(), else_bb);
Expand Down Expand Up @@ -1248,10 +1270,11 @@ std::vector<llvm::Value*> LLVMGenerator::Visitor::BuildParams(
// build value.
DexPtr value_expr = pair->value_expr();
value_expr->Accept(*this);
LValue& result_ref = *result();

// append all the parameters corresponding to this LValue.
result_ref.AppendFunctionParams(&params);
if (auto result_ptr = result()) {
LValue& result_ref = *result_ptr;
// append all the parameters corresponding to this LValue.
result_ref.AppendFunctionParams(&params);
}

// build validity.
if (with_validity) {
Expand Down
1 change: 1 addition & 0 deletions cpp/src/gandiva/llvm_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class GANDIVA_EXPORT LLVMGenerator {
void Visit(const TrueDex& dex) override;
void Visit(const FalseDex& dex) override;
void Visit(const LiteralDex& dex) override;
void Visit(const NullLiteralDex& dex) override;
void Visit(const NonNullableFuncDex& dex) override;
void Visit(const NullableNeverFuncDex& dex) override;
void Visit(const NullableInternalFuncDex& dex) override;
Expand Down
15 changes: 14 additions & 1 deletion cpp/src/gandiva/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <vector>

#include "arrow/status.h"

#include "gandiva/arrow.h"
#include "gandiva/func_descriptor.h"
#include "gandiva/gandiva_aliases.h"
Expand Down Expand Up @@ -94,6 +93,20 @@ class GANDIVA_EXPORT LiteralNode : public Node {
bool is_null_;
};

/// \brief Node in the expression tree, representing a NullLiteralNode.
class GANDIVA_EXPORT NullLiteralNode : public Node {
public:
NullLiteralNode() : Node(arrow::null()) {}

Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); }

std::string ToString() const override {
std::stringstream ss;
ss << "(const " << return_type()->ToString() << ") null";
return ss.str();
}
};

/// \brief Node in the expression tree, representing an arrow field.
class GANDIVA_EXPORT FieldNode : public Node {
public:
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/gandiva/node_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class FieldNode;
class FunctionNode;
class IfNode;
class LiteralNode;
class NullLiteralNode;
class BooleanNode;
template <typename Type>
class InExpressionNode;
Expand All @@ -44,6 +45,7 @@ class GANDIVA_EXPORT NodeVisitor {
virtual Status Visit(const FunctionNode& node) = 0;
virtual Status Visit(const IfNode& node) = 0;
virtual Status Visit(const LiteralNode& node) = 0;
virtual Status Visit(const NullLiteralNode& node) = 0;
virtual Status Visit(const BooleanNode& node) = 0;
virtual Status Visit(const InExpressionNode<int32_t>& node) = 0;
virtual Status Visit(const InExpressionNode<int64_t>& node) = 0;
Expand Down
Loading