Skip to content

Commit

Permalink
join be aware of cancel signal (#9450)
Browse files Browse the repository at this point in the history
close #9430

Signed-off-by: xufei <[email protected]>
Signed-off-by: xufei <[email protected]>

Co-authored-by: Liqi Geng <[email protected]>
  • Loading branch information
windtalker and gengliqi authored Sep 21, 2024
1 parent d22df76 commit 8aba9f0
Show file tree
Hide file tree
Showing 19 changed files with 106 additions and 27 deletions.
2 changes: 1 addition & 1 deletion dbms/src/DataStreams/AggregatingBlockInputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Block AggregatingBlockInputStream::readImpl()
executed = true;
AggregatedDataVariantsPtr data_variants = std::make_shared<AggregatedDataVariants>();

Aggregator::CancellationHook hook = [&]() {
CancellationHook hook = [&]() {
return this->isCancelled();
};
aggregator.setCancellationHook(hook);
Expand Down
2 changes: 0 additions & 2 deletions dbms/src/DataStreams/HashJoinProbeExec.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class HashJoinProbeExec : public std::enable_shared_from_this<HashJoinProbeExec>
const BlockInputStreamPtr & probe_stream,
size_t max_block_size);

using CancellationHook = std::function<bool()>;

HashJoinProbeExec(
const String & req_id,
const JoinPtr & join_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Block ParallelAggregatingBlockInputStream::readImpl()
{
if (!executed)
{
Aggregator::CancellationHook hook = [&]() {
CancellationHook hook = [&]() {
return this->isCancelled();
};
aggregator.setCancellationHook(hook);
Expand Down
10 changes: 7 additions & 3 deletions dbms/src/Flash/Mpp/MPPTask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,9 @@ void MPPTask::runImpl()
GET_RESOURCE_GROUP_METRIC(tiflash_resource_group, type_handling_mpp_task_run, resource_group).Decrement();
});

// set cancellation hook
context->setCancellationHook([this] { return is_cancelled.load(); });

String err_msg;
try
{
Expand Down Expand Up @@ -750,7 +753,7 @@ void MPPTask::abort(const String & message, AbortType abort_type)
if (previous_status == FINISHED || previous_status == CANCELLED || previous_status == FAILED)
{
LOG_WARNING(log, "task already in {} state", magic_enum::enum_name(previous_status));
return;
break;
}
else if (previous_status == INITIALIZING && switchStatus(INITIALIZING, next_task_status))
{
Expand All @@ -759,7 +762,7 @@ void MPPTask::abort(const String & message, AbortType abort_type)
/// so just close all tunnels here
abortTunnels("", false);
LOG_WARNING(log, "Finish abort task from uninitialized");
return;
break;
}
else if (previous_status == RUNNING && switchStatus(RUNNING, next_task_status))
{
Expand All @@ -773,9 +776,10 @@ void MPPTask::abort(const String & message, AbortType abort_type)
scheduleThisTask(ScheduleState::FAILED);
/// runImpl is running, leave remaining work to runImpl
LOG_WARNING(log, "Finish abort task from running");
return;
break;
}
}
is_cancelled = true;
}

bool MPPTask::switchStatus(TaskStatus from, TaskStatus to)
Expand Down
1 change: 1 addition & 0 deletions dbms/src/Flash/Mpp/MPPTask.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class MPPTask

MPPTaskManager * manager;
std::atomic<bool> is_registered{false};
std::atomic<bool> is_cancelled{false};

MPPTaskScheduleEntry schedule_entry;

Expand Down
1 change: 1 addition & 0 deletions dbms/src/Flash/Planner/Plans/PhysicalJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ void PhysicalJoin::probeSideTransform(DAGPipeline & probe_pipeline, Context & co
execId(),
needScanHashMapAfterProbe(join_ptr->getKind()));
join_ptr->initProbe(probe_pipeline.firstStream()->getHeader(), probe_pipeline.streams.size());
join_ptr->setCancellationHook([&] { return context.isCancelled(); });
size_t probe_index = 0;
for (auto & stream : probe_pipeline.streams)
{
Expand Down
2 changes: 2 additions & 0 deletions dbms/src/Flash/Planner/Plans/PhysicalJoinBuild.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <Flash/Coprocessor/DAGContext.h>
#include <Flash/Coprocessor/InterpreterUtils.h>
#include <Flash/Executor/PipelineExecutorContext.h>
#include <Flash/Pipeline/Exec/PipelineExecBuilder.h>
#include <Flash/Planner/Plans/PhysicalJoinBuild.h>
#include <Interpreters/Context.h>
Expand All @@ -39,6 +40,7 @@ void PhysicalJoinBuild::buildPipelineExecGroupImpl(
join_execute_info.join_build_profile_infos = group_builder.getCurProfileInfos();
join_ptr->initBuild(group_builder.getCurrentHeader(), group_builder.concurrency());
join_ptr->setInitActiveBuildThreads();
join_ptr->setCancellationHook([&]() { return exec_context.isCancelled(); });
join_ptr.reset();
}
} // namespace DB
3 changes: 1 addition & 2 deletions dbms/src/Interpreters/Aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <Interpreters/AggSpillContext.h>
#include <Interpreters/AggregateDescription.h>
#include <Interpreters/AggregationCommon.h>
#include <Interpreters/CancellationHook.h>
#include <TiDB/Collation/Collator.h>
#include <common/StringRef.h>
#include <common/logger_useful.h>
Expand Down Expand Up @@ -1370,8 +1371,6 @@ class Aggregator
*/
Blocks convertBlockToTwoLevel(const Block & block);

using CancellationHook = std::function<bool()>;

/** Set a function that checks whether the current task can be aborted.
*/
void setCancellationHook(CancellationHook cancellation_hook);
Expand Down
22 changes: 22 additions & 0 deletions dbms/src/Interpreters/CancellationHook.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <functional>

namespace DB
{
using CancellationHook = std::function<bool()>;
} // namespace DB
16 changes: 12 additions & 4 deletions dbms/src/Interpreters/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <Core/Types.h>
#include <Debug/MockServerInfo.h>
#include <IO/FileProvider/FileProvider_fwd.h>
#include <Interpreters/CancellationHook.h>
#include <Interpreters/ClientInfo.h>
#include <Interpreters/Context_fwd.h>
#include <Interpreters/Settings.h>
Expand All @@ -34,6 +35,7 @@
#include <mutex>
#include <thread>


namespace pingcap
{
struct ClusterConfig;
Expand Down Expand Up @@ -180,6 +182,9 @@ class Context
TimezoneInfo timezone_info;

DAGContext * dag_context = nullptr;
CancellationHook is_cancelled{[]() {
return false;
}};
using DatabasePtr = std::shared_ptr<IDatabase>;
using Databases = std::map<String, std::shared_ptr<IDatabase>>;
/// Use copy constructor or createGlobal() instead
Expand Down Expand Up @@ -237,8 +242,8 @@ class Context
/// Compute and set actual user settings, client_info.current_user should be set
void calculateUserSettings();

ClientInfo & getClientInfo() { return client_info; };
const ClientInfo & getClientInfo() const { return client_info; };
ClientInfo & getClientInfo() { return client_info; }
const ClientInfo & getClientInfo() const { return client_info; }

void setQuota(
const String & name,
Expand Down Expand Up @@ -375,6 +380,9 @@ class Context
void setDAGContext(DAGContext * dag_context);
DAGContext * getDAGContext() const;

bool isCancelled() const { return is_cancelled(); }
void setCancellationHook(CancellationHook cancellation_hook) { is_cancelled = cancellation_hook; }

/// List all queries.
ProcessList & getProcessList();
const ProcessList & getProcessList() const;
Expand Down Expand Up @@ -505,8 +513,8 @@ class Context

SharedQueriesPtr getSharedQueries();

const TimezoneInfo & getTimezoneInfo() const { return timezone_info; };
TimezoneInfo & getTimezoneInfo() { return timezone_info; };
const TimezoneInfo & getTimezoneInfo() const { return timezone_info; }
TimezoneInfo & getTimezoneInfo() { return timezone_info; }

/// User name and session identifier. Named sessions are local to users.
using SessionKey = std::pair<String, String>;
Expand Down
28 changes: 25 additions & 3 deletions dbms/src/Interpreters/Join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1382,6 +1382,8 @@ Block Join::joinBlockHash(ProbeProcessInfo & probe_process_info) const
restore_config.restore_round);
while (true)
{
if (is_cancelled())
return {};
auto block = doJoinBlockHash(probe_process_info, join_build_info);
assert(block);
block = removeUselessColumn(block);
Expand Down Expand Up @@ -1486,6 +1488,8 @@ Block Join::joinBlockCross(ProbeProcessInfo & probe_process_info) const

while (true)
{
if (is_cancelled())
return {};
Block block = doJoinBlockCross(probe_process_info);
assert(block);
block = removeUselessColumn(block);
Expand Down Expand Up @@ -1567,6 +1571,9 @@ Block Join::joinBlockNullAwareSemiImpl(const ProbeProcessInfo & probe_process_in

RUNTIME_ASSERT(res.size() == rows, "SemiJoinResult size {} must be equal to block size {}", res.size(), rows);

if (is_cancelled())
return {};

Block block{};
for (size_t i = 0; i < probe_process_info.block.columns(); ++i)
{
Expand Down Expand Up @@ -1603,13 +1610,17 @@ Block Join::joinBlockNullAwareSemiImpl(const ProbeProcessInfo & probe_process_in
blocks,
null_rows,
max_block_size,
non_equal_conditions);
non_equal_conditions,
is_cancelled);

helper.joinResult(res_list);

RUNTIME_CHECK_MSG(res_list.empty(), "NASemiJoinResult list must be empty after calculating join result");
}

if (is_cancelled())
return {};

/// Now all results are known.

std::unique_ptr<IColumn::Filter> filter;
Expand Down Expand Up @@ -1754,6 +1765,8 @@ Block Join::joinBlockSemiImpl(const JoinBuildInfo & join_build_info, const Probe
probe_process_info);

RUNTIME_ASSERT(res.size() == rows, "SemiJoinResult size {} must be equal to block size {}", res.size(), rows);
if (is_cancelled())
return {};

const NameSet & probe_output_name_set = has_other_condition
? output_columns_names_set_for_other_condition_after_finalize
Expand Down Expand Up @@ -1788,15 +1801,23 @@ Block Join::joinBlockSemiImpl(const JoinBuildInfo & join_build_info, const Probe
{
if (!res_list.empty())
{
SemiJoinHelper<KIND, typename Maps::MappedType>
helper(block, left_columns, right_column_indices_to_add, max_block_size, non_equal_conditions);
SemiJoinHelper<KIND, typename Maps::MappedType> helper(
block,
left_columns,
right_column_indices_to_add,
max_block_size,
non_equal_conditions,
is_cancelled);

helper.joinResult(res_list);

RUNTIME_CHECK_MSG(res_list.empty(), "SemiJoinResult list must be empty after calculating join result");
}
}

if (is_cancelled())
return {};

/// Now all results are known.

std::unique_ptr<IColumn::Filter> filter;
Expand Down Expand Up @@ -2469,6 +2490,7 @@ std::optional<RestoreInfo> Join::getOneRestoreStream(size_t max_block_size_)
restore_join->initBuild(build_sample_block, restore_join_build_concurrency);
restore_join->setInitActiveBuildThreads();
restore_join->initProbe(probe_sample_block, restore_join_build_concurrency);
restore_join->setCancellationHook(is_cancelled);
BlockInputStreams restore_scan_hash_map_streams;
restore_scan_hash_map_streams.resize(restore_join_build_concurrency, nullptr);
if (needScanHashMapAfterProbe(kind))
Expand Down
6 changes: 6 additions & 0 deletions dbms/src/Interpreters/Join.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <Flash/Coprocessor/JoinInterpreterHelper.h>
#include <Flash/Coprocessor/RuntimeFilterMgr.h>
#include <Interpreters/AggregationCommon.h>
#include <Interpreters/CancellationHook.h>
#include <Interpreters/ExpressionActions.h>
#include <Interpreters/HashJoinSpillContext.h>
#include <Interpreters/JoinHashMap.h>
Expand Down Expand Up @@ -314,6 +315,8 @@ class Join
void flushProbeSideMarkedSpillData(size_t stream_index);
size_t getProbeCacheColumnThreshold() const { return probe_cache_column_threshold; }

void setCancellationHook(CancellationHook cancellation_hook) { is_cancelled = cancellation_hook; }

static const String match_helper_prefix;
static const DataTypePtr match_helper_type;
static const String flag_mapped_entry_helper_prefix;
Expand Down Expand Up @@ -452,6 +455,9 @@ class Join
// the index of vector is the stream_index.
std::vector<MarkedSpillData> build_side_marked_spilled_data;
std::vector<MarkedSpillData> probe_side_marked_spilled_data;
CancellationHook is_cancelled{[]() {
return false;
}};

private:
/** Set information about structure of right hand of JOIN (joined data).
Expand Down
14 changes: 10 additions & 4 deletions dbms/src/Interpreters/NullAwareSemiJoinHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,15 @@ NASemiJoinHelper<KIND, STRICTNESS, Mapped>::NASemiJoinHelper(
const BlocksList & right_blocks_,
const std::vector<RowsNotInsertToMap *> & null_rows_,
size_t max_block_size_,
const JoinNonEqualConditions & non_equal_conditions_)
const JoinNonEqualConditions & non_equal_conditions_,
CancellationHook is_cancelled_)
: block(block_)
, left_columns(left_columns_)
, right_column_indices_to_add(right_column_indices_to_add_)
, right_blocks(right_blocks_)
, null_rows(null_rows_)
, max_block_size(max_block_size_)
, is_cancelled(is_cancelled_)
, non_equal_conditions(non_equal_conditions_)
{
static_assert(KIND == NullAware_Anti || KIND == NullAware_LeftOuterAnti || KIND == NullAware_LeftOuterSemi);
Expand Down Expand Up @@ -280,17 +282,17 @@ void NASemiJoinHelper<KIND, STRICTNESS, Mapped>::joinResult(std::list<NASemiJoin
res_list.swap(next_step_res_list);
}

if (res_list.empty())
if (is_cancelled() || res_list.empty())
return;

runStep<NASemiJoinStep::NOT_NULL_KEY_CHECK_NULL_ROWS>(res_list, next_step_res_list);
res_list.swap(next_step_res_list);
if (res_list.empty())
if (is_cancelled() || res_list.empty())
return;

runStep<NASemiJoinStep::NULL_KEY_CHECK_NULL_ROWS>(res_list, next_step_res_list);
res_list.swap(next_step_res_list);
if (res_list.empty())
if (is_cancelled() || res_list.empty())
return;

runStepAllBlocks(res_list);
Expand Down Expand Up @@ -324,6 +326,8 @@ void NASemiJoinHelper<KIND, STRICTNESS, Mapped>::runStep(

while (!res_list.empty())
{
if (is_cancelled())
return;
MutableColumns columns(block_columns);
for (size_t i = 0; i < block_columns; ++i)
{
Expand Down Expand Up @@ -384,6 +388,8 @@ void NASemiJoinHelper<KIND, STRICTNESS, Mapped>::runStepAllBlocks(std::list<NASe
NASemiJoinHelper::Result * res = *res_list.begin();
for (const auto & right_block : right_blocks)
{
if (is_cancelled())
return;
if (res->getStep() == NASemiJoinStep::DONE)
break;

Expand Down
5 changes: 4 additions & 1 deletion dbms/src/Interpreters/NullAwareSemiJoinHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <Common/Logger.h>
#include <Core/Block.h>
#include <Flash/Coprocessor/JoinInterpreterHelper.h>
#include <Interpreters/CancellationHook.h>
#include <Interpreters/SemiJoinHelper.h>
#include <Parsers/ASTTablesInSelectQuery.h>

Expand Down Expand Up @@ -167,7 +168,8 @@ class NASemiJoinHelper
const BlocksList & right_blocks,
const std::vector<RowsNotInsertToMap *> & null_rows,
size_t max_block_size,
const JoinNonEqualConditions & non_equal_conditions);
const JoinNonEqualConditions & non_equal_conditions,
CancellationHook is_cancelled_);

void joinResult(std::list<Result *> & res_list);

Expand All @@ -192,6 +194,7 @@ class NASemiJoinHelper
const BlocksList & right_blocks;
const std::vector<RowsNotInsertToMap *> & null_rows;
size_t max_block_size;
CancellationHook is_cancelled;

const JoinNonEqualConditions & non_equal_conditions;
};
Expand Down
Loading

0 comments on commit 8aba9f0

Please sign in to comment.