Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-8528][CH]Support approx_count_distinct #8550

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

taiyang-li
Copy link
Contributor

@taiyang-li taiyang-li commented Jan 16, 2025

What changes were proposed in this pull request?

(Please fill in changes proposed in this fix)

(Fixes: #8528)

How was this patch tested?

New added uts

@taiyang-li taiyang-li changed the title [GLUTEN-8528][CH]Support approx count distinct [GLUTEN-8528][CH]Support approx_count_distinct Jan 16, 2025
@github-actions github-actions bot added CORE works for Gluten Core CLICKHOUSE labels Jan 16, 2025
Copy link

#8528

Copy link

Run Gluten ClickHouse CI on ARM

@taiyang-li
Copy link
Contributor Author

@CodiumAI-Agent /review

@CodiumAI-Agent
Copy link

CodiumAI-Agent commented Feb 5, 2025

PR Reviewer Guide 🔍

(Review updated until commit 92e9224)

Here are some key observations to aid the review process:

🎫 Ticket compliance analysis 🔶

8528 - Partially compliant

Compliant requirements:

  • Implement support for approx_count_distinct functionality.
  • Ensure compatibility with Spark's approx_count_distinct function.
  • Add necessary tests to validate the implementation.

Non-compliant requirements:

Requires further human verification:

  • Validate the correctness of the approx_count_distinct implementation through integration testing.
  • Verify the performance impact of the changes in a real-world scenario.
⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
🔒 No security concerns identified
⚡ Recommended focus areas for review

Possible Issue

The logic for handling HyperLogLogPlusPlus in lines 279-289 introduces additional nodes for relativeSDLiteral. Ensure this does not inadvertently affect other aggregate functions or introduce unexpected behavior.

val extraNodes = aggregateFunc match {
  case hll: HyperLogLogPlusPlus =>
    val relativeSDLiteral = Literal(hll.relativeSD)
    Seq(
      ExpressionConverter
        .replaceWithExpressionTransformer(relativeSDLiteral, child.output)
        .doTransform(args))
  case _ => Seq.empty
}

nodes ++ extraNodes
Edge Case Handling

The parser implementation for approx_count_distinct (lines 48-148) should be reviewed for edge cases, such as invalid input types or unexpected argument counts.

template <typename NameStruct>
class AggregateFunctionParserApproxCountDistinct final : public AggregateFunctionParser
{
public:
    static constexpr auto name = NameStruct::spark_name;

    AggregateFunctionParserApproxCountDistinct(ParserContextPtr parser_context_) : AggregateFunctionParser(parser_context_) { }
    ~AggregateFunctionParserApproxCountDistinct() override = default;

    String getName() const override { return NameStruct::spark_name; }

    String getCHFunctionName(const CommonFunctionInfo &) const override { return NameStruct::ch_name; }

    String getCHFunctionName(DataTypes & types) const override
    {
        /// Always invoked during second stage, the first argument is expr, the second argument is relative_sd.
        /// 1. Remove the second argument because types are used to create the aggregate function.
        /// 2. Replace the first argument type with UInt64 or Nullable(UInt64) because uniqHLLPP requres it.
        types.resize(1);
        const auto old_type = types[0];
        types[0] = std::make_shared<DataTypeUInt64>();
        if (old_type->isNullable())
            types[0] = std::make_shared<DataTypeNullable>(types[0]);

        return NameStruct::ch_name;
    }

    Array parseFunctionParameters(
        const CommonFunctionInfo & func_info, ActionsDAG::NodeRawConstPtrs & arg_nodes, ActionsDAG & actions_dag) const override
    {
        if (func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE
            || func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT
            || func_info.phase == substrait::AGGREGATION_PHASE_UNSPECIFIED)
        {
            const auto & arguments = func_info.arguments;
            const size_t num_args = arguments.size();
            const size_t num_nodes = arg_nodes.size();
            if (num_args != num_nodes || num_args > 2 || num_args < 1 || num_nodes > 2 || num_nodes < 1)
                throw Exception(
                    ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
                    "Function {} takes 1 or 2 arguments in phase {}",
                    getName(),
                    magic_enum::enum_name(func_info.phase));

            Array params;
            if (num_args == 2)
            {
                const auto & relative_sd_arg = arguments[1].value();
                if (relative_sd_arg.has_literal())
                {
                    auto [_, field] = parseLiteral(relative_sd_arg.literal());
                    params.push_back(std::move(field));
                }
                else
                    throw Exception(ErrorCodes::BAD_ARGUMENTS, "Second argument of function {} must be literal", getName());
            }
            else
            {
                params.push_back(0.05);
            }

            const auto & expr_arg = arg_nodes[0];
            const auto * is_null_node = toFunctionNode(actions_dag, "isNull", {expr_arg});
            const auto * hash_node = toFunctionNode(actions_dag, "sparkXxHash64", {expr_arg});
            const auto * null_node
                = addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeNullable>(std::make_shared<DataTypeUInt64>()), {});
            const auto * if_node = toFunctionNode(actions_dag, "if", {is_null_node, null_node, hash_node});
            /// Replace the first argument expr with if(isNull(expr), null, sparkXxHash64(expr))
            arg_nodes[0] = if_node;
            arg_nodes.resize(1);

            return params;
        }
        else
        {
            if (arg_nodes.size() != 1)
                throw Exception(
                    ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
                    "Function {} takes 1 argument in phase {}",
                    getName(),
                    magic_enum::enum_name(func_info.phase));

            const auto & result_type = arg_nodes[0]->result_type;
            const auto * aggregate_function_type = checkAndGetDataType<DataTypeAggregateFunction>(result_type.get());
            if (!aggregate_function_type)
                throw Exception(
                    ErrorCodes::BAD_ARGUMENTS,
                    "The first argument type of function {} in phase {} must be AggregateFunction, but is {}",
                    getName(),
                    magic_enum::enum_name(func_info.phase),
                    result_type->getName());

            return aggregate_function_type->getParameters();
        }
    }

    Array getDefaultFunctionParameters() const override { return {0.05}; }
};

