diff --git a/c_glib/test/gandiva/test-null-literal-node.rb b/c_glib/test/gandiva/test-null-literal-node.rb index ae14f3c15e411..87de4e811ca7a 100644 --- a/c_glib/test/gandiva/test-null-literal-node.rb +++ b/c_glib/test/gandiva/test-null-literal-node.rb @@ -22,11 +22,8 @@ def setup def test_invalid_type return_type = Arrow::NullDataType.new - message = - "[gandiva][null-literal-node][new] " + - "failed to create: <#{return_type}>" - assert_raise(Arrow::Error::Invalid.new(message)) do - Gandiva::NullLiteralNode.new(return_type) + literal_node = Gandiva::NullLiteralNode.new(return_type) + assert_equal(return_type, literal_node.return_type) end end diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt index 654a4a40be151..18932a9523cab 100644 --- a/cpp/src/gandiva/CMakeLists.txt +++ b/cpp/src/gandiva/CMakeLists.txt @@ -69,6 +69,7 @@ set(SRC_FILES expression_registry.cc exported_funcs_registry.cc filter.cc + null_ops.cc function_ir_builder.cc function_registry.cc function_registry_arithmetic.cc @@ -236,6 +237,7 @@ add_gandiva_test(internals-test random_generator_holder_test.cc hash_utils_test.cc gdv_function_stubs_test.cc + null_ops_test.cc EXTRA_DEPENDENCIES LLVM::LLVM_INTERFACE ${GANDIVA_OPENSSL_LIBS} diff --git a/cpp/src/gandiva/annotator.cc b/cpp/src/gandiva/annotator.cc index f6acaff180411..8d0eb145e1719 100644 --- a/cpp/src/gandiva/annotator.cc +++ b/cpp/src/gandiva/annotator.cc @@ -77,13 +77,21 @@ void Annotator::PrepareBuffersForField(const FieldDescriptor& desc, ++buffer_idx; } - uint8_t* data_buf = const_cast(array_data.buffers[buffer_idx]->data()); - eval_batch->SetBuffer(desc.data_idx(), data_buf, array_data.offset); + if (array_data.type->id() == arrow::Type::NA) { + eval_batch->SetBuffer(desc.data_idx(), nullptr, array_data.offset); + } else { + uint8_t* data_buf = const_cast(array_data.buffers[buffer_idx]->data()); + eval_batch->SetBuffer(desc.data_idx(), data_buf, array_data.offset); + } if (is_output) { // pass in the Buffer object for output data buffers. Can be used for resizing. - uint8_t* data_buf_ptr = - reinterpret_cast(array_data.buffers[buffer_idx].get()); - eval_batch->SetBuffer(desc.data_buffer_ptr_idx(), data_buf_ptr, array_data.offset); + if (array_data.type->id() == arrow::Type::NA) { + eval_batch->SetBuffer(desc.data_buffer_ptr_idx(), nullptr, array_data.offset); + } else { + uint8_t* data_buf_ptr = + reinterpret_cast(array_data.buffers[buffer_idx].get()); + eval_batch->SetBuffer(desc.data_buffer_ptr_idx(), data_buf_ptr, array_data.offset); + } } } diff --git a/cpp/src/gandiva/dex.h b/cpp/src/gandiva/dex.h index d1115c0516a0d..0b6cc22c19756 100644 --- a/cpp/src/gandiva/dex.h +++ b/cpp/src/gandiva/dex.h @@ -205,6 +205,14 @@ class GANDIVA_EXPORT LiteralDex : public Dex { LiteralHolder holder_; }; +/// decomposed expression for a null literal. +class GANDIVA_EXPORT NullLiteralDex : public Dex { + public: + NullLiteralDex() {} + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } +}; + /// decomposed if-else expression. class GANDIVA_EXPORT IfDex : public Dex { public: diff --git a/cpp/src/gandiva/dex_visitor.h b/cpp/src/gandiva/dex_visitor.h index 5d160bb22ca68..28378db0c19e7 100644 --- a/cpp/src/gandiva/dex_visitor.h +++ b/cpp/src/gandiva/dex_visitor.h @@ -31,6 +31,7 @@ class VectorReadFixedLenValueDex; class VectorReadVarLenValueDex; class LocalBitMapValidityDex; class LiteralDex; +class NullLiteralDex; class TrueDex; class FalseDex; class NonNullableFuncDex; @@ -54,6 +55,7 @@ class GANDIVA_EXPORT DexVisitor { virtual void Visit(const TrueDex& dex) = 0; virtual void Visit(const FalseDex& dex) = 0; virtual void Visit(const LiteralDex& dex) = 0; + virtual void Visit(const NullLiteralDex& dex) = 0; virtual void Visit(const NonNullableFuncDex& dex) = 0; virtual void Visit(const NullableNeverFuncDex& dex) = 0; virtual void Visit(const NullableInternalFuncDex& dex) = 0; @@ -80,6 +82,7 @@ class GANDIVA_EXPORT DexDefaultVisitor : public DexVisitor { VISIT_DCHECK(TrueDex) VISIT_DCHECK(FalseDex) VISIT_DCHECK(LiteralDex) + VISIT_DCHECK(NullLiteralDex) VISIT_DCHECK(NonNullableFuncDex) VISIT_DCHECK(NullableNeverFuncDex) VISIT_DCHECK(NullableInternalFuncDex) diff --git a/cpp/src/gandiva/exported_funcs.h b/cpp/src/gandiva/exported_funcs.h index 58205266094d9..1dc1f57f77004 100644 --- a/cpp/src/gandiva/exported_funcs.h +++ b/cpp/src/gandiva/exported_funcs.h @@ -32,6 +32,12 @@ class ExportedFuncsBase { virtual void AddMappings(Engine* engine) const = 0; }; +// Class for exporting Null functions +class ExportedNullFunctions : public ExportedFuncsBase { + void AddMappings(Engine* engine) const override; +}; +REGISTER_EXPORTED_FUNCS(ExportedNullFunctions); + // Class for exporting Stub functions class ExportedStubFunctions : public ExportedFuncsBase { void AddMappings(Engine* engine) const override; diff --git a/cpp/src/gandiva/expr_decomposer.cc b/cpp/src/gandiva/expr_decomposer.cc index 1c09d28f5e036..02bb050724b4e 100644 --- a/cpp/src/gandiva/expr_decomposer.cc +++ b/cpp/src/gandiva/expr_decomposer.cc @@ -225,6 +225,13 @@ Status ExprDecomposer::Visit(const LiteralNode& node) { return Status::OK(); } +Status ExprDecomposer::Visit(const NullLiteralNode& node) { + auto value_dex = std::make_shared(); + auto validity_dex = std::make_shared(); + result_ = std::make_shared(validity_dex, value_dex); + return Status::OK(); +} + // The bolow functions use a stack to detect : // a. nested if-else expressions. // In such cases, the local bitmap can be re-used. diff --git a/cpp/src/gandiva/expr_decomposer.h b/cpp/src/gandiva/expr_decomposer.h index f68b8a8fc0277..d5b3866ea353e 100644 --- a/cpp/src/gandiva/expr_decomposer.h +++ b/cpp/src/gandiva/expr_decomposer.h @@ -64,6 +64,7 @@ class GANDIVA_EXPORT ExprDecomposer : public NodeVisitor { Status Visit(const FunctionNode& node) override; Status Visit(const IfNode& node) override; Status Visit(const LiteralNode& node) override; + Status Visit(const NullLiteralNode& node) override; Status Visit(const BooleanNode& node) override; Status Visit(const InExpressionNode& node) override; Status Visit(const InExpressionNode& node) override; diff --git a/cpp/src/gandiva/expr_validator.cc b/cpp/src/gandiva/expr_validator.cc index c3c784c9511db..32aab53bf0752 100644 --- a/cpp/src/gandiva/expr_validator.cc +++ b/cpp/src/gandiva/expr_validator.cc @@ -42,11 +42,14 @@ Status ExprValidator::Validate(const ExpressionPtr& expr) { } Status ExprValidator::Visit(const FieldNode& node) { - auto llvm_type = types_->IRType(node.return_type()->id()); - ARROW_RETURN_IF(llvm_type == nullptr, - Status::ExpressionValidationError("Field ", node.field()->name(), - " has unsupported data type ", - node.return_type()->name())); + auto return_type = node.return_type(); + if (return_type->id() != arrow::Type::NA) { + auto llvm_type = types_->DataVecType(node.return_type()); + ARROW_RETURN_IF(llvm_type == nullptr, + Status::ExpressionValidationError("Field ", node.field()->name(), + " has unsupported data type ", + node.return_type()->name())); + } // Ensure that field is found in schema auto field_in_schema_entry = field_map_.find(node.field()->name()); @@ -120,6 +123,15 @@ Status ExprValidator::Visit(const LiteralNode& node) { return Status::OK(); } +Status ExprValidator::Visit(const NullLiteralNode& node) { + auto llvm_type = types_->DataVecType(node.return_type()); + ARROW_RETURN_IF(llvm_type != nullptr, + Status::ExpressionValidationError("Should be data type ", + node.return_type()->name())); + + return Status::OK(); +} + Status ExprValidator::Visit(const BooleanNode& node) { ARROW_RETURN_IF( node.children().size() < 2, diff --git a/cpp/src/gandiva/expr_validator.h b/cpp/src/gandiva/expr_validator.h index daaf50897fc75..08b3e4227619c 100644 --- a/cpp/src/gandiva/expr_validator.h +++ b/cpp/src/gandiva/expr_validator.h @@ -57,6 +57,7 @@ class ExprValidator : public NodeVisitor { Status Visit(const FunctionNode& node) override; Status Visit(const IfNode& node) override; Status Visit(const LiteralNode& node) override; + Status Visit(const NullLiteralNode& node) override; Status Visit(const BooleanNode& node) override; Status Visit(const InExpressionNode& node) override; Status Visit(const InExpressionNode& node) override; diff --git a/cpp/src/gandiva/function_registry.cc b/cpp/src/gandiva/function_registry.cc index d5d015c10b4b4..2d6221241025b 100644 --- a/cpp/src/gandiva/function_registry.cc +++ b/cpp/src/gandiva/function_registry.cc @@ -20,6 +20,7 @@ #include "gandiva/function_registry_datetime.h" #include "gandiva/function_registry_hash.h" #include "gandiva/function_registry_math_ops.h" +#include "gandiva/function_registry_null.h" #include "gandiva/function_registry_string.h" #include "gandiva/function_registry_timestamp_arithmetic.h" @@ -65,6 +66,9 @@ SignatureMap FunctionRegistry::InitPCMap() { auto v6 = GetDateTimeArithmeticFunctionRegistry(); pc_registry_.insert(std::end(pc_registry_), v6.begin(), v6.end()); + auto v8 = GetNullFunctionRegistry(); + pc_registry_.insert(std::end(pc_registry_), v8.begin(), v8.end()); + for (auto& elem : pc_registry_) { for (auto& func_signature : elem.signatures()) { map.insert(std::make_pair(&(func_signature), &elem)); diff --git a/cpp/src/gandiva/function_registry_common.h b/cpp/src/gandiva/function_registry_common.h index 66f945150897a..b95b8684d6c53 100644 --- a/cpp/src/gandiva/function_registry_common.h +++ b/cpp/src/gandiva/function_registry_common.h @@ -44,6 +44,7 @@ using arrow::int32; using arrow::int64; using arrow::int8; using arrow::month_interval; +using arrow::null; using arrow::uint16; using arrow::uint32; using arrow::uint64; diff --git a/cpp/src/gandiva/function_registry_null.h b/cpp/src/gandiva/function_registry_null.h new file mode 100644 index 0000000000000..ab45e6f4e4101 --- /dev/null +++ b/cpp/src/gandiva/function_registry_null.h @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "gandiva/native_function.h" + +namespace gandiva { + +std::vector GetNullFunctionRegistry() { + static std::vector null_fn_registry_ = { + NativeFunction("equal", + {"not_equal", "less_than", "less_than_or_equal_to", "greater_than", + "greater_than_or_equal_to"}, + DataTypeVector{null(), null()}, null(), kResultNullNever, + "compare_null_null"), + NativeFunction("isnull", {}, DataTypeVector{null()}, boolean(), kResultNullNever, + "isnull_null"), + NativeFunction("isnotnull", {}, DataTypeVector{null()}, boolean(), kResultNullNever, + "isnotnull_null")}; + return null_fn_registry_; +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/llvm_generator.cc b/cpp/src/gandiva/llvm_generator.cc index 0129e52784f65..7e5c074acb2b2 100644 --- a/cpp/src/gandiva/llvm_generator.cc +++ b/cpp/src/gandiva/llvm_generator.cc @@ -170,6 +170,9 @@ llvm::Value* LLVMGenerator::GetDataReference(llvm::Value* arg_addrs, int idx, llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name); llvm::Type* base_type = types()->DataVecType(field->type()); llvm::Value* ret; + if (base_type == nullptr) { + return nullptr; + } if (base_type->isPointerTy()) { ret = ir_builder()->CreateIntToPtr(load, base_type, name + "_darray"); } else { @@ -363,6 +366,8 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count, AddFunctionCall("gdv_fn_populate_varlen_vector", types()->i32_type(), {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, loop_var, output_value->data(), output_value->length()}); + } else if (output_type_id == arrow::Type::NA) { + // Do nothing when data type is null } else { return Status::NotImplemented("output type ", output->Type()->ToString(), " not supported"); @@ -452,6 +457,10 @@ void LLVMGenerator::ComputeBitMapsForExpr(const CompiledExpr& compiled_expr, // Extract the destination bitmap address. int out_idx = compiled_expr.output()->validity_idx(); uint8_t* dst_bitmap = eval_batch.GetBuffer(out_idx); + if (dst_bitmap == nullptr) { + // Return when dst_bitmap is null meaning data type is null + return; + } // Compute the destination bitmap. if (selection_vector == nullptr) { accumulator.ComputeResult(dst_bitmap); @@ -491,7 +500,7 @@ llvm::Value* LLVMGenerator::AddFunctionCall(const std::string& full_name, // build a call to the llvm function. llvm::Value* value; - if (ret_type->isVoidTy()) { + if (ret_type == nullptr || ret_type->isVoidTy()) { // void functions can't have a name for the call. value = ir_builder()->CreateCall(fn, args); } else { @@ -556,6 +565,9 @@ void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex& dex) { break; } + case arrow::Type::NA: + break; + default: { auto slot_offset = CreateGEP(builder, slot_ref, slot_index); slot_value = CreateLoad(builder, slot_offset, dex.FieldName()); @@ -720,6 +732,13 @@ void LLVMGenerator::Visitor::Visit(const LiteralDex& dex) { result_.reset(new LValue(value, len)); } +void LLVMGenerator::Visitor::Visit(const NullLiteralDex& dex) { + llvm::Value* value = nullptr; + llvm::Value* len = nullptr; + ADD_VISITOR_TRACE("visit Literal null"); + result_.reset(new LValue(value, len)); +} + void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex& dex) { const std::string& function_name = dex.func_descriptor()->name(); ADD_VISITOR_TRACE("visit NonNullableFunc base function " + function_name); @@ -1133,6 +1152,9 @@ LValuePtr LLVMGenerator::Visitor::BuildIfElse(llvm::Value* condition, // Emit the merge block. builder->SetInsertPoint(merge_bb); auto llvm_type = types->IRType(result_type->id()); + if (llvm_type == nullptr) { + return nullptr; + } llvm::PHINode* result_value = builder->CreatePHI(llvm_type, 2, "res_value"); result_value->addIncoming(then_lvalue->data(), then_bb); result_value->addIncoming(else_lvalue->data(), else_bb); @@ -1248,10 +1270,11 @@ std::vector LLVMGenerator::Visitor::BuildParams( // build value. DexPtr value_expr = pair->value_expr(); value_expr->Accept(*this); - LValue& result_ref = *result(); - - // append all the parameters corresponding to this LValue. - result_ref.AppendFunctionParams(¶ms); + if (auto result_ptr = result()) { + LValue& result_ref = *result_ptr; + // append all the parameters corresponding to this LValue. + result_ref.AppendFunctionParams(¶ms); + } // build validity. if (with_validity) { diff --git a/cpp/src/gandiva/llvm_generator.h b/cpp/src/gandiva/llvm_generator.h index ff6d846024cb9..d18a47a27350f 100644 --- a/cpp/src/gandiva/llvm_generator.h +++ b/cpp/src/gandiva/llvm_generator.h @@ -100,6 +100,7 @@ class GANDIVA_EXPORT LLVMGenerator { void Visit(const TrueDex& dex) override; void Visit(const FalseDex& dex) override; void Visit(const LiteralDex& dex) override; + void Visit(const NullLiteralDex& dex) override; void Visit(const NonNullableFuncDex& dex) override; void Visit(const NullableNeverFuncDex& dex) override; void Visit(const NullableInternalFuncDex& dex) override; diff --git a/cpp/src/gandiva/node.h b/cpp/src/gandiva/node.h index 20807d4a0cb5b..6e4c22e93b101 100644 --- a/cpp/src/gandiva/node.h +++ b/cpp/src/gandiva/node.h @@ -23,7 +23,6 @@ #include #include "arrow/status.h" - #include "gandiva/arrow.h" #include "gandiva/func_descriptor.h" #include "gandiva/gandiva_aliases.h" @@ -94,6 +93,20 @@ class GANDIVA_EXPORT LiteralNode : public Node { bool is_null_; }; +/// \brief Node in the expression tree, representing a NullLiteralNode. +class GANDIVA_EXPORT NullLiteralNode : public Node { + public: + NullLiteralNode() : Node(arrow::null()) {} + + Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); } + + std::string ToString() const override { + std::stringstream ss; + ss << "(const " << return_type()->ToString() << ") null"; + return ss.str(); + } +}; + /// \brief Node in the expression tree, representing an arrow field. class GANDIVA_EXPORT FieldNode : public Node { public: diff --git a/cpp/src/gandiva/node_visitor.h b/cpp/src/gandiva/node_visitor.h index 8f233f5b77c0b..a8f94fe873508 100644 --- a/cpp/src/gandiva/node_visitor.h +++ b/cpp/src/gandiva/node_visitor.h @@ -31,6 +31,7 @@ class FieldNode; class FunctionNode; class IfNode; class LiteralNode; +class NullLiteralNode; class BooleanNode; template class InExpressionNode; @@ -44,6 +45,7 @@ class GANDIVA_EXPORT NodeVisitor { virtual Status Visit(const FunctionNode& node) = 0; virtual Status Visit(const IfNode& node) = 0; virtual Status Visit(const LiteralNode& node) = 0; + virtual Status Visit(const NullLiteralNode& node) = 0; virtual Status Visit(const BooleanNode& node) = 0; virtual Status Visit(const InExpressionNode& node) = 0; virtual Status Visit(const InExpressionNode& node) = 0; diff --git a/cpp/src/gandiva/null_ops.cc b/cpp/src/gandiva/null_ops.cc new file mode 100644 index 0000000000000..e21ce236848a7 --- /dev/null +++ b/cpp/src/gandiva/null_ops.cc @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "gandiva/null_ops.h" + +#include "gandiva/engine.h" +#include "gandiva/exported_funcs.h" +#include "gandiva/gdv_function_stubs.h" + +/// Stub functions that can be accessed from LLVM or the pre-compiled library. + +extern "C" { + +GANDIVA_EXPORT +void compare_null_null(bool in1_valid, bool in2_valid) {} + +GANDIVA_EXPORT +bool isnull_null(bool in_valid) { return true; } + +GANDIVA_EXPORT +bool isnotnull_null(bool in_valid) { return false; } +} + +namespace gandiva { +void ExportedNullFunctions::AddMappings(Engine* engine) const { + std::vector args; + auto types = engine->types(); + + args = {types->i1_type(), types->i1_type()}; + engine->AddGlobalMappingForFunc("compare_null_null", types->void_type() /*return_type*/, + args, reinterpret_cast(compare_null_null)); + + args = {types->i1_type()}; + engine->AddGlobalMappingForFunc("isnull_null", types->i1_type() /*return_type*/, args, + reinterpret_cast(isnull_null)); + + args = {types->i1_type()}; + engine->AddGlobalMappingForFunc("isnotnull_null", types->i1_type() /*return_type*/, + args, reinterpret_cast(isnotnull_null)); +} +} // namespace gandiva diff --git a/cpp/src/gandiva/null_ops.h b/cpp/src/gandiva/null_ops.h new file mode 100644 index 0000000000000..7da22f7d0d480 --- /dev/null +++ b/cpp/src/gandiva/null_ops.h @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +/// Stub functions that can be accessed from LLVM. +extern "C" { + +GANDIVA_EXPORT +void compare_null_null(bool in1_valid, bool in2_valid); + +GANDIVA_EXPORT +bool isnull_null(bool in_valid); + +GANDIVA_EXPORT +bool isnotnull_null(bool in_valid); +} diff --git a/cpp/src/gandiva/null_ops_test.cc b/cpp/src/gandiva/null_ops_test.cc new file mode 100644 index 0000000000000..a979b82a771b8 --- /dev/null +++ b/cpp/src/gandiva/null_ops_test.cc @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "gandiva/precompiled/types.h" + +namespace gandiva { + +TEST(TestNullOps, Test) { + compare_null_null(true, true); + EXPECT_TRUE(isnull_null(true)); + EXPECT_FALSE(isnotnull_null(true)); +} +} // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/types.h b/cpp/src/gandiva/precompiled/types.h index 1a35741cc26d9..0180cd6f7a62c 100644 --- a/cpp/src/gandiva/precompiled/types.h +++ b/cpp/src/gandiva/precompiled/types.h @@ -20,6 +20,7 @@ #include #include "gandiva/gdv_function_stubs.h" +#include "gandiva/null_ops.h" // Use the same names as in arrow data types. Makes it easy to write pre-processor macros. using gdv_boolean = bool; diff --git a/cpp/src/gandiva/projector.cc b/cpp/src/gandiva/projector.cc index ff167538f9c1c..dbdd746d6712c 100644 --- a/cpp/src/gandiva/projector.cc +++ b/cpp/src/gandiva/projector.cc @@ -24,7 +24,6 @@ #include "arrow/util/hash_util.h" #include "arrow/util/logging.h" - #include "gandiva/cache.h" #include "gandiva/expr_validator.h" #include "gandiva/llvm_generator.h" @@ -296,6 +295,8 @@ Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records, } else if (arrow::is_binary_like(type_id)) { // we don't know the expected size for varlen output vectors. data_len = 0; + } else if (type_id == arrow::Type::NA) { + data_len = 0; } else { return Status::Invalid("Unsupported output data type " + type->ToString()); } @@ -308,7 +309,11 @@ Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records, } buffers.push_back(std::move(data_buffer)); - *array_data = arrow::ArrayData::Make(type, num_records, std::move(buffers)); + if (type_id == arrow::Type::NA) { + *array_data = arrow::ArrayData::Make(type, num_records, {nullptr}); + } else { + *array_data = arrow::ArrayData::Make(type, num_records, std::move(buffers)); + } return Status::OK(); } @@ -357,6 +362,10 @@ Status Projector::ValidateArrayDataCapacity(const arrow::ArrayData& array_data, int64_t data_len = array_data.buffers[1]->capacity(); ARROW_RETURN_IF(data_len < min_data_len, Status::Invalid("Data buffer too small for ", field.name())); + } else if (type_id == arrow::Type::NA) { + ARROW_RETURN_IF(array_data.buffers.size() == 1 && array_data.buffers[0] == nullptr, + Status::Invalid("Data buffer should be nullptr for null typed field", + field.name())); } else { return Status::Invalid("Unsupported output data type " + field.type()->ToString()); } diff --git a/cpp/src/gandiva/tests/CMakeLists.txt b/cpp/src/gandiva/tests/CMakeLists.txt index 5fa2da16c632f..a57085c589e69 100644 --- a/cpp/src/gandiva/tests/CMakeLists.txt +++ b/cpp/src/gandiva/tests/CMakeLists.txt @@ -25,6 +25,7 @@ add_gandiva_test(binary_test) add_gandiva_test(date_time_test) add_gandiva_test(to_string_test) add_gandiva_test(utf8_test) +add_gandiva_test(null_test) add_gandiva_test(hash_test) add_gandiva_test(in_expr_test) add_gandiva_test(null_validity_test) diff --git a/cpp/src/gandiva/tests/null_test.cc b/cpp/src/gandiva/tests/null_test.cc new file mode 100644 index 0000000000000..db67117c3efa4 --- /dev/null +++ b/cpp/src/gandiva/tests/null_test.cc @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/memory_pool.h" +#include "arrow/status.h" +#include "gandiva/projector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::null; + +class TestNull : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestNull, TestSimple) { + // schema for input fields + auto field_null = field("field_null", null()); + auto schema = arrow::schema({field_null}); + + auto literal_null = TreeExprBuilder::MakeNull(arrow::null()); + auto node_field_null = TreeExprBuilder::MakeField(field_null); + + // output fields + auto res_1 = field("res1", null()); + auto res_2 = field("res2", null()); + auto expr_1 = TreeExprBuilder::MakeExpression(literal_null, res_1); + auto expr_2 = TreeExprBuilder::MakeExpression(node_field_null, res_2); + + // Build a projector for the expressions. + std::shared_ptr projector; + auto status = + Projector::Make(schema, {expr_1, expr_2}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + arrow::ArrayVector outputs; + auto null_array = std::make_shared(4); + auto in_batch = arrow::RecordBatch::Make(schema, 4, {null_array}); + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(null_array, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(null_array, outputs.at(1)); +} + +TEST_F(TestNull, TestOps) { + // schema for input fields + auto field_null = field("field_null", null()); + auto schema = arrow::schema({field_null}); + + // output fields + auto res_1 = field("res1", null()); + auto res_2 = field("res2", null()); + auto res_3 = field("res3", null()); + auto res_4 = field("res4", null()); + auto res_5 = field("res5", null()); + auto res_6 = field("res6", null()); + auto res_7 = field("res7", boolean()); + auto res_8 = field("res8", boolean()); + auto expr_1 = TreeExprBuilder::MakeExpression("equal", {field_null, field_null}, res_1); + auto expr_2 = + TreeExprBuilder::MakeExpression("not_equal", {field_null, field_null}, res_2); + auto expr_3 = + TreeExprBuilder::MakeExpression("less_than", {field_null, field_null}, res_3); + auto expr_4 = TreeExprBuilder::MakeExpression("less_than_or_equal_to", + {field_null, field_null}, res_4); + auto expr_5 = + TreeExprBuilder::MakeExpression("greater_than", {field_null, field_null}, res_5); + auto expr_6 = TreeExprBuilder::MakeExpression("greater_than_or_equal_to", + {field_null, field_null}, res_6); + auto expr_7 = TreeExprBuilder::MakeExpression("isnull", {field_null}, res_7); + auto expr_8 = TreeExprBuilder::MakeExpression("isnotnull", {field_null}, res_8); + + // Build a projector for the expressions. + std::shared_ptr projector; + auto status = Projector::Make( + schema, {expr_1, expr_2, expr_3, expr_4, expr_5, expr_6, expr_7, expr_8}, + TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + arrow::ArrayVector outputs; + auto null_array = std::make_shared(4); + auto in_batch = arrow::RecordBatch::Make(schema, 4, {null_array}); + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + auto exp_true = MakeArrowArrayBool({true, true, true, true}, {true, true, true, true}); + auto exp_false = + MakeArrowArrayBool({false, false, false, false}, {true, true, true, true}); + for (int i = 0; i < 6; i++) { + EXPECT_EQ(outputs.at(i)->null_count(), 4); + } + EXPECT_ARROW_ARRAY_EQUALS(exp_true, outputs.at(6)); + EXPECT_ARROW_ARRAY_EQUALS(exp_false, outputs.at(7)); +} + +TEST_F(TestNull, TestMakeIf) { + // schema for input fields + auto field_null = field("field_null", null()); + auto schema = arrow::schema({field_null}); + + // output fields + auto res_1 = field("res1", null()); + auto res_2 = field("res2", null()); + + auto null_node = TreeExprBuilder::MakeNull(null()); + auto expr_1 = TreeExprBuilder::MakeExpression( + TreeExprBuilder::MakeIf(TreeExprBuilder::MakeLiteral(true), null_node, null_node, + null()), + res_1); + auto expr_2 = TreeExprBuilder::MakeExpression( + TreeExprBuilder::MakeIf(TreeExprBuilder::MakeLiteral(false), null_node, null_node, + null()), + res_2); + + // Build a projector for the expressions. + std::shared_ptr projector; + auto status = + Projector::Make(schema, {expr_1, expr_2}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + arrow::ArrayVector outputs; + auto null_array = std::make_shared(4); + auto in_batch = arrow::RecordBatch::Make(schema, 4, {null_array}); + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + for (auto& output : outputs) { + EXPECT_EQ(output->null_count(), 4); + } +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/tree_expr_builder.cc b/cpp/src/gandiva/tree_expr_builder.cc index de8e3445a1259..c7869c147ecf0 100644 --- a/cpp/src/gandiva/tree_expr_builder.cc +++ b/cpp/src/gandiva/tree_expr_builder.cc @@ -105,6 +105,8 @@ NodePtr TreeExprBuilder::MakeNull(DataTypePtr data_type) { DecimalScalar128 literal(decimal_type->precision(), decimal_type->scale()); return std::make_shared(data_type, LiteralHolder(literal), true); } + case arrow::Type::NA: + return std::make_shared(); default: return nullptr; }