Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NULL behavior fix for logical operator #380

Merged
merged 11 commits into from
Jan 7, 2020
219 changes: 175 additions & 44 deletions dbms/src/Functions/FunctionsLogical.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,22 @@ struct AndImpl
return true;
}

static inline bool resNotNull(const Field & value)
{
return !value.isNull() && applyVisitor(FieldVisitorConvertToNumber<bool>(), value) == 0;
}

static inline bool resNotNull(UInt8 value, UInt8 is_null)
{
return !is_null && !value;
}

static void adjustForNullValue(UInt8 & value, UInt8 & is_null)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not inline this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

{
is_null = false;
value = false;
}

static inline bool isSaturatedValue(bool a)
{
return !a;
Expand All @@ -49,7 +65,6 @@ struct AndImpl
return a && b;
}

static inline bool specialImplementationForNulls() { return false; }
};

struct OrImpl
Expand All @@ -64,12 +79,27 @@ struct OrImpl
return a;
}

static inline bool resNotNull(const Field & value)
{
return !value.isNull() && applyVisitor(FieldVisitorConvertToNumber<bool>(), value) == 1;
}

static inline bool resNotNull(UInt8 value, UInt8 is_null)
{
return !is_null && value;
}

static void adjustForNullValue(UInt8 & value, UInt8 & is_null)
{
is_null = false;
value = true;
}

static inline bool apply(bool a, bool b)
{
return a || b;
}

static inline bool specialImplementationForNulls() { return true; }
};

struct XorImpl
Expand All @@ -84,12 +114,24 @@ struct XorImpl
return false;
}

static inline bool resNotNull(const Field & )
{
return true;
}

static inline bool resNotNull(UInt8 , UInt8 )
{
return true;
}

static void adjustForNullValue(UInt8 & , UInt8 & )
{
}

static inline bool apply(bool a, bool b)
{
return a != b;
}

static inline bool specialImplementationForNulls() { return false; }
};

template <typename A>
Expand Down Expand Up @@ -172,15 +214,19 @@ struct AssociativeOperationImpl<Op, 1>
};


