Skip to content

Commit

Permalink
finish dev
Browse files Browse the repository at this point in the history
  • Loading branch information
taiyang-li committed Jan 16, 2025
1 parent 6b87b7f commit 92e9224
Show file tree
Hide file tree
Showing 8 changed files with 244 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,24 @@ case class CHHashAggregateExecTransformer(
val childrenNodeList = new util.ArrayList[ExpressionNode]()
val childrenNodes = aggExpr.mode match {
case Partial | Complete =>
aggregateFunc.children.toList.map(
val nodes = aggregateFunc.children.toList.map(
expr => {
ExpressionConverter
.replaceWithExpressionTransformer(expr, child.output)
.doTransform(args)
})

val extraNodes = aggregateFunc match {
case hll: HyperLogLogPlusPlus =>
val relativeSDLiteral = Literal(hll.relativeSD)
Seq(
ExpressionConverter
.replaceWithExpressionTransformer(relativeSDLiteral, child.output)
.doTransform(args))
case _ => Seq.empty
}

nodes ++ extraNodes
case PartialMerge if distinct_modes.contains(Partial) =>
// this is the case where PartialMerge co-exists with Partial
// so far, it only happens in a three-stage count distinct case
Expand Down Expand Up @@ -438,6 +450,12 @@ case class CHHashAggregateExecTransformer(
percentile.percentageExpression.dataType,
percentile.percentageExpression.nullable)
(makeStructType(fields), attr.nullable)
case hllpp: HyperLogLogPlusPlus =>
var fields = Seq[(DataType, Boolean)]()
fields = fields :+ (hllpp.child.dataType, hllpp.child.nullable)
val relativeSDLiteral = Literal(hllpp.relativeSD)
fields = fields :+ (relativeSDLiteral.dataType, false)
(makeStructType(fields), attr.nullable)
case _ =>
(makeStructTypeSingleOne(attr.dataType, attr.nullable), attr.nullable)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.gluten.expression.ConverterUtils.FunctionConfig
import org.apache.gluten.substrait.expression._

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -260,6 +261,7 @@ case class GetArrayItemTransformer(
ConverterUtils.getTypeNode(getArrayItem.dataType, getArrayItem.nullable))
}
}

case class CHStringSplitTransformer(
substraitExprName: String,
children: Seq[ExpressionTransformer],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ struct Settings;

namespace ErrorCodes
{

extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,9 @@ namespace
{

AggregateFunctionPtr createAggregateFunctionUniqHyperLogLogPlusPlus(
const std::string & name, const DataTypes & argument_types, const Array & , const Settings *)
const std::string & , const DataTypes & argument_types, const Array & params, const Settings *)
{
if (argument_types.size() != 1)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Aggregate function '{}' requires exactly one argument: got {}",
name,
argument_types.size());

WhichDataType which(argument_types[0]);
if (!which.isUInt64())
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Aggregate function '{}' requires an argument of type UInt64, got {}",
name,
argument_types[0]->getName());

return std::make_shared<AggregateFunctionUniqHyperLogLogPlusPlus>(argument_types);
return std::make_shared<AggregateFunctionUniqHyperLogLogPlusPlus>(argument_types, params);
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,44 +1,53 @@
#pragma once

#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/IDataType.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <Common/FieldVisitorConvertToNumber.h>
#include <Common/FieldVisitors.h>
#include <Common/HashTable/Hash.h>
#include <Common/HyperLogLogWithSmallSetOptimization.h>
#include <Parsers/NullsAction.h>
namespace DB
{

namespace ErrorCodes
{
extern const int PARAMETER_OUT_OF_BOUND;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
extern const int LOGICAL_ERROR;
extern const int INCORRECT_DATA;
}

struct HyperLogLogPlusPlusData
{
// HyperLogLogPlusPlusData(Float64 relative_sd_ = 0.05)
HyperLogLogPlusPlusData()
// : relative_sd(relative_sd_)
: relative_sd(0.05)
explicit HyperLogLogPlusPlusData(double relative_sd_ = 0.05)
: relative_sd(relative_sd_)
, p(static_cast<UInt64>(std::ceil(2.0 + std::log(1.106 / relative_sd / std::log(2.0)))))
, idx_shift(64 - p)
, w_padding(1ULL << (p - 1))
, m(1ULL << p)
, num_words(m / REGISTERS_PER_WORD + 1)
, alpha_mm(computeAlphaMM())
, registers(num_words, 0) // Initialize registers with zeros
{
// std::cout << "relative_sd: " << relative_sd << std::endl;
// std::cout << "p: " << p << std::endl;
// std::cout << "idx_shift: " << idx_shift << std::endl;
// std::cout << "w_padding: " << w_padding << std::endl;
// std::cout << "m: " << m << std::endl;
// std::cout << "num_words: " << num_words << std::endl;
// std::cout << "alpha_mm: " << alpha_mm << std::endl;
if (p < 4)
throw Exception(
ErrorCodes::PARAMETER_OUT_OF_BOUND,
"HLL++ requires at least 4 bits for addressing instead of {}. Use a lower error, at most 39%",
p);

if (p > 25)
throw Exception(
ErrorCodes::PARAMETER_OUT_OF_BOUND,
"HLL++ requires at most 25 bits for addressing instead of {} to avoid allocating too much memory",
p);

// std::cout << "relative_sd:" << relative_sd << " p:" << p << " m:" << m << " num_words:" << num_words << " alpha_mm:" << alpha_mm
// << std::endl;
registers = PaddedPODArray<UInt64>(num_words, 0); // Initialize registers with zeros
}

void serialize(WriteBuffer & buf) const
Expand Down Expand Up @@ -2975,18 +2984,42 @@ struct HyperLogLogPlusPlusData
1238126.379, 1244673.795, 1251260.649, 1257697.86, 1264320.983, 1270736.319, 1277274.694, 1283804.95, 1290211.514,
1296858.568, 1303455.691}}};

/// TODO

PaddedPODArray<UInt64> registers; // Declaration of registers
};

class AggregateFunctionUniqHyperLogLogPlusPlus final
: public IAggregateFunctionDataHelper<HyperLogLogPlusPlusData, AggregateFunctionUniqHyperLogLogPlusPlus>
{
public:
explicit AggregateFunctionUniqHyperLogLogPlusPlus(const DataTypes & argument_types_)
explicit AggregateFunctionUniqHyperLogLogPlusPlus(const DataTypes & argument_types_, const Array & params)
: IAggregateFunctionDataHelper(argument_types_, {}, createResultType())
{
if (argument_types.size() != 1)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Aggregate function '{}' requires exactly one argument: got {}",
getName(),
argument_types.size());

WhichDataType which(argument_types[0]);
if (!which.isUInt64())
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Aggregate function '{}' requires an argument of type UInt64, got {}",
getName(),
argument_types[0]->getName());

if (params.empty())
relative_sd = 0.05;
else if (params.size() == 1)
{
relative_sd = applyVisitor(FieldVisitorConvertToNumber<Float64>(), params[0]);
if (isNaN(relative_sd) || relative_sd <= 0 || relative_sd > 1)
throw Exception(ErrorCodes::PARAMETER_OUT_OF_BOUND, "Relative standard deviation must be in the range (0, 1]");
}
else
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function {} requires one parameter: relative_sd", getName());
}

String getName() const override { return "uniqHLLPP"; }
Expand All @@ -2995,6 +3028,11 @@ class AggregateFunctionUniqHyperLogLogPlusPlus final

bool allocatesMemoryInArena() const override { return false; }

void create(AggregateDataPtr __restrict place) const override /// NOLINT
{
new (place) Data(relative_sd);
}

void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{
auto value = static_cast<const ColumnUInt64 &>(*columns[0]).getData()[row_num];
Expand All @@ -3010,10 +3048,25 @@ class AggregateFunctionUniqHyperLogLogPlusPlus final
data(place).deserialize(buf);
}

DataTypePtr getNormalizedStateType() const override
{
/// Return normalized state type: quantiles*(1)(...)
Array params{0.05};
AggregateFunctionProperties properties = {.returns_default_when_only_null = true, .is_order_dependent = false};
return std::make_shared<DataTypeAggregateFunction>(
AggregateFunctionFactory::instance().get(getName(), NullsAction::EMPTY, this->argument_types, params, properties),
this->argument_types,
params);
}

void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
static_cast<ColumnUInt64 &>(to).getData().push_back(data(place).query());
}

private:
/// Defines the maximum relative standard deviation allowed.
Float64 relative_sd;
};

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Parser/AggregateFunctionParser.h>
#include <DataTypes/DataTypeNullable.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>
#include "DataTypes/DataTypeAggregateFunction.h"
#include "DataTypes/DataTypesNumber.h"
#include "Functions/FunctionHelpers.h"

namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int BAD_ARGUMENTS;
}
}

