Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test: wrap some API for aggregation function tests #5787

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions dbms/src/Flash/tests/gtest_aggregation_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <TestUtils/ExecutorTestUtils.h>
#include <TestUtils/AggregationTestUtils.h>
#include <TestUtils/mockExecutor.h>

namespace DB
Expand All @@ -30,7 +30,7 @@ namespace tests
types_col_name[a], types_col_name[b] \
}

class ExecutorAggTestRunner : public DB::tests::ExecutorTest
class ExecutorAggTestRunner : public DB::tests::AggregationTest
{
public:
using ColStringNullableType = std::optional<typename TypeTraits<String>::FieldType>;
Expand Down Expand Up @@ -355,6 +355,16 @@ try
}
CATCH

TEST_F(ExecutorAggTestRunner, TestFramwork)
try
{
executeGroupByAndAssert({toNullableVec<Int8>("tinyint_", col_tinyint)}, {toNullableVec<Int8>({-1, 2, {}, 0, 1, 3, -2})});
executeGroupByAndAssert({toNullableVec<Int8>("tinyint_", col_tinyint), toNullableVec<Int16>("smallint_", col_smallint)}, {toNullableVec<Int8>({0, 2, 0, -1, 1, -2, 3, {}, {}}), toNullableVec<Int16>({-1, 3, -2, 4, 2, 0, {}, {}, 0})});
executeAggFunctionAndAssert({"Max"}, toNullableVec<Int8>("tinyint_", col_tinyint), {toNullableVec<Int8>(ColumnWithNullableInt8{3})});
executeAggFunctionAndAssert({"Max", "Min"}, toNullableVec<Int8>("tinyint_", col_tinyint), {toNullableVec<Int8>(ColumnWithNullableInt8{3}), toNullableVec<Int8>(ColumnWithNullableInt8{-2})});
}
CATCH

// TODO support more type of min, max, count.
// support more aggregation functions: sum, forst_row, group_concat

Expand Down
80 changes: 80 additions & 0 deletions dbms/src/TestUtils/AggregationTestUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,85 @@ void AggregationTest::SetUpTestCase()
};

register_func(DB::registerAggregateFunctions);
register_func(DB::registerFunctions);
}

::testing::AssertionResult AggregationTest::checkAggReturnType(const String & agg_name, const DataTypes & data_types, const DataTypePtr & expect_type)
{
AggregateFunctionPtr agg_ptr = DB::AggregateFunctionFactory::instance().get(agg_name, data_types, {});
const DataTypePtr & ret_type = agg_ptr->getReturnType();
if (ret_type->equals(*expect_type))
return ::testing::AssertionSuccess();
return ::testing::AssertionFailure() << "Expect type: " << expect_type->getName() << " Actual type: " << ret_type->getName();
}

void AggregationTest::executeAggFunctionAndAssert(const std::vector<String> & func_names, const ColumnWithTypeAndName & column, const ColumnsWithTypeAndName & expected_cols)
{
String db_name = "test_agg_function";
String table_name = "test_table_agg";
std::vector<ASTPtr> agg_funcs;
for (const auto & func_name : func_names)
agg_funcs.push_back(aggFunctionBuilder(func_name, column.name));

MockColumnInfoVec column_infos;
column_infos.push_back({column.name, dataTypeToTP(column.type)});
context.addMockTable(db_name, table_name, column_infos, {column});

auto request = context.scan(db_name, table_name)
.aggregation(agg_funcs, {})
.build(context);

checkResult(request, expected_cols);
}

void AggregationTest::executeGroupByAndAssert(const ColumnsWithTypeAndName & cols, const ColumnsWithTypeAndName & expected_cols)
{
RUNTIME_CHECK_MSG(cols.size() == expected_cols.size(), "number of group_by columns don't match number of expected columns");

String db_name = "test_group";
String table_name = "test_table_group";
MockAstVec group_by_cols;
MockColumnNameVec proj_names;
MockColumnInfoVec column_infos;
for (const auto & col : cols)
{
group_by_cols.push_back(col(col.name));
proj_names.push_back(col.name);
column_infos.push_back({col.name, dataTypeToTP(col.type)});
}

context.addMockTable(db_name, table_name, column_infos, cols);

auto request = context.scan(db_name, table_name)
.aggregation({}, group_by_cols)
.project(proj_names)
.build(context);

checkResult(request, expected_cols);
}

void AggregationTest::checkResult(std::shared_ptr<tipb::DAGRequest> request, const ColumnsWithTypeAndName & expected_cols)
{
for (size_t i = 1; i <= 10; ++i)
ASSERT_COLUMNS_EQ_UR(expected_cols, executeStreams(request, i)) << "expected_cols: " << getColumnsContent(expected_cols) << ", actual_cols: " << getColumnsContent(executeStreams(request, i));
}

ASTPtr AggregationTest::aggFunctionBuilder(const String & func_name, const String & col_name)
{
ASTPtr func;
String func_name_lowercase = Poco::toLower(func_name);

// TODO support more agg functions.
if (func_name_lowercase == "max")
func = Max(col(col_name));
else if (func_name_lowercase == "min")
func = Min(col(col_name));
else if (func_name_lowercase == "count")
func = Count(col(col_name));
else if (func_name_lowercase == "sum")
func = Sum(col(col_name));
else
throw Exception(fmt::format("Unsupported agg function {}", func_name), ErrorCodes::LOGICAL_ERROR);
return func;
}
} // namespace DB::tests
32 changes: 19 additions & 13 deletions dbms/src/TestUtils/AggregationTestUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,33 @@
#pragma once

#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/registerAggregateFunctions.h>
#include <TestUtils/TiFlashTestEnv.h>
#include <gtest/gtest.h>
#include <TestUtils/ExecutorTestUtils.h>

namespace DB::tests
{

class AggregationTest : public ::testing::Test
class AggregationTest : public ExecutorTest
{
public:
::testing::AssertionResult checkAggReturnType(const String & agg_name, const DataTypes & data_types, const DataTypePtr & expect_type)
{
AggregateFunctionPtr agg_ptr = DB::AggregateFunctionFactory::instance().get(agg_name, data_types, {});
const DataTypePtr & ret_type = agg_ptr->getReturnType();
if (ret_type->equals(*expect_type))
return ::testing::AssertionSuccess();
return ::testing::AssertionFailure() << "Expect type: " << expect_type->getName() << " Actual type: " << ret_type->getName();
}
static ::testing::AssertionResult checkAggReturnType(const String & agg_name, const DataTypes & data_types, const DataTypePtr & expect_type);

// Test aggregation functions without group by.
void executeAggFunctionAndAssert(
const std::vector<String> & func_names,
const ColumnWithTypeAndName & column,
const ColumnsWithTypeAndName & expected_cols);

// Test group by columns
// Note that we must give columns in cols a name.
void executeGroupByAndAssert(
const ColumnsWithTypeAndName & cols,
const ColumnsWithTypeAndName & expected_cols);

static void SetUpTestCase();

private:
void checkResult(std::shared_ptr<tipb::DAGRequest> request, const ColumnsWithTypeAndName & expected_cols);
ASTPtr aggFunctionBuilder(const String & func_name, const String & col_name);
};

} // namespace DB::tests