template <typename Impl, typename Name>
/**
* The behavior of and and or is the same as
* https://en.wikipedia.org/wiki/Null_(SQL)#Comparisons_with_NULL_and_the_three-valued_logic_(3VL)
*/
template <typename Impl, typename Name, bool special_impl_for_nulls>
class FunctionAnyArityLogical : public IFunction
{
public:
static constexpr auto name = Name::name;
static FunctionPtr create(const Context &) { return std::make_shared<FunctionAnyArityLogical>(); };

private:
bool extractConstColumns(ColumnRawPtrs & in, UInt8 & res)
bool extractConstColumns(ColumnRawPtrs & in, UInt8 & res, UInt8 & res_not_null, UInt8 & input_has_null)
{
bool has_res = false;
for (int i = static_cast<int>(in.size()) - 1; i >= 0; --i)
Expand All @@ -189,6 +235,11 @@ class FunctionAnyArityLogical : public IFunction
continue;

Field value = (*in[i])[0];
if constexpr (special_impl_for_nulls)
{
input_has_null |= value.isNull();
res_not_null |= Impl::resNotNull(value);
}

UInt8 x = !value.isNull() && applyVisitor(FieldVisitorConvertToNumber<bool>(), value);
if (has_res)
Expand All @@ -207,21 +258,26 @@ class FunctionAnyArityLogical : public IFunction
}

template <typename T>
bool convertTypeToUInt8(const IColumn * column, UInt8Container & res)
bool convertTypeToUInt8(const IColumn * column, UInt8Container & res, UInt8Container & res_not_null)
{
auto col = checkAndGetColumn<ColumnVector<T>>(column);
if (!col)
return false;
const auto & vec = col->getData();
size_t n = res.size();
for (size_t i = 0; i < n; ++i)
{
res[i] = !!vec[i];
if constexpr (special_impl_for_nulls)
res_not_null[i] |= Impl::resNotNull(res[i], false);
}

return true;
}

template <typename T>
bool convertNullableTypeToUInt8(const IColumn * column, UInt8Container & res)
bool convertNullableTypeToUInt8(const IColumn * column, UInt8Container & res, UInt8Container & res_not_null,
UInt8Container & input_has_null)
{
auto col_nullable = checkAndGetColumn<ColumnNullable>(column);

Expand All @@ -234,32 +290,40 @@ class FunctionAnyArityLogical : public IFunction

size_t n = res.size();
for (size_t i = 0; i < n; ++i)
{
res[i] = !!vec[i] && !null_map[i];
if constexpr (special_impl_for_nulls)
{
res_not_null[i] |= Impl::resNotNull(res[i], null_map[i]);
input_has_null[i] |= null_map[i];
}
}

return true;
}

void convertToUInt8(const IColumn * column, UInt8Container & res)
{
if (!convertTypeToUInt8<Int8>(column, res) &&
!convertTypeToUInt8<Int16>(column, res) &&
!convertTypeToUInt8<Int32>(column, res) &&
!convertTypeToUInt8<Int64>(column, res) &&
!convertTypeToUInt8<UInt16>(column, res) &&
!convertTypeToUInt8<UInt32>(column, res) &&
!convertTypeToUInt8<UInt64>(column, res) &&
!convertTypeToUInt8<Float32>(column, res) &&
!convertTypeToUInt8<Float64>(column, res) &&
!convertNullableTypeToUInt8<Int8>(column, res) &&
!convertNullableTypeToUInt8<Int16>(column, res) &&
!convertNullableTypeToUInt8<Int32>(column, res) &&
!convertNullableTypeToUInt8<Int64>(column, res) &&
!convertNullableTypeToUInt8<UInt8>(column, res) &&
!convertNullableTypeToUInt8<UInt16>(column, res) &&
!convertNullableTypeToUInt8<UInt32>(column, res) &&
!convertNullableTypeToUInt8<UInt64>(column, res) &&
!convertNullableTypeToUInt8<Float32>(column, res) &&
!convertNullableTypeToUInt8<Float64>(column, res))
void convertToUInt8(const IColumn * column, UInt8Container & res, UInt8Container & res_not_null,
UInt8Container & input_has_null)
{
if (!convertTypeToUInt8<Int8>(column, res, res_not_null) &&
!convertTypeToUInt8<Int16>(column, res, res_not_null) &&
!convertTypeToUInt8<Int32>(column, res, res_not_null) &&
!convertTypeToUInt8<Int64>(column, res, res_not_null) &&
!convertTypeToUInt8<UInt16>(column, res, res_not_null) &&
!convertTypeToUInt8<UInt32>(column, res, res_not_null) &&
!convertTypeToUInt8<UInt64>(column, res, res_not_null) &&
!convertTypeToUInt8<Float32>(column, res, res_not_null) &&
!convertTypeToUInt8<Float64>(column, res, res_not_null) &&
!convertNullableTypeToUInt8<Int8>(column, res, res_not_null, input_has_null) &&
!convertNullableTypeToUInt8<Int16>(column, res, res_not_null, input_has_null) &&
!convertNullableTypeToUInt8<Int32>(column, res, res_not_null, input_has_null) &&
!convertNullableTypeToUInt8<Int64>(column, res, res_not_null, input_has_null) &&
!convertNullableTypeToUInt8<UInt8>(column, res, res_not_null, input_has_null) &&
!convertNullableTypeToUInt8<UInt16>(column, res, res_not_null, input_has_null) &&
!convertNullableTypeToUInt8<UInt32>(column, res, res_not_null, input_has_null) &&
!convertNullableTypeToUInt8<UInt64>(column, res, res_not_null, input_has_null) &&
!convertNullableTypeToUInt8<Float32>(column, res, res_not_null, input_has_null) &&
!convertNullableTypeToUInt8<Float64>(column, res, res_not_null, input_has_null))
throw Exception("Unexpected type of column: " + column->getName(), ErrorCodes::ILLEGAL_COLUMN);
}

Expand All @@ -272,7 +336,7 @@ class FunctionAnyArityLogical : public IFunction
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }

bool useDefaultImplementationForNulls() const override { return !Impl::specialImplementationForNulls(); }
bool useDefaultImplementationForNulls() const override { return !special_impl_for_nulls; }

/// Get result types by argument types. If the function does not apply to these arguments, throw an exception.
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
Expand All @@ -282,20 +346,32 @@ class FunctionAnyArityLogical : public IFunction
+ toString(arguments.size()) + ", should be at least 2.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);

bool has_nullable_input_column = false;
for (size_t i = 0; i < arguments.size(); ++i)
{
has_nullable_input_column |= arguments[i]->isNullable();
if (!(arguments[i]->isNumber()
|| (Impl::specialImplementationForNulls() && (arguments[i]->onlyNull() || removeNullable(arguments[i])->isNumber()))))
|| (special_impl_for_nulls && (arguments[i]->onlyNull() || removeNullable(arguments[i])->isNumber()))))
throw Exception("Illegal type ("
+ arguments[i]->getName()
+ ") of " + toString(i + 1) + " argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
+ arguments[i]->getName()
+ ") of " + toString(i + 1) + " argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}

return std::make_shared<DataTypeUInt8>();
if (has_nullable_input_column)
return makeNullable(std::make_shared<DataTypeUInt8>());
else
return std::make_shared<DataTypeUInt8>();
}

