Skip to content

Commit

Permalink
fix mem tracker 0609
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyiZzz committed Jun 11, 2022
1 parent 3f575e3 commit 590dc6a
Show file tree
Hide file tree
Showing 14 changed files with 88 additions and 58 deletions.
1 change: 0 additions & 1 deletion be/src/exec/tablet_sink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,6 @@ Status NodeChannel::none_of(std::initializer_list<bool> vars) {
}

void NodeChannel::clear_all_batches() {
SCOPED_SWITCH_THREAD_LOCAL_MEM_TRACKER(_node_channel_tracker);
std::lock_guard<std::mutex> lg(_pending_batches_lock);
std::queue<AddBatchReq> empty;
std::swap(_pending_batches, empty);
Expand Down
1 change: 0 additions & 1 deletion be/src/exec/tablet_sink.h
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,6 @@ class IndexChannel {

void for_each_node_channel(
const std::function<void(const std::shared_ptr<NodeChannel>&)>& func) {
SCOPED_SWITCH_THREAD_LOCAL_MEM_TRACKER(_index_channel_tracker);
for (auto& it : _node_channels) {
func(it.second);
}
Expand Down
2 changes: 1 addition & 1 deletion be/src/gutil/strings/numbers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1488,7 +1488,7 @@ string AccurateItoaKMGT(int64 i) {
i = -i;
}

string ret = std::to_string(i) + " = " + StringPrintf("%s", sign);
string ret = StringPrintf("%s", sign) + std::to_string(i) + " = " + StringPrintf("%s", sign);
int64 val;
if ((val = (i >> 40)) > 1) {
ret += StringPrintf("%" PRId64
Expand Down
9 changes: 5 additions & 4 deletions be/src/runtime/load_channel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,18 @@

namespace doris {

LoadChannel::LoadChannel(const UniqueId& load_id, int64_t mem_limit, int64_t timeout_s,
bool is_high_priority, const std::string& sender_ip, bool is_vec)
LoadChannel::LoadChannel(const UniqueId& load_id, int64_t load_mem_limit, int64_t channel_mem_limit,
int64_t timeout_s, bool is_high_priority, const std::string& sender_ip,
bool is_vec)
: _load_id(load_id),
_timeout_s(timeout_s),
_is_high_priority(is_high_priority),
_sender_ip(sender_ip),
_is_vec(is_vec) {
_mem_tracker = MemTracker::create_tracker(
mem_limit, "LoadChannel:tabletId=" + _load_id.to_string(),
channel_mem_limit, "LoadChannel:senderIp=" + sender_ip,
ExecEnv::GetInstance()->task_pool_mem_tracker_registry()->register_load_mem_tracker(
_load_id.to_string(), mem_limit),
_load_id.to_string(), load_mem_limit),
MemTrackerLevel::TASK);
// _last_updated_time should be set before being inserted to
// _load_channels in load_channel_mgr, or it may be erased
Expand Down
5 changes: 3 additions & 2 deletions be/src/runtime/load_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ class Cache;
// corresponding to a certain load job
class LoadChannel {
public:
LoadChannel(const UniqueId& load_id, int64_t mem_limit, int64_t timeout_s,
bool is_high_priority, const std::string& sender_ip, bool is_vec);
LoadChannel(const UniqueId& load_id, int64_t load_mem_limit, int64_t channel_mem_limit,
int64_t timeout_s, bool is_high_priority, const std::string& sender_ip,
bool is_vec);
~LoadChannel();

// open a new load channel if not exist
Expand Down
44 changes: 23 additions & 21 deletions be/src/runtime/load_channel_mgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,22 @@ static int64_t calc_process_max_load_memory(int64_t process_mem_limit) {
return std::min<int64_t>(max_load_memory_bytes, config::load_process_max_memory_limit_bytes);
}

// Calculate the memory limit for a single load job.
static int64_t calc_job_max_load_memory(int64_t mem_limit_in_req, int64_t total_mem_limit) {
// Calculate the memory limit for a single load channel.
static int64_t calc_channel_max_load_memory(int64_t load_mem_limit, int64_t total_mem_limit) {
// default mem limit is used to be compatible with old request.
// new request should be set load_mem_limit.
constexpr int64_t default_load_mem_limit = 2 * 1024 * 1024 * 1024L; // 2GB
int64_t load_mem_limit = default_load_mem_limit;
if (mem_limit_in_req != -1) {
constexpr int64_t default_channel_mem_limit = 2 * 1024 * 1024 * 1024L; // 2GB
int64_t channel_mem_limit = default_channel_mem_limit;
if (load_mem_limit != -1) {
// mem-limit of a certain load should between config::write_buffer_size
// and total-memory-limit
load_mem_limit = std::max<int64_t>(mem_limit_in_req, config::write_buffer_size);
load_mem_limit = std::min<int64_t>(load_mem_limit, total_mem_limit);
channel_mem_limit = std::max<int64_t>(load_mem_limit, config::write_buffer_size);
channel_mem_limit = std::min<int64_t>(channel_mem_limit, total_mem_limit);
}
return load_mem_limit;
return channel_mem_limit;
}

static int64_t calc_job_timeout_s(int64_t timeout_in_req_s) {
static int64_t calc_channel_timeout_s(int64_t timeout_in_req_s) {
int64_t load_channel_timeout_s = config::streaming_load_rpc_max_alive_time_sec;
if (timeout_in_req_s > 0) {
load_channel_timeout_s = std::max<int64_t>(load_channel_timeout_s, timeout_in_req_s);
Expand All @@ -83,8 +83,8 @@ LoadChannelMgr::~LoadChannelMgr() {
}

Status LoadChannelMgr::init(int64_t process_mem_limit) {
int64_t load_mem_limit = calc_process_max_load_memory(process_mem_limit);
_mem_tracker = MemTracker::create_tracker(load_mem_limit, "LoadChannelMgr",
int64_t load_mgr_mem_limit = calc_process_max_load_memory(process_mem_limit);
_mem_tracker = MemTracker::create_tracker(load_mgr_mem_limit, "LoadChannelMgr",
MemTracker::get_process_tracker(),
MemTrackerLevel::OVERVIEW);
SCOPED_SWITCH_THREAD_LOCAL_MEM_TRACKER(_mem_tracker);
Expand All @@ -95,10 +95,12 @@ Status LoadChannelMgr::init(int64_t process_mem_limit) {
return Status::OK();
}

LoadChannel* LoadChannelMgr::_create_load_channel(const UniqueId& load_id, int64_t mem_limit,
int64_t timeout_s, bool is_high_priority,
LoadChannel* LoadChannelMgr::_create_load_channel(const UniqueId& load_id, int64_t load_mem_limit,
int64_t channel_mem_limit, int64_t timeout_s,
bool is_high_priority,
const std::string& sender_ip, bool is_vec) {
return new LoadChannel(load_id, mem_limit, timeout_s, is_high_priority, sender_ip, is_vec);
return new LoadChannel(load_id, load_mem_limit, channel_mem_limit, timeout_s, is_high_priority,
sender_ip, is_vec);
}

Status LoadChannelMgr::open(const PTabletWriterOpenRequest& params) {
Expand All @@ -112,18 +114,18 @@ Status LoadChannelMgr::open(const PTabletWriterOpenRequest& params) {
channel = it->second;
} else {
// create a new load channel
int64_t mem_limit_in_req = params.has_load_mem_limit() ? params.load_mem_limit() : -1;
int64_t job_max_memory =
calc_job_max_load_memory(mem_limit_in_req, _mem_tracker->limit());
int64_t load_mem_limit = params.has_load_mem_limit() ? params.load_mem_limit() : -1;
int64_t channel_mem_limit =
calc_channel_max_load_memory(load_mem_limit, _mem_tracker->limit());

int64_t timeout_in_req_s =
params.has_load_channel_timeout_s() ? params.load_channel_timeout_s() : -1;
int64_t job_timeout_s = calc_job_timeout_s(timeout_in_req_s);
int64_t channel_timeout_s = calc_channel_timeout_s(timeout_in_req_s);

bool is_high_priority = (params.has_is_high_priority() && params.is_high_priority());
channel.reset(_create_load_channel(load_id, job_max_memory, job_timeout_s,
is_high_priority, params.sender_ip(),
params.is_vectorized()));
channel.reset(_create_load_channel(load_id, load_mem_limit, channel_mem_limit,
channel_timeout_s, is_high_priority,
params.sender_ip(), params.is_vectorized()));
_load_channels.insert({load_id, channel});
}
}
Expand Down
7 changes: 4 additions & 3 deletions be/src/runtime/load_channel_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ class LoadChannelMgr {
std::shared_ptr<MemTracker> mem_tracker() { return _mem_tracker; }

private:
static LoadChannel* _create_load_channel(const UniqueId& load_id, int64_t mem_limit,
int64_t timeout_s, bool is_high_priority,
const std::string& sender_ip, bool is_vec);
static LoadChannel* _create_load_channel(const UniqueId& load_id, int64_t load_mem_limit,
int64_t channel_mem_limit, int64_t timeout_s,
bool is_high_priority, const std::string& sender_ip,
bool is_vec);

template <typename Request>
Status _get_load_channel(std::shared_ptr<LoadChannel>& channel, bool& is_eof,
Expand Down
6 changes: 5 additions & 1 deletion be/src/runtime/mem_tracker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,17 @@ std::shared_ptr<MemTracker> MemTracker::create_tracker_impl(
std::string reset_label;
MemTracker* task_parent_tracker = reset_parent->parent_task_mem_tracker();
if (task_parent_tracker) {
reset_label = fmt::format("{}:{}", label, split(task_parent_tracker->label(), ":")[1]);
reset_label = fmt::format("{}&{}", label, split(task_parent_tracker->label(), "&")[1]);
} else {
reset_label = label;
}
if (byte_limit == -1) byte_limit = reset_parent->limit();

std::shared_ptr<MemTracker> tracker(
new MemTracker(byte_limit, reset_label, reset_parent,
level > reset_parent->_level ? level : reset_parent->_level, profile));
// Do not check limit exceed when add_child_tracker, otherwise it will cause deadlock when log_usage is called.
STOP_CHECK_LIMIT_THREAD_LOCAL_MEM_TRACKER();
reset_parent->add_child_tracker(tracker);
return tracker;
}
Expand Down Expand Up @@ -285,6 +288,7 @@ std::string MemTracker::log_usage(int max_recursive_depth,

Status MemTracker::mem_limit_exceeded(RuntimeState* state, const std::string& details,
int64_t failed_allocation_size, Status failed_alloc) {
STOP_CHECK_LIMIT_THREAD_LOCAL_MEM_TRACKER();
MemTracker* process_tracker = MemTracker::get_raw_process_tracker();
std::string detail =
"Memory exceed limit. fragment={}, details={}, on backend={}. Memory left in process "
Expand Down
3 changes: 1 addition & 2 deletions be/src/runtime/mem_tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,7 @@ class MemTracker {
/// 'failed_allocation_size' is zero, nothing about the allocation size is logged.
/// If 'state' is non-nullptr, logs the error to 'state'.
Status mem_limit_exceeded(RuntimeState* state, const std::string& details = std::string(),
int64_t failed_allocation = -1,
Status failed_alloc = Status::OK()) WARN_UNUSED_RESULT;
int64_t failed_allocation = -1, Status failed_alloc = Status::OK());

// Usually, a negative values means that the statistics are not accurate,
// 1. The released memory is not consumed.
Expand Down
23 changes: 19 additions & 4 deletions be/src/runtime/mem_tracker_task_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,17 @@ std::shared_ptr<MemTracker> MemTrackerTaskPool::register_query_mem_tracker(
VLOG_FILE << "Register Query memory tracker, query id: " << query_id
<< " limit: " << PrettyPrinter::print(mem_limit, TUnit::BYTES);
return register_task_mem_tracker_impl(query_id, mem_limit,
fmt::format("Query:queryId={}", query_id),
fmt::format("Query&queryId={}", query_id),
ExecEnv::GetInstance()->query_pool_mem_tracker());
}

std::shared_ptr<MemTracker> MemTrackerTaskPool::register_load_mem_tracker(
const std::string& load_id, int64_t mem_limit) {
// In load, the query id of the fragment is executed, which is the same as the load id of the load channel.
VLOG_FILE << "Register Load memory tracker, load id: " << load_id
<< " limit: " << PrettyPrinter::print(mem_limit, TUnit::BYTES);
return register_task_mem_tracker_impl(load_id, mem_limit,
fmt::format("Load:loadId={}", load_id),
fmt::format("Load&loadId={}", load_id),
ExecEnv::GetInstance()->load_pool_mem_tracker());
}

Expand All @@ -66,8 +67,13 @@ std::shared_ptr<MemTracker> MemTrackerTaskPool::get_task_mem_tracker(const std::
void MemTrackerTaskPool::logout_task_mem_tracker() {
std::vector<std::string> expired_tasks;
for (auto it = _task_mem_trackers.begin(); it != _task_mem_trackers.end(); it++) {
// No RuntimeState uses this task MemTracker, it is only referenced by this map, delete it
if (it->second.use_count() == 1) {
if (!it->second) {
// when parallel querying, after phmap _task_mem_trackers.erase,
// there have been cases where the key still exists in _task_mem_trackers.
// https://github.com/apache/incubator-doris/issues/10006
expired_tasks.emplace_back(it->first);
} else if (it->second.use_count() == 1) {
// No RuntimeState uses this task MemTracker, it is only referenced by this map, delete it
if (config::memory_leak_detection && it->second->consumption() != 0) {
// If consumption is not equal to 0 before query mem tracker is destructed,
// there are two possibilities in theory.
Expand All @@ -86,13 +92,22 @@ void MemTrackerTaskPool::logout_task_mem_tracker() {
it->second->parent()->consume_local(-it->second->consumption(),
MemTracker::get_process_tracker().get());
expired_tasks.emplace_back(it->first);
} else {
// Log limit exceeded query tracker.
if (it->second->limit_exceeded()) {
it->second->mem_limit_exceeded(
nullptr,
fmt::format("Task mem limit exceeded but no cancel, queryId:{}", it->first),
0, Status::OK());
}
}
}
for (auto tid : expired_tasks) {
// This means that after all RuntimeState is destructed,
// there are still task mem trackers that are get or register.
// The only known case: after an load task ends all fragments on a BE,`tablet_writer_open` is still
// called to create a channel, and the load task tracker will be re-registered in the channel open.
// https://github.com/apache/incubator-doris/issues/9905
if (_task_mem_trackers[tid].use_count() == 1) {
_task_mem_trackers.erase(tid);
VLOG_FILE << "Deregister task memory tracker, task id: " << tid;
Expand Down
8 changes: 5 additions & 3 deletions be/src/runtime/thread_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,14 @@ SwitchThreadMemTracker<Existed>::~SwitchThreadMemTracker() {
#endif // USE_MEM_TRACKER
}

SwitchThreadMemTrackerErrCallBack::SwitchThreadMemTrackerErrCallBack(
const std::string& action_type, bool cancel_work, ERRCALLBACK err_call_back_func) {
SwitchThreadMemTrackerErrCallBack::SwitchThreadMemTrackerErrCallBack(const std::string& action_type,
bool cancel_work,
ERRCALLBACK err_call_back_func,
bool log_limit_exceeded) {
#ifdef USE_MEM_TRACKER
DCHECK(action_type != std::string());
_old_tracker_cb = tls_ctx()->_thread_mem_tracker_mgr->update_consume_err_cb(
action_type, cancel_work, err_call_back_func);
action_type, cancel_work, err_call_back_func, log_limit_exceeded);
#endif
}

Expand Down
3 changes: 2 additions & 1 deletion be/src/runtime/thread_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ class SwitchThreadMemTrackerErrCallBack {
public:
explicit SwitchThreadMemTrackerErrCallBack(const std::string& action_type,
bool cancel_work = true,
ERRCALLBACK err_call_back_func = nullptr);
ERRCALLBACK err_call_back_func = nullptr,
bool log_limit_exceeded = true);

~SwitchThreadMemTrackerErrCallBack();

Expand Down
17 changes: 8 additions & 9 deletions be/src/runtime/thread_mem_tracker_mgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,24 @@ void ThreadMemTrackerMgr::exceeded_cancel_task(const std::string& cancel_details
ExecEnv::GetInstance()->fragment_mgr()->cancel(
_fragment_instance_id, PPlanFragmentCancelReason::MEMORY_LIMIT_EXCEED,
cancel_details);
_fragment_instance_id = TUniqueId(); // Make sure it will only be canceled once
}
}

void ThreadMemTrackerMgr::exceeded(int64_t mem_usage, Status st) {
auto rst = _mem_trackers[_tracker_id]->mem_limit_exceeded(
nullptr, fmt::format("In TCMalloc Hook, {}", _consume_err_cb.cancel_msg), mem_usage,
st);
if (_consume_err_cb.cb_func != nullptr) {
_consume_err_cb.cb_func();
}
if (is_attach_task()) {
if (_consume_err_cb.cancel_task == true) {
if (_consume_err_cb.cancel_task) {
auto rst = _mem_trackers[_tracker_id]->mem_limit_exceeded(
nullptr,
fmt::format("Task mem limit exceeded and cancel it, msg:{}",
_consume_err_cb.cancel_msg),
mem_usage, st);
exceeded_cancel_task(rst.to_string());
} else {
// TODO(zxy) Need other processing, or log (not too often).
_consume_err_cb.cancel_task = false; // Make sure it will only be canceled once
_consume_err_cb.log_limit_exceeded = false;
}
} else {
// TODO(zxy) Need other processing, or log (not too often).
}
}
} // namespace doris
17 changes: 12 additions & 5 deletions be/src/runtime/thread_mem_tracker_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,24 @@ typedef void (*ERRCALLBACK)();

struct ConsumeErrCallBackInfo {
std::string cancel_msg;
bool cancel_task; // Whether to cancel the task when the current tracker exceeds the limit
bool cancel_task; // Whether to cancel the task when the current tracker exceeds the limit.
ERRCALLBACK cb_func;
bool log_limit_exceeded; // Whether to print log_usage of mem tracker when mem limit exceeded.

ConsumeErrCallBackInfo() { init(); }

ConsumeErrCallBackInfo(const std::string& cancel_msg, bool cancel_task, ERRCALLBACK cb_func)
: cancel_msg(cancel_msg), cancel_task(cancel_task), cb_func(cb_func) {}
ConsumeErrCallBackInfo(const std::string& cancel_msg, bool cancel_task, ERRCALLBACK cb_func,
bool log_limit_exceeded)
: cancel_msg(cancel_msg),
cancel_task(cancel_task),
cb_func(cb_func),
log_limit_exceeded(log_limit_exceeded) {}

void init() {
cancel_msg = "";
cancel_task = false;
cancel_task = true;
cb_func = nullptr;
log_limit_exceeded = true;
}
};

Expand Down Expand Up @@ -94,11 +100,12 @@ class ThreadMemTrackerMgr {
void add_tracker(const std::shared_ptr<MemTracker>& mem_tracker);

ConsumeErrCallBackInfo update_consume_err_cb(const std::string& cancel_msg, bool cancel_task,
ERRCALLBACK cb_func) {
ERRCALLBACK cb_func, bool log_limit_exceeded) {
_temp_consume_err_cb = _consume_err_cb;
_consume_err_cb.cancel_msg = cancel_msg;
_consume_err_cb.cancel_task = cancel_task;
_consume_err_cb.cb_func = cb_func;
_consume_err_cb.log_limit_exceeded = log_limit_exceeded;
return _temp_consume_err_cb;
}

Expand Down

0 comments on commit 590dc6a

Please sign in to comment.