From 3ef3154089e9c24614038db873f71d3d80d07934 Mon Sep 17 00:00:00 2001 From: SeaRise Date: Thu, 20 Apr 2023 12:07:19 +0800 Subject: [PATCH] Pipeline: support spill for fine grained aggregation (#7220) ref pingcap/tiflash#6518 --- .../AggregatingBlockInputStream.cpp | 2 +- .../ParallelAggregatingBlockInputStream.cpp | 23 ++-- .../Plans/PhysicalAggregationConvergent.cpp | 32 ++---- .../Flash/tests/gtest_spill_aggregation.cpp | 84 +++++++++++---- dbms/src/Interpreters/Aggregator.cpp | 32 +++++- dbms/src/Interpreters/Aggregator.h | 4 + dbms/src/Operators/AggregateContext.cpp | 68 ++++++++---- dbms/src/Operators/AggregateContext.h | 34 ++++-- dbms/src/Operators/BucketInput.cpp | 65 +++++++++++ dbms/src/Operators/BucketInput.h | 42 ++++++++ dbms/src/Operators/LocalAggregateRestorer.cpp | 102 ++++++++++++++++++ dbms/src/Operators/LocalAggregateRestorer.h | 78 ++++++++++++++ .../src/Operators/LocalAggregateTransform.cpp | 102 ++++++++++++++---- dbms/src/Operators/LocalAggregateTransform.h | 37 +++++-- dbms/src/TestUtils/ExecutorTestUtils.cpp | 17 ++- dbms/src/TestUtils/ExecutorTestUtils.h | 8 +- 16 files changed, 603 insertions(+), 127 deletions(-) create mode 100644 dbms/src/Operators/BucketInput.cpp create mode 100644 dbms/src/Operators/BucketInput.h create mode 100644 dbms/src/Operators/LocalAggregateRestorer.cpp create mode 100644 dbms/src/Operators/LocalAggregateRestorer.h diff --git a/dbms/src/DataStreams/AggregatingBlockInputStream.cpp b/dbms/src/DataStreams/AggregatingBlockInputStream.cpp index 52839aa1724..85b9772cb4c 100644 --- a/dbms/src/DataStreams/AggregatingBlockInputStream.cpp +++ b/dbms/src/DataStreams/AggregatingBlockInputStream.cpp @@ -63,7 +63,7 @@ Block AggregatingBlockInputStream::readImpl() if (!isCancelled()) { /// Flush data in the RAM to disk also. It's easier than merging on-disk and RAM data. - if (!data_variants->empty()) + if (data_variants->tryMarkNeedSpill()) aggregator.spill(*data_variants); } aggregator.finishSpill(); diff --git a/dbms/src/DataStreams/ParallelAggregatingBlockInputStream.cpp b/dbms/src/DataStreams/ParallelAggregatingBlockInputStream.cpp index a3c8729fa26..1f945aa6ef8 100644 --- a/dbms/src/DataStreams/ParallelAggregatingBlockInputStream.cpp +++ b/dbms/src/DataStreams/ParallelAggregatingBlockInputStream.cpp @@ -134,11 +134,14 @@ Block ParallelAggregatingBlockInputStream::readImpl() void ParallelAggregatingBlockInputStream::Handler::onBlock(Block & block, size_t thread_num) { + auto & data = *parent.many_data[thread_num]; parent.aggregator.executeOnBlock( block, - *parent.many_data[thread_num], + data, parent.threads_data[thread_num].key_columns, parent.threads_data[thread_num].aggregate_columns); + if (data.need_spill) + parent.aggregator.spill(data); parent.threads_data[thread_num].src_rows += block.rows(); parent.threads_data[thread_num].src_bytes += block.bytes(); @@ -150,11 +153,7 @@ void ParallelAggregatingBlockInputStream::Handler::onFinishThread(size_t thread_ { /// Flush data in the RAM to disk. So it's easier to unite them later. auto & data = *parent.many_data[thread_num]; - - if (data.isConvertibleToTwoLevel()) - data.convertToTwoLevel(); - - if (!data.empty()) + if (data.tryMarkNeedSpill()) parent.aggregator.spill(data); } } @@ -167,10 +166,7 @@ void ParallelAggregatingBlockInputStream::Handler::onFinish() /// because at the time of `onFinishThread` call, no data has been flushed to disk, and then some were. for (auto & data : parent.many_data) { - if (data->isConvertibleToTwoLevel()) - data->convertToTwoLevel(); - - if (!data->empty()) + if (data->tryMarkNeedSpill()) parent.aggregator.spill(*data); } } @@ -245,11 +241,16 @@ void ParallelAggregatingBlockInputStream::execute() /// If there was no data, and we aggregate without keys, we must return single row with the result of empty aggregation. /// To do this, we pass a block with zero rows to aggregate. if (total_src_rows == 0 && params.keys_size == 0 && !params.empty_result_for_aggregation_by_empty_set) + { + auto & data = *many_data[0]; aggregator.executeOnBlock( children.at(0)->getHeader(), - *many_data[0], + data, threads_data[0].key_columns, threads_data[0].aggregate_columns); + if (data.need_spill) + aggregator.spill(data); + } } void ParallelAggregatingBlockInputStream::appendInfo(FmtBuffer & buffer) const diff --git a/dbms/src/Flash/Planner/Plans/PhysicalAggregationConvergent.cpp b/dbms/src/Flash/Planner/Plans/PhysicalAggregationConvergent.cpp index 2e9177c58df..1a17048d3ce 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalAggregationConvergent.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalAggregationConvergent.cpp @@ -29,29 +29,15 @@ void PhysicalAggregationConvergent::buildPipelineExecGroup( assert(!fine_grained_shuffle.enable()); aggregate_context->initConvergent(); - - if (unlikely(aggregate_context->useNullSource())) - { - group_builder.init(1); - group_builder.transform([&](auto & builder) { - builder.setSourceOp(std::make_unique( - exec_status, - aggregate_context->getHeader(), - log->identifier())); - }); - } - else - { - group_builder.init(aggregate_context->getConvergentConcurrency()); - size_t index = 0; - group_builder.transform([&](auto & builder) { - builder.setSourceOp(std::make_unique( - exec_status, - aggregate_context, - index++, - log->identifier())); - }); - } + group_builder.init(aggregate_context->getConvergentConcurrency()); + size_t index = 0; + group_builder.transform([&](auto & builder) { + builder.setSourceOp(std::make_unique( + exec_status, + aggregate_context, + index++, + log->identifier())); + }); executeExpression(exec_status, group_builder, expr_after_agg, log); } diff --git a/dbms/src/Flash/tests/gtest_spill_aggregation.cpp b/dbms/src/Flash/tests/gtest_spill_aggregation.cpp index 417592a9353..03b28982ae6 100644 --- a/dbms/src/Flash/tests/gtest_spill_aggregation.cpp +++ b/dbms/src/Flash/tests/gtest_spill_aggregation.cpp @@ -58,36 +58,34 @@ try context.context->setSetting("max_block_size", Field(static_cast(max_block_size))); /// disable spill context.context->setSetting("max_bytes_before_external_group_by", Field(static_cast(0))); - auto ref_columns = executeStreams(request, original_max_streams, true); + auto ref_columns = executeStreams(request, original_max_streams); /// enable spill context.context->setSetting("max_bytes_before_external_group_by", Field(static_cast(total_data_size / 200))); context.context->setSetting("group_by_two_level_threshold", Field(static_cast(1))); context.context->setSetting("group_by_two_level_threshold_bytes", Field(static_cast(1))); /// don't use `executeAndAssertColumnsEqual` since it takes too long to run /// test single thread aggregation - /// need to enable memory tracker since currently, the memory usage in aggregator is - /// calculated by memory tracker, if memory tracker is not enabled, spill will never be triggered. - ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, 1, true)); + ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, 1)); /// test parallel aggregation - ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams, true)); + ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams)); /// enable spill and use small max_cached_data_bytes_in_spiller context.context->setSetting("max_cached_data_bytes_in_spiller", Field(static_cast(total_data_size / 200))); /// test single thread aggregation - ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, 1, true)); + ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, 1)); /// test parallel aggregation - ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams, true)); + ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams)); /// test spill with small max_block_size /// the avg rows in one bucket is ~10240/256 = 400, so set the small_max_block_size to 300 /// is enough to test the output spilt size_t small_max_block_size = 300; context.context->setSetting("max_block_size", Field(static_cast(small_max_block_size))); - auto blocks = getExecuteStreamsReturnBlocks(request, 1, true); + auto blocks = getExecuteStreamsReturnBlocks(request, 1); for (auto & block : blocks) { ASSERT_EQ(block.rows() <= small_max_block_size, true); } ASSERT_COLUMNS_EQ_UR(ref_columns, vstackBlocks(std::move(blocks)).getColumnsWithTypeAndName()); - blocks = getExecuteStreamsReturnBlocks(request, original_max_streams, true); + blocks = getExecuteStreamsReturnBlocks(request, original_max_streams); for (auto & block : blocks) { ASSERT_EQ(block.rows() <= small_max_block_size, true); @@ -166,10 +164,7 @@ try context.context->setSetting("group_by_two_level_threshold_bytes", Field(static_cast(0))); context.context->setSetting("max_bytes_before_external_group_by", Field(static_cast(0))); context.context->setSetting("max_block_size", Field(static_cast(unique_rows * 2))); - /// here has to enable memory tracker otherwise the processList in the context is the last query's processList - /// and may cause segment fault, maybe a bug but should not happens in TiDB because all the tasks from tidb - /// enable memory tracker - auto reference = executeStreams(request, 1, true); + auto reference = executeStreams(request, 1); if (current_collator->isCI()) { /// for ci collation, need to sort and compare the result manually @@ -198,7 +193,7 @@ try context.context->setSetting("group_by_two_level_threshold_bytes", Field(static_cast(1))); context.context->setSetting("max_bytes_before_external_group_by", Field(static_cast(max_bytes_before_external_agg))); context.context->setSetting("max_block_size", Field(static_cast(max_block_size))); - auto blocks = getExecuteStreamsReturnBlocks(request, concurrency, true); + auto blocks = getExecuteStreamsReturnBlocks(request, concurrency); for (auto & block : blocks) { block.checkNumberOfRows(); @@ -302,10 +297,7 @@ try context.context->setSetting("group_by_two_level_threshold_bytes", Field(static_cast(0))); context.context->setSetting("max_bytes_before_external_group_by", Field(static_cast(0))); context.context->setSetting("max_block_size", Field(static_cast(unique_rows * 2))); - /// here has to enable memory tracker otherwise the processList in the context is the last query's processList - /// and may cause segment fault, maybe a bug but should not happens in TiDB because all the tasks from tidb - /// enable memory tracker - auto reference = executeStreams(request, 1, true); + auto reference = executeStreams(request, 1); if (current_collator->isCI()) { /// for ci collation, need to sort and compare the result manually @@ -334,7 +326,7 @@ try context.context->setSetting("group_by_two_level_threshold_bytes", Field(static_cast(1))); context.context->setSetting("max_bytes_before_external_group_by", Field(static_cast(max_bytes_before_external_agg))); context.context->setSetting("max_block_size", Field(static_cast(max_block_size))); - auto blocks = getExecuteStreamsReturnBlocks(request, concurrency, true); + auto blocks = getExecuteStreamsReturnBlocks(request, concurrency); for (auto & block : blocks) { block.checkNumberOfRows(); @@ -358,5 +350,59 @@ try } } CATCH + +TEST_F(SpillAggregationTestRunner, FineGrainedShuffle) +try +{ + DB::MockColumnInfoVec column_infos{{"a", TiDB::TP::TypeLongLong}, {"b", TiDB::TP::TypeLongLong}, {"c", TiDB::TP::TypeLongLong}, {"d", TiDB::TP::TypeLongLong}, {"e", TiDB::TP::TypeLongLong}}; + DB::MockColumnInfoVec partition_column_infos{{"a", TiDB::TP::TypeLongLong}, {"b", TiDB::TP::TypeLongLong}}; + ColumnsWithTypeAndName column_datas; + size_t table_rows = 5120; + size_t duplicated_rows = 2560; + UInt64 max_block_size = 100; + size_t total_data_size = 0; + for (const auto & column_info : mockColumnInfosToTiDBColumnInfos(column_infos)) + { + ColumnGeneratorOpts opts{table_rows, getDataTypeByColumnInfoForComputingLayer(column_info)->getName(), RANDOM, column_info.name}; + column_datas.push_back(ColumnGenerator::instance().generate(opts)); + total_data_size += column_datas.back().column->byteSize(); + } + for (auto & column_data : column_datas) + column_data.column->assumeMutable()->insertRangeFrom(*column_data.column, 0, duplicated_rows); + context.addExchangeReceiver("exchange_receiver_1_concurrency", column_infos, column_datas, 1, partition_column_infos); + context.addExchangeReceiver("exchange_receiver_3_concurrency", column_infos, column_datas, 3, partition_column_infos); + context.addExchangeReceiver("exchange_receiver_5_concurrency", column_infos, column_datas, 5, partition_column_infos); + context.addExchangeReceiver("exchange_receiver_10_concurrency", column_infos, column_datas, 10, partition_column_infos); + std::vector exchange_receiver_concurrency = {1, 3, 5, 10}; + + auto gen_request = [&](size_t exchange_concurrency) { + return context + .receive(fmt::format("exchange_receiver_{}_concurrency", exchange_concurrency), exchange_concurrency) + .aggregation({Min(col("c")), Max(col("d")), Count(col("e"))}, {col("a"), col("b")}, exchange_concurrency) + .build(context); + }; + context.context->setSetting("max_block_size", Field(static_cast(max_block_size))); + + /// disable spill + context.context->setSetting("max_bytes_before_external_group_by", Field(static_cast(0))); + enablePipeline(false); + auto baseline = executeStreams(gen_request(1), 1); + + /// enable spill + context.context->setSetting("max_bytes_before_external_group_by", Field(static_cast(total_data_size / 200))); + context.context->setSetting("group_by_two_level_threshold", Field(static_cast(1))); + context.context->setSetting("group_by_two_level_threshold_bytes", Field(static_cast(1))); + for (size_t exchange_concurrency : exchange_receiver_concurrency) + { + /// don't use `executeAndAssertColumnsEqual` since it takes too long to run + auto request = gen_request(exchange_concurrency); + enablePipeline(false); + ASSERT_COLUMNS_EQ_UR(baseline, executeStreams(request, exchange_concurrency)); + enablePipeline(true); + ASSERT_COLUMNS_EQ_UR(baseline, executeStreams(request, exchange_concurrency)); + } +} +CATCH + } // namespace tests } // namespace DB diff --git a/dbms/src/Interpreters/Aggregator.cpp b/dbms/src/Interpreters/Aggregator.cpp index e8fe5800a3f..b67e8314e00 100644 --- a/dbms/src/Interpreters/Aggregator.cpp +++ b/dbms/src/Interpreters/Aggregator.cpp @@ -65,6 +65,22 @@ AggregatedDataVariants::~AggregatedDataVariants() destroyAggregationMethodImpl(); } +bool AggregatedDataVariants::tryMarkNeedSpill() +{ + assert(!need_spill); + if (empty()) + return false; + if (!isTwoLevel()) + { + /// Data can only be flushed to disk if a two-level aggregation is supported. + if (!isConvertibleToTwoLevel()) + return false; + convertToTwoLevel(); + } + need_spill = true; + return true; +} + void AggregatedDataVariants::destroyAggregationMethodImpl() { if (!aggregation_method_impl) @@ -719,6 +735,8 @@ bool Aggregator::executeOnBlock( ColumnRawPtrs & key_columns, AggregateColumns & aggregate_columns) { + assert(!result.need_spill); + if (is_cancelled()) return true; @@ -811,15 +829,11 @@ bool Aggregator::executeOnBlock( result.convertToTwoLevel(); /** Flush data to disk if too much RAM is consumed. - * Data can only be flushed to disk if a two-level aggregation is supported. */ if (max_bytes_before_external_group_by && result_size > 0 - && (result.isTwoLevel() || result.isConvertibleToTwoLevel()) && result_size_bytes > max_bytes_before_external_group_by) { - if (!result.isTwoLevel()) - result.convertToTwoLevel(); - spill(result); + result.tryMarkNeedSpill(); } return true; @@ -847,6 +861,7 @@ void Aggregator::initThresholdByAggregatedDataVariantsSize(size_t aggregated_dat void Aggregator::spill(AggregatedDataVariants & data_variants) { + assert(data_variants.need_spill); bool init_value = false; if (spill_triggered.compare_exchange_strong(init_value, true, std::memory_order_relaxed)) { @@ -872,6 +887,7 @@ void Aggregator::spill(AggregatedDataVariants & data_variants) /// NOTE Instead of freeing up memory and creating new hash tables and arenas, you can re-use the old ones. data_variants.init(data_variants.type); + data_variants.need_spill = false; data_variants.aggregates_pools = Arenas(1, std::make_shared()); data_variants.aggregates_pool = data_variants.aggregates_pools.back().get(); data_variants.without_key = nullptr; @@ -978,12 +994,18 @@ void Aggregator::execute(const BlockInputStreamPtr & stream, AggregatedDataVaria if (!executeOnBlock(block, result, key_columns, aggregate_columns)) break; + if (result.need_spill) + spill(result); } /// If there was no data, and we aggregate without keys, and we must return single row with the result of empty aggregation. /// To do this, we pass a block with zero rows to aggregate. if (result.empty() && params.keys_size == 0 && !params.empty_result_for_aggregation_by_empty_set) + { executeOnBlock(stream->getHeader(), result, key_columns, aggregate_columns); + if (result.need_spill) + spill(result); + } double elapsed_seconds = watch.elapsedSeconds(); size_t rows = result.size(); diff --git a/dbms/src/Interpreters/Aggregator.h b/dbms/src/Interpreters/Aggregator.h index 55b5481e077..fddc245c54d 100644 --- a/dbms/src/Interpreters/Aggregator.h +++ b/dbms/src/Interpreters/Aggregator.h @@ -672,6 +672,10 @@ struct AggregatedDataVariants : private boost::noncopyable Type type{Type::EMPTY}; + bool need_spill = false; + + bool tryMarkNeedSpill(); + void destroyAggregationMethodImpl(); AggregatedDataVariants() diff --git a/dbms/src/Operators/AggregateContext.cpp b/dbms/src/Operators/AggregateContext.cpp index 11958282cd9..110d707ddd6 100644 --- a/dbms/src/Operators/AggregateContext.cpp +++ b/dbms/src/Operators/AggregateContext.cpp @@ -18,7 +18,8 @@ namespace DB { void AggregateContext::initBuild(const Aggregator::Params & params, size_t max_threads_, Aggregator::CancellationHook && hook) { - RUNTIME_CHECK(!inited_build && !inited_convergent); + assert(status.load() == AggStatus::init); + is_cancelled = std::move(hook); max_threads = max_threads_; empty_result_for_aggregation_by_empty_set = params.empty_result_for_aggregation_by_empty_set; keys_size = params.keys_size; @@ -31,21 +32,53 @@ void AggregateContext::initBuild(const Aggregator::Params & params, size_t max_t } aggregator = std::make_unique(params, log->identifier()); - aggregator->setCancellationHook(std::move(hook)); + aggregator->setCancellationHook(is_cancelled); aggregator->initThresholdByAggregatedDataVariantsSize(many_data.size()); - inited_build = true; + status = AggStatus::build; build_watch.emplace(); LOG_TRACE(log, "Aggregate Context inited"); } void AggregateContext::buildOnBlock(size_t task_index, const Block & block) { - RUNTIME_CHECK(inited_build && !inited_convergent); + assert(status.load() == AggStatus::build); aggregator->executeOnBlock(block, *many_data[task_index], threads_data[task_index]->key_columns, threads_data[task_index]->aggregate_columns); threads_data[task_index]->src_bytes += block.bytes(); threads_data[task_index]->src_rows += block.rows(); } +bool AggregateContext::hasSpilledData() const +{ + assert(status.load() == AggStatus::build); + return aggregator->hasSpilledData(); +} + +bool AggregateContext::needSpill(size_t task_index, bool try_mark_need_spill) +{ + assert(status.load() == AggStatus::build); + auto & data = *many_data[task_index]; + if (try_mark_need_spill && !data.need_spill) + data.tryMarkNeedSpill(); + return data.need_spill; +} + +void AggregateContext::spillData(size_t task_index) +{ + assert(status.load() == AggStatus::build); + aggregator->spill(*many_data[task_index]); +} + +LocalAggregateRestorerPtr AggregateContext::buildLocalRestorer() +{ + assert(status.load() == AggStatus::build); + aggregator->finishSpill(); + LOG_INFO(log, "Begin restore data from disk for local aggregation."); + auto input_streams = aggregator->restoreSpilledData(); + status = AggStatus::restore; + RUNTIME_CHECK_MSG(!input_streams.empty(), "There will be at least one spilled file."); + return std::make_unique(input_streams, *aggregator, is_cancelled, log->identifier()); +} + void AggregateContext::initConvergentPrefix() { assert(build_watch); @@ -87,43 +120,32 @@ void AggregateContext::initConvergentPrefix() void AggregateContext::initConvergent() { - RUNTIME_CHECK(inited_build && !inited_convergent); + assert(status.load() == AggStatus::build); initConvergentPrefix(); merging_buckets = aggregator->mergeAndConvertToBlocks(many_data, true, max_threads); - inited_convergent = true; + status = AggStatus::convergent; RUNTIME_CHECK(!merging_buckets || merging_buckets->getConcurrency() > 0); } size_t AggregateContext::getConvergentConcurrency() { - RUNTIME_CHECK(inited_convergent); - - return isTwoLevel() ? merging_buckets->getConcurrency() : 1; + assert(status.load() == AggStatus::convergent); + return merging_buckets ? merging_buckets->getConcurrency() : 1; } Block AggregateContext::getHeader() const { - RUNTIME_CHECK(inited_build); + assert(aggregator); return aggregator->getHeader(true); } -bool AggregateContext::isTwoLevel() -{ - RUNTIME_CHECK(inited_build); - return many_data[0]->isTwoLevel(); -} - -bool AggregateContext::useNullSource() -{ - RUNTIME_CHECK(inited_convergent); - return !merging_buckets; -} - Block AggregateContext::readForConvergent(size_t index) { - RUNTIME_CHECK(inited_convergent); + assert(status.load() == AggStatus::convergent); + if unlikely (!merging_buckets) + return {}; return merging_buckets->getData(index); } } // namespace DB diff --git a/dbms/src/Operators/AggregateContext.h b/dbms/src/Operators/AggregateContext.h index 4f29605bd4b..01fb49c3135 100644 --- a/dbms/src/Operators/AggregateContext.h +++ b/dbms/src/Operators/AggregateContext.h @@ -17,6 +17,7 @@ #include #include #include +#include #include namespace DB @@ -50,6 +51,14 @@ class AggregateContext void buildOnBlock(size_t task_index, const Block & block); + bool hasSpilledData() const; + + bool needSpill(size_t task_index, bool try_mark_need_spill = false); + + void spillData(size_t task_index); + + LocalAggregateRestorerPtr buildLocalRestorer(); + void initConvergent(); // Called before convergent to trace aggregate statistics and handle empty table with result case. @@ -61,18 +70,29 @@ class AggregateContext Block getHeader() const; - bool useNullSource(); - -private: - bool isTwoLevel(); - private: std::unique_ptr aggregator; bool keys_size = false; bool empty_result_for_aggregation_by_empty_set = false; - std::atomic_bool inited_build = false; - std::atomic_bool inited_convergent = false; + /** + * init────►build───┬───►convergent + * │ + * ▼ + * restore + */ + enum class AggStatus + { + init, + build, + convergent, + restore, + }; + std::atomic status{AggStatus::init}; + + Aggregator::CancellationHook is_cancelled{[]() { + return false; + }}; MergingBucketsPtr merging_buckets; ManyAggregatedDataVariants many_data; diff --git a/dbms/src/Operators/BucketInput.cpp b/dbms/src/Operators/BucketInput.cpp new file mode 100644 index 00000000000..30458e17b0e --- /dev/null +++ b/dbms/src/Operators/BucketInput.cpp @@ -0,0 +1,65 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace DB +{ +BucketInput::BucketInput(const BlockInputStreamPtr & stream_) + : stream(stream_) +{ + stream->readPrefix(); +} + +bool BucketInput::needLoad() const +{ + return !is_exhausted && !output.has_value(); +} + +void BucketInput::load() +{ + assert(needLoad()); + Block ret = stream->read(); + if unlikely (!ret) + { + is_exhausted = true; + stream->readSuffix(); + } + else + { + /// Only two level data can be spilled. + assert(ret.info.bucket_num != -1); + output.emplace(std::move(ret)); + } +} + +bool BucketInput::hasOutput() const +{ + return output.has_value(); +} + +Int32 BucketInput::bucketNum() const +{ + assert(hasOutput()); + return output->info.bucket_num; +} + +Block BucketInput::moveOutput() +{ + assert(hasOutput()); + Block ret = std::move(*output); + output.reset(); + return ret; +} +} // namespace DB diff --git a/dbms/src/Operators/BucketInput.h b/dbms/src/Operators/BucketInput.h new file mode 100644 index 00000000000..e6c8b9756ef --- /dev/null +++ b/dbms/src/Operators/BucketInput.h @@ -0,0 +1,42 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace DB +{ +/// Used to reading spilled bucket data of aggregator. +class BucketInput +{ +public: + explicit BucketInput(const BlockInputStreamPtr & stream_); + + bool needLoad() const; + void load(); + + bool hasOutput() const; + Int32 bucketNum() const; + Block moveOutput(); + +private: + BlockInputStreamPtr stream; + std::optional output; + bool is_exhausted = false; +}; +using BucketInputs = std::vector; + +} // namespace DB diff --git a/dbms/src/Operators/LocalAggregateRestorer.cpp b/dbms/src/Operators/LocalAggregateRestorer.cpp new file mode 100644 index 00000000000..12527c97e29 --- /dev/null +++ b/dbms/src/Operators/LocalAggregateRestorer.cpp @@ -0,0 +1,102 @@ +// Copyright 2023 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +namespace DB +{ +LocalAggregateRestorer::LocalAggregateRestorer( + const BlockInputStreams & bucket_streams, + Aggregator & aggregator_, + std::function is_cancelled_, + const String & req_id) + : aggregator(aggregator_) + , is_cancelled(std::move(is_cancelled_)) + , log(Logger::get(req_id)) +{ + for (const auto & bucket_stream : bucket_streams) + bucket_inputs.emplace_back(bucket_stream); + assert(!bucket_inputs.empty()); +} + +void LocalAggregateRestorer::storeToBucketData() +{ + assert(!finished); + assert(!bucket_inputs.empty()); + + // get min bucket num. + Int32 min_bucket_num = NUM_BUCKETS; + for (auto & bucket_input : bucket_inputs) + { + if (bucket_input.hasOutput()) + min_bucket_num = std::min(bucket_input.bucketNum(), min_bucket_num); + } + if unlikely (min_bucket_num >= NUM_BUCKETS) + { + assert(!finished); + finished = true; + LOG_DEBUG(log, "local agg restore finished"); + return; + } + + // store bucket data of min bucket num. + for (auto & bucket_input : bucket_inputs) + { + if (bucket_input.hasOutput() && min_bucket_num == bucket_input.bucketNum()) + bucket_data.push_back(bucket_input.moveOutput()); + } + assert(!bucket_data.empty()); +} + +void LocalAggregateRestorer::loadBucketData() +{ + if unlikely (finished || is_cancelled()) + return; + + // load bucket data from inputs. + assert(bucket_data.empty()); + for (auto & bucket_input : bucket_inputs) + { + if unlikely (is_cancelled()) + return; + if (bucket_input.needLoad()) + bucket_input.load(); + } + if unlikely (is_cancelled()) + return; + + storeToBucketData(); +} + +bool LocalAggregateRestorer::tryPop(Block & block) +{ + if unlikely (finished || is_cancelled()) + return true; + + if (restored_blocks.empty()) + { + if (bucket_data.empty()) + return false; + + BlocksList tmp; + std::swap(tmp, bucket_data); + restored_blocks = aggregator.vstackBlocks(tmp, true); + assert(!restored_blocks.empty()); + } + block = std::move(restored_blocks.front()); + restored_blocks.pop_front(); + return true; +} +} // namespace DB diff --git a/dbms/src/Operators/LocalAggregateRestorer.h b/dbms/src/Operators/LocalAggregateRestorer.h new file mode 100644 index 00000000000..df71bcc595a --- /dev/null +++ b/dbms/src/Operators/LocalAggregateRestorer.h @@ -0,0 +1,78 @@ +// Copyright 2023 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include + +namespace DB +{ +class Aggregator; + +/** + * ┌──────────────────────────────────────────────────┐ + * │ {bucket0, bucket1, ... bucket256}spilled_file0──┼────┐ + * │ {bucket0, bucket1, ... bucket256}spilled_file1──┼────┤ + * │ {bucket0, bucket1, ... bucket256}spilled_file2──┼────┤ + * │ ... │ │ + * │ {bucket0, bucket1, ... bucket256}spilled_filen──┼────┤ + * └──────────────────────────────────────────────────┘ │ + * │ loadBucketData + * bucket_data◄──────────────────────────┘ + * │ + * │ tryPop + * ▼ + * restored_blocks + */ +class LocalAggregateRestorer +{ +public: + LocalAggregateRestorer( + const BlockInputStreams & bucket_streams, + Aggregator & aggregator_, + std::function is_cancelled_, + const String & req_id); + + // load data from bucket_inputs to bucket_data. + void loadBucketData(); + + // return true if pop success + // return false means that `loadBucketData` need to be called. + bool tryPop(Block & block); + +private: + void storeToBucketData(); + +private: + Aggregator & aggregator; + + std::function is_cancelled; + + LoggerPtr log; + + bool finished = false; + + // bucket_inputs --> bucket_data --> restored_blocks. + BlocksList bucket_data; + BlocksList restored_blocks; + BucketInputs bucket_inputs; + + static constexpr Int32 NUM_BUCKETS = 256; +}; +using LocalAggregateRestorerPtr = std::unique_ptr; +} // namespace DB diff --git a/dbms/src/Operators/LocalAggregateTransform.cpp b/dbms/src/Operators/LocalAggregateTransform.cpp index 4883e9a0196..df44151eba0 100644 --- a/dbms/src/Operators/LocalAggregateTransform.cpp +++ b/dbms/src/Operators/LocalAggregateTransform.cpp @@ -15,14 +15,16 @@ #include #include +#include + namespace DB { namespace { -/// for local agg, the concurrency of build and convert must both be 1. +/// for local agg, the concurrency of build and convergent must both be 1. constexpr size_t local_concurrency = 1; -/// for local agg, the task_index of build and convert must both be 0. +/// for local agg, the task_index of build and convergent must both be 0. constexpr size_t task_index = 0; } // namespace @@ -42,24 +44,55 @@ OperatorStatus LocalAggregateTransform::transformImpl(Block & block) switch (status) { case LocalAggStatus::build: - if (unlikely(!block)) + if unlikely (!block) { - // status from build to convert. - status = LocalAggStatus::convert; - agg_context.initConvergent(); - if likely (!agg_context.useNullSource()) - { - RUNTIME_CHECK(agg_context.getConvergentConcurrency() == local_concurrency); - block = agg_context.readForConvergent(task_index); - } - return OperatorStatus::HAS_OUTPUT; + return agg_context.hasSpilledData() + ? fromBuildToFinalSpillOrRestore() + : fromBuildToConvergent(block); } agg_context.buildOnBlock(task_index, block); block.clear(); - return OperatorStatus::NEED_INPUT; - case LocalAggStatus::convert: - throw Exception("Unexpected status: convert"); + return tryFromBuildToSpill(); + default: + throw Exception(fmt::format("Unexpected status: {}", magic_enum::enum_name(status))); + } +} + +OperatorStatus LocalAggregateTransform::fromBuildToConvergent(Block & block) +{ + // status from build to convergent. + assert(status == LocalAggStatus::build); + status = LocalAggStatus::convergent; + agg_context.initConvergent(); + RUNTIME_CHECK(agg_context.getConvergentConcurrency() == local_concurrency); + block = agg_context.readForConvergent(task_index); + return OperatorStatus::HAS_OUTPUT; +} + +OperatorStatus LocalAggregateTransform::fromBuildToFinalSpillOrRestore() +{ + assert(status == LocalAggStatus::build); + if (agg_context.needSpill(task_index, /*try_mark_need_spill=*/true)) + { + status = LocalAggStatus::final_spill; + } + else + { + restorer = agg_context.buildLocalRestorer(); + status = LocalAggStatus::restore; } + return OperatorStatus::IO; +} + +OperatorStatus LocalAggregateTransform::tryFromBuildToSpill() +{ + assert(status == LocalAggStatus::build); + if (agg_context.needSpill(task_index)) + { + status = LocalAggStatus::spill; + return OperatorStatus::IO; + } + return OperatorStatus::NEED_INPUT; } OperatorStatus LocalAggregateTransform::tryOutputImpl(Block & block) @@ -68,10 +101,43 @@ OperatorStatus LocalAggregateTransform::tryOutputImpl(Block & block) { case LocalAggStatus::build: return OperatorStatus::NEED_INPUT; - case LocalAggStatus::convert: - if likely (!agg_context.useNullSource()) - block = agg_context.readForConvergent(task_index); + case LocalAggStatus::convergent: + block = agg_context.readForConvergent(task_index); return OperatorStatus::HAS_OUTPUT; + case LocalAggStatus::restore: + return restorer->tryPop(block) + ? OperatorStatus::HAS_OUTPUT + : OperatorStatus::IO; + default: + throw Exception(fmt::format("Unexpected status: {}", magic_enum::enum_name(status))); + } +} + +OperatorStatus LocalAggregateTransform::executeIOImpl() +{ + switch (status) + { + case LocalAggStatus::spill: + { + agg_context.spillData(task_index); + status = LocalAggStatus::build; + return OperatorStatus::NEED_INPUT; + } + case LocalAggStatus::final_spill: + { + agg_context.spillData(task_index); + restorer = agg_context.buildLocalRestorer(); + status = LocalAggStatus::restore; + return OperatorStatus::IO; + } + case LocalAggStatus::restore: + { + assert(restorer); + restorer->loadBucketData(); + return OperatorStatus::HAS_OUTPUT; + } + default: + throw Exception(fmt::format("Unexpected status: {}", magic_enum::enum_name(status))); } } diff --git a/dbms/src/Operators/LocalAggregateTransform.h b/dbms/src/Operators/LocalAggregateTransform.h index 7c5ff2639bb..20fe5fd3b45 100644 --- a/dbms/src/Operators/LocalAggregateTransform.h +++ b/dbms/src/Operators/LocalAggregateTransform.h @@ -19,12 +19,6 @@ namespace DB { -enum class LocalAggStatus -{ - build, - convert, -}; - /// Only do build and convert at the current operator, no sharing of objects with other operators. class LocalAggregateTransform : public TransformOp { @@ -44,12 +38,43 @@ class LocalAggregateTransform : public TransformOp OperatorStatus tryOutputImpl(Block & block) override; + OperatorStatus executeIOImpl() override; + void transformHeaderImpl(Block & header_) override; +private: + OperatorStatus tryFromBuildToSpill(); + + OperatorStatus fromBuildToConvergent(Block & block); + + OperatorStatus fromBuildToFinalSpillOrRestore(); + private: Aggregator::Params params; AggregateContext agg_context; + /** + * spill◄────►build────┬─────────────►restore + * │ │ ▲ + * │ └───►final_spill──┘ + * ▼ + * convergent + */ + enum class LocalAggStatus + { + // Accept the block and build aggregate data. + build, + // spill the aggregate data into disk. + spill, + // convert the aggregate data to block and then output it. + convergent, + // spill the rest remaining memory aggregate data. + final_spill, + // load the disk aggregate data to memory and then convert to block and output it. + restore, + }; LocalAggStatus status{LocalAggStatus::build}; + + LocalAggregateRestorerPtr restorer; }; } // namespace DB diff --git a/dbms/src/TestUtils/ExecutorTestUtils.cpp b/dbms/src/TestUtils/ExecutorTestUtils.cpp index e010bf73f04..fad8bb6af87 100644 --- a/dbms/src/TestUtils/ExecutorTestUtils.cpp +++ b/dbms/src/TestUtils/ExecutorTestUtils.cpp @@ -288,31 +288,30 @@ void ExecutorTest::enablePipeline(bool is_enable) const DB::ColumnsWithTypeAndName ExecutorTest::executeStreams( const std::shared_ptr & request, - size_t concurrency, - bool enable_memory_tracker) + size_t concurrency) { DAGContext dag_context(*request, "executor_test", concurrency); - return executeStreams(&dag_context, enable_memory_tracker); + return executeStreams(&dag_context); } -ColumnsWithTypeAndName ExecutorTest::executeStreams(DAGContext * dag_context, bool enable_memory_tracker) +ColumnsWithTypeAndName ExecutorTest::executeStreams(DAGContext * dag_context) { TiFlashTestEnv::setUpTestContext(*context.context, dag_context, context.mockStorage(), TestType::EXECUTOR_TEST); // Currently, don't care about regions information in tests. Blocks blocks; - queryExecute(*context.context, /*internal=*/!enable_memory_tracker)->execute([&blocks](const Block & block) { blocks.push_back(block); }).verify(); + queryExecute(*context.context, /*internal=*/true)->execute([&blocks](const Block & block) { blocks.push_back(block); }).verify(); return vstackBlocks(std::move(blocks)).getColumnsWithTypeAndName(); } -Blocks ExecutorTest::getExecuteStreamsReturnBlocks(const std::shared_ptr & request, - size_t concurrency, - bool enable_memory_tracker) +Blocks ExecutorTest::getExecuteStreamsReturnBlocks( + const std::shared_ptr & request, + size_t concurrency) { DAGContext dag_context(*request, "executor_test", concurrency); TiFlashTestEnv::setUpTestContext(*context.context, &dag_context, context.mockStorage(), TestType::EXECUTOR_TEST); // Currently, don't care about regions information in tests. Blocks blocks; - queryExecute(*context.context, /*internal=*/!enable_memory_tracker)->execute([&blocks](const Block & block) { blocks.push_back(block); }).verify(); + queryExecute(*context.context, /*internal=*/true)->execute([&blocks](const Block & block) { blocks.push_back(block); }).verify(); return blocks; } diff --git a/dbms/src/TestUtils/ExecutorTestUtils.h b/dbms/src/TestUtils/ExecutorTestUtils.h index 669b7a92415..c9095361936 100644 --- a/dbms/src/TestUtils/ExecutorTestUtils.h +++ b/dbms/src/TestUtils/ExecutorTestUtils.h @@ -113,17 +113,15 @@ class ExecutorTest : public ::testing::Test } } - ColumnsWithTypeAndName executeStreams(DAGContext * dag_context, bool enable_memory_tracker = false); + ColumnsWithTypeAndName executeStreams(DAGContext * dag_context); ColumnsWithTypeAndName executeStreams( const std::shared_ptr & request, - size_t concurrency = 1, - bool enable_memory_tracker = false); + size_t concurrency = 1); Blocks getExecuteStreamsReturnBlocks( const std::shared_ptr & request, - size_t concurrency = 1, - bool enable_memory_tracker = false); + size_t concurrency = 1); /// test execution summary //