void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override
{
bool has_nullable_input_column = false;
size_t num_arguments = arguments.size();

for (size_t i = 0; i < num_arguments; ++i)
has_nullable_input_column |= block.getByPosition(arguments[i]).type->isNullable();

ColumnRawPtrs in(num_arguments);
for (size_t i = 0; i < num_arguments; ++i)
in[i] = block.getByPosition(arguments[i]).column.get();
Expand All @@ -304,32 +380,62 @@ class FunctionAnyArityLogical : public IFunction

/// Combine all constant columns into a single value.
UInt8 const_val = 0;
bool has_consts = extractConstColumns(in, const_val);
UInt8 const_val_input_has_null = 0;
UInt8 const_val_res_not_null = 0;
bool has_consts = extractConstColumns(in, const_val, const_val_res_not_null, const_val_input_has_null);

// If this value uniquely determines the result, return it.
if (has_consts && (in.empty() || Impl::apply(const_val, 0) == Impl::apply(const_val, 1)))
if (has_consts && (in.empty() || (!has_nullable_input_column && Impl::apply(const_val, 0) == Impl::apply(const_val, 1))))
{
if (!in.empty())
const_val = Impl::apply(const_val, 0);
block.getByPosition(result).column = DataTypeUInt8().createColumnConst(rows, toField(const_val));
if constexpr (!special_impl_for_nulls)
block.getByPosition(result).column = DataTypeUInt8().createColumnConst(rows, toField(const_val));
else
{
if (const_val_input_has_null && const_val_res_not_null) {
zanmato1984 marked this conversation as resolved.
Show resolved Hide resolved
Impl::adjustForNullValue(const_val, const_val_input_has_null);
}
if (const_val_input_has_null)
block.getByPosition(result).column =
block.getByPosition(result).type->createColumnConst(rows,Null());
else
block.getByPosition(result).column = has_nullable_input_column ? makeNullable(
DataTypeUInt8().createColumnConst(rows, toField(const_val))) :
DataTypeUInt8().createColumnConst(rows, toField(const_val));
}
return;
}

/// If this value is a neutral element, let's forget about it.
if (has_consts && Impl::apply(const_val, 0) == 0 && Impl::apply(const_val, 1) == 1)
if (!has_nullable_input_column && has_consts && Impl::apply(const_val, 0) == 0 && Impl::apply(const_val, 1) == 1)
has_consts = false;

auto col_res = ColumnUInt8::create();
UInt8Container & vec_res = col_res->getData();
auto col_input_has_null = ColumnUInt8::create();
UInt8Container & vec_input_has_null = col_input_has_null->getData();
auto col_res_not_null = ColumnUInt8::create();
UInt8Container & vec_res_not_null = col_res_not_null->getData();

if (has_consts)
{
vec_res.assign(rows, const_val);
in.push_back(col_res.get());
if constexpr (special_impl_for_nulls)
{
vec_input_has_null.assign(rows, const_val_input_has_null);
vec_res_not_null.assign(rows, const_val_res_not_null);
}
}
else
{
vec_res.resize(rows);
if constexpr (special_impl_for_nulls)
{
vec_input_has_null.assign(rows, (UInt8) 0);
vec_res_not_null.assign(rows, (UInt8) 0);
}
}

/// Convert all columns to UInt8
Expand All @@ -339,11 +445,20 @@ class FunctionAnyArityLogical : public IFunction
for (const IColumn * column : in)
{
if (auto uint8_column = checkAndGetColumn<ColumnUInt8>(column))
{
uint8_in.push_back(uint8_column);
const auto & data = uint8_column->getData();
if constexpr (special_impl_for_nulls)
{
size_t n = uint8_column->size();
for (size_t i = 0; i < n; i++)
vec_res_not_null[i] |= Impl::resNotNull(data[i], false);
}
}
else
{
auto converted_column = ColumnUInt8::create(rows);
convertToUInt8(column, converted_column->getData());
convertToUInt8(column, converted_column->getData(), vec_res_not_null, vec_input_has_null);
uint8_in.push_back(converted_column.get());
converted_columns.emplace_back(std::move(converted_column));
}
Expand All @@ -362,7 +477,23 @@ class FunctionAnyArityLogical : public IFunction
if (uint8_in[0] != col_res.get())
vec_res.assign(uint8_in[0]->getData());

block.getByPosition(result).column = std::move(col_res);
if constexpr (!special_impl_for_nulls)
{
block.getByPosition(result).column = std::move(col_res);
}
else {
if (has_nullable_input_column) {
zanmato1984 marked this conversation as resolved.
Show resolved Hide resolved
for (size_t i = 0; i < rows; i++)
{
if (vec_input_has_null[i] && vec_res_not_null[i])
Impl::adjustForNullValue(vec_res[i], vec_input_has_null[i]);
}
block.getByPosition(result).column = ColumnNullable::create(std::move(col_res),
std::move(col_input_has_null));
}
else
block.getByPosition(result).column = std::move(col_res);
}
}
};

Expand Down Expand Up @@ -438,9 +569,9 @@ struct NameOr { static constexpr auto name = "or"; };
struct NameXor { static constexpr auto name = "xor"; };
struct NameNot { static constexpr auto name = "not"; };

using FunctionAnd = FunctionAnyArityLogical<AndImpl, NameAnd>;
using FunctionOr = FunctionAnyArityLogical<OrImpl, NameOr>;
using FunctionXor = FunctionAnyArityLogical<XorImpl, NameXor>;
using FunctionAnd = FunctionAnyArityLogical<AndImpl, NameAnd, true>;
using FunctionOr = FunctionAnyArityLogical<OrImpl, NameOr, true>;
using FunctionXor = FunctionAnyArityLogical<XorImpl, NameXor, false>;
using FunctionNot = FunctionUnaryLogical<NotImpl, NameNot>;

}
Loading