static const AggregateFunctionParserRegister<AggregateFunctionParserApproxCountDistinct<ApproxCountDistinctNameStruct>> registerer_approx_count_distinct;
}
Test Coverage

The added tests focus on specific scenarios for HyperLogLogPlusPlus. Ensure that all edge cases, such as empty inputs or extreme values, are covered.

#include <gtest/gtest.h>
#include <AggregateFunctions/AggregateFunctionUniqHyperLogLogPlusPlus.h>
#include "IO/ReadBufferFromString.h"

using namespace DB;

static std::vector<UInt64> random_uint64s
    = {17956993516945311251ULL,
       4306050051188505054ULL,
       14289061765075743502ULL,
       16763375724458316157ULL,
       6144297519955185930ULL,
       18446472757487308114ULL,
       16923578592198257123ULL,
       13557354668567515845ULL,
       15328387702200001967ULL,
       15878166530370497646ULL};

static void initSmallHLL(HyperLogLogPlusPlusData & hll)
{
    for (auto x : random_uint64s)
        hll.add(x);
}

static void initLargeHLL(HyperLogLogPlusPlusData & hll)
{
    for (auto x : random_uint64s)
    {
        for (size_t i = 0; i < 100; ++i)
            hll.add(x * (i+1));
    }
}

TEST(HyperLogLogPlusPlusDataTest, Small)
{
    HyperLogLogPlusPlusData hll;
    initSmallHLL(hll);
    EXPECT_EQ(hll.query(), 10);
}

TEST(HyperLogLogPlusPlusDataTest, Large)
{
    HyperLogLogPlusPlusData hll;
    initLargeHLL(hll);
    EXPECT_EQ(hll.query(), 806);
}