namespace local_engine
{

using namespace DB;

struct ApproxCountDistinctNameStruct
{
static constexpr auto spark_name = "approx_count_distinct";
static constexpr auto ch_name = "uniqHLLPP";
};

/// Spark approx_count_distinct(expr, relative_sd) = CH uniqHLLPP(relative_sd)(if(isNull(expr), null, sparkXxHash64(expr)))
/// Spark approx_count_distinct(expr) = CH uniqHLLPP(0.05)(if(isNull(expr), null, sparkXxHash64(expr)))
template <typename NameStruct>
class AggregateFunctionParserApproxCountDistinct final : public AggregateFunctionParser
{
public:
static constexpr auto name = NameStruct::spark_name;

AggregateFunctionParserApproxCountDistinct(ParserContextPtr parser_context_) : AggregateFunctionParser(parser_context_) { }
~AggregateFunctionParserApproxCountDistinct() override = default;

String getName() const override { return NameStruct::spark_name; }

String getCHFunctionName(const CommonFunctionInfo &) const override { return NameStruct::ch_name; }

String getCHFunctionName(DataTypes & types) const override
{
/// Always invoked during second stage, the first argument is expr, the second argument is relative_sd.
/// 1. Remove the second argument because types are used to create the aggregate function.
/// 2. Replace the first argument type with UInt64 or Nullable(UInt64) because uniqHLLPP requres it.
types.resize(1);
const auto old_type = types[0];
types[0] = std::make_shared<DataTypeUInt64>();
if (old_type->isNullable())
types[0] = std::make_shared<DataTypeNullable>(types[0]);

return NameStruct::ch_name;
}

Array parseFunctionParameters(
const CommonFunctionInfo & func_info, ActionsDAG::NodeRawConstPtrs & arg_nodes, ActionsDAG & actions_dag) const override
{
if (func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE
|| func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT
|| func_info.phase == substrait::AGGREGATION_PHASE_UNSPECIFIED)
{
const auto & arguments = func_info.arguments;
const size_t num_args = arguments.size();
const size_t num_nodes = arg_nodes.size();
if (num_args != num_nodes || num_args > 2 || num_args < 1 || num_nodes > 2 || num_nodes < 1)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} takes 1 or 2 arguments in phase {}",
getName(),
magic_enum::enum_name(func_info.phase));

Array params;
if (num_args == 2)
{
const auto & relative_sd_arg = arguments[1].value();
if (relative_sd_arg.has_literal())
{
auto [_, field] = parseLiteral(relative_sd_arg.literal());
params.push_back(std::move(field));
}
else
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Second argument of function {} must be literal", getName());
}
else
{
params.push_back(0.05);
}

const auto & expr_arg = arg_nodes[0];
const auto * is_null_node = toFunctionNode(actions_dag, "isNull", {expr_arg});
const auto * hash_node = toFunctionNode(actions_dag, "sparkXxHash64", {expr_arg});
const auto * null_node
= addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeNullable>(std::make_shared<DataTypeUInt64>()), {});
const auto * if_node = toFunctionNode(actions_dag, "if", {is_null_node, null_node, hash_node});
/// Replace the first argument expr with if(isNull(expr), null, sparkXxHash64(expr))
arg_nodes[0] = if_node;
arg_nodes.resize(1);

return params;
}
else
{
if (arg_nodes.size() != 1)
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} takes 1 argument in phase {}",
getName(),
magic_enum::enum_name(func_info.phase));

const auto & result_type = arg_nodes[0]->result_type;
const auto * aggregate_function_type = checkAndGetDataType<DataTypeAggregateFunction>(result_type.get());
if (!aggregate_function_type)
throw Exception(
ErrorCodes::BAD_ARGUMENTS,
"The first argument type of function {} in phase {} must be AggregateFunction, but is {}",
getName(),
magic_enum::enum_name(func_info.phase),
result_type->getName());

return aggregate_function_type->getParameters();
}
}

Array getDefaultFunctionParameters() const override { return {0.05}; }
};

static const AggregateFunctionParserRegister<AggregateFunctionParserApproxCountDistinct<ApproxCountDistinctNameStruct>> registerer_approx_count_distinct;
}
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ object ExpressionMappings {
Sig[Skewness](SKEWNESS),
Sig[Kurtosis](KURTOSIS),
Sig[ApproximatePercentile](APPROX_PERCENTILE),
Sig[HyperLogLogPlusPlus](APPROX_COUNT_DISTINCT),
Sig[Percentile](PERCENTILE)
) ++ SparkShimLoader.getSparkShims.aggregateExpressionMappings

Expand Down
Loading

0 comments on commit 92e9224

Please sign in to comment.