TEST(HyperLogLogPlusPlusDataTest, Merge) {
    HyperLogLogPlusPlusData hll1;
    initSmallHLL(hll1);

    HyperLogLogPlusPlusData hll2;
    initLargeHLL(hll2);

    hll1.merge(hll2);
    EXPECT_EQ(hll1.query(), 806);
}

TEST(HyperLogLogPlusPlusDataTest, SerializeAndDeserialize) {
    HyperLogLogPlusPlusData hll1;
    initLargeHLL(hll1);

    WriteBufferFromOwnString write_buffer;
    hll1.serialize(write_buffer);

    ReadBufferFromString read_buffer(write_buffer.str());
    HyperLogLogPlusPlusData hll2;
    hll2.deserialize(read_buffer);

    EXPECT_EQ(hll2.query(), 806);
}

@taiyang-li
Copy link
Contributor Author

taiyang-li commented Feb 10, 2025

Native approx_count_distinct implementation is 25x faster than that in vanilla spark...

0: jdbc:hive2://localhost:10000/> select approx_count_distinct(id, 0.001), approx_count_distinct(id, 0.01), approx_count_distinct(id, 0.1) from range(1000);    
+----------------------------+----------------------------+----------------------------+
| approx_count_distinct(id)  | approx_count_distinct(id)  | approx_count_distinct(id)  |
+----------------------------+----------------------------+----------------------------+
| 999                        | 996                        | 928                        |
+----------------------------+----------------------------+----------------------------+
1 row selected (5.82 seconds)
0: jdbc:hive2://localhost:10000/> 
0: jdbc:hive2://localhost:10000/> set spark.gluten.enabled = false; 
+-----------------------+--------+
|          key          | value  |
+-----------------------+--------+
| spark.gluten.enabled  | false  |
+-----------------------+--------+
1 row selected (0.137 seconds)
0: jdbc:hive2://localhost:10000/> select approx_count_distinct(id, 0.001), approx_count_distinct(id, 0.01), approx_count_distinct(id, 0.1) from range(1000);     
+----------------------------+----------------------------+----------------------------+
| approx_count_distinct(id)  | approx_count_distinct(id)  | approx_count_distinct(id)  |
+----------------------------+----------------------------+----------------------------+
| 999                        | 996                        | 928                        |
+----------------------------+----------------------------+----------------------------+
1 row selected (149.915 seconds)

@CodiumAI-Agent
Copy link

Persistent review updated to latest commit 92e9224

@taiyang-li taiyang-li marked this pull request as ready for review February 10, 2025 08:57
Copy link

Run Gluten ClickHouse CI on ARM

Copy link

Run Gluten ClickHouse CI on ARM

Copy link

Run Gluten ClickHouse CI on ARM

@zhanglistar
Copy link
Contributor

Lets' enable spark hll UT to see what will happen.

Copy link

Run Gluten ClickHouse CI on ARM

@taiyang-li
Copy link
Contributor Author

Run Gluten ClickHouse CI on ARM

done.

Copy link

Run Gluten ClickHouse CI on ARM

@lgbo-ustc
Copy link
Contributor

LGTM

Copy link
Contributor

@zhanglistar zhanglistar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better add some comment about HLLPP for later feature readers.

#include <DataTypes/DataTypeNullable.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>
#include "DataTypes/DataTypeAggregateFunction.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use <> instead of " in include clause

@@ -25,6 +25,7 @@
#include <Parser/TypeParser.h>
#include <Common/CHUtil.h>
#include <Common/Exception.h>
#include <Common/logger_useful.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Useless header


inline static const std::vector<std::vector<double>> BIAS_DATA = {
// precision 4
{10,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里格式化下吧每行10个元素。


struct HyperLogLogPlusPlusData
{
explicit HyperLogLogPlusPlusData(double relative_sd_ = 0.05)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Float64 for consistency

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLICKHOUSE CORE works for Gluten Core
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[CH] support approx_count_distinct
4 participants