This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
threaded_engine.h
609 lines (577 loc) · 22.1 KB
/
threaded_engine.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2015 by Contributors
* \file threaded_engine.h
* \brief Implements base class of threaded engine
* that tracks the dependency and pushes actions to execute.
* \author Yutian Li
*/
#ifndef MXNET_ENGINE_THREADED_ENGINE_H_
#define MXNET_ENGINE_THREADED_ENGINE_H_
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/omp.h>
#include <mxnet/storage.h>
#include <vector>
#include <functional>
#include <condition_variable>
#include <atomic>
#include <utility>
#include <mutex>
#include <string>
#include <thread>
#include "./engine_impl.h"
#include "../profiler/profiler.h"
#include "./openmp.h"
#include "../common/object_pool.h"
#include "../profiler/custom_op_profiler.h"
namespace mxnet {
namespace engine {
// Define helper macros for debug information.
#if ENGINE_DEBUG
#define DEFINE_ENGINE_DEBUG_INFO(Type) \
static std::atomic<std::size_t> counter; \
Type() { LOG(INFO) << __func__ << " " << ++counter; } \
~Type() { LOG(INFO) << __func__ << " " << --counter; }
#else
#define DEFINE_ENGINE_DEBUG_INFO(Type)
#endif
// Forward declarations
struct ThreadedOpr;
/*! shared_ptr to exception_ptr, used for exception handling */
typedef std::shared_ptr<std::exception_ptr> ExceptionRef;
/*!
* \brief Operation block in the scheduler.
* Each OprBlock corresponds to an operation pushed to the engine.
*/
struct OprBlock : public common::ObjectPoolAllocatable<OprBlock> {
/*!
* \brief wait number of pending tasks this OprBlock is waiting for.
*/
std::atomic<int> wait{0};
/*! \brief Pointer to information on performing real operation */
ThreadedOpr* opr{nullptr};
/*! \brief The context this operator */
Context ctx;
/*! \brief priority of the function */
int priority;
/*! \brief indicate whether to profile this operator */
bool profiling{false};
/*! \brief operator execution statistics */
std::unique_ptr<profiler::ProfileOperator> opr_profile;
// define possible debug information
DEFINE_ENGINE_DEBUG_INFO(OprBlock);
/*!
* \brief call this function to decrease the wait counter.
* \return the wait counter after the decreasement.
*/
inline int decr_wait() {
// check invariant, avoid over trigger
const int ret = --wait;
CHECK_GE(ret, 0);
return ret;
}
}; // struct OprBlock
/*!
* \brief VersionedVarBlock that corresponding to a variable version.
* This is a basic unit of LinkedList in the ThreadedVar.
*/
struct VersionedVarBlock
: public common::ObjectPoolAllocatable<VersionedVarBlock> {
/*! \brief next block in the LinkedList */
VersionedVarBlock* next{nullptr};
/*! \brief the operation this block triggers */
OprBlock* trigger{nullptr};
/*! \brief whether this operation is a write(mutate) operation. */
bool write{false};
/*! \brief define possible debug information */
DEFINE_ENGINE_DEBUG_INFO(VersionedVarBlock);
}; // struct VersionedVarBlock
/*!
* \brief Variable implementation.
* Each ThreadedVar is a linked list(queue) of operations to be performed.
*/
class ThreadedVar final
: public Var, public common::ObjectPoolAllocatable<ThreadedVar> {
public:
/*!
* \brief constructor
* \param head head block of the LinkedList,
* need to be initialized with next==nullptr and trigger=nullptr.
*/
explicit ThreadedVar(VersionedVarBlock* head);
/*!
* \brief Schedule a read operation on this variable.
* If the opr_block can be runed right away,
* the wait counter of opr_block will be decreased.
* Otherwise, the opr_block will be added to waiting queue.
* \param opr_block The operation to be scheduled.
*/
inline void AppendReadDependency(OprBlock* opr_block);
/*!
* \brief Schedule a write operation on this variable.
* If the opr_block can be runed right away,
* the wait counter of opr_block will be decreased.
* Otherwise, the opr_block will be added to waiting queue.
* \param opr_block The operation to be scheduled.
*/
inline void AppendWriteDependency(OprBlock* opr_block);
/*!
* \brief A read operation is completed on this variable.
* This function may trigger subsequent waiting operations on this variable.
*
* \param dispatcher the function called to trigger the operation,
* when all of its dependencies are satiesfied.
* \tparam Dispatcher the function called to trigger an operation.
*/
template <typename Dispatcher>
inline void CompleteReadDependency(Dispatcher dispatcher);
/*!
* \brief A write operation is completed on this variable.
* This function may trigger subsequent waiting operations on this variable.
*
* \param dispatcher the function called to trigger the operation,
* when all of its dependencies are satiesfied.
* \tparam Dispatcher the function called to trigger an operation.
* \return to_delete, whether this Variable can be deleted after this functin.
*/
template <typename Dispatcher>
inline bool CompleteWriteDependency(Dispatcher dispatcher);
/*! \brief Mark this variable to be deleted. */
inline void SetToDelete();
/*! \return whether this variable is ready to read. */
inline bool ready_to_read();
inline size_t version() override;
/*!
* \brief Cast a Var pointer to ThreadedVar pointer
* \param ptr pointer from base.
* \return a casted pointer.
*/
inline static ThreadedVar* CastFromBase(Var* ptr) {
return ptr->Cast<ThreadedVar>();
}
// code for debug.
#if ENGINE_DEBUG
static std::atomic<std::size_t> counter;
~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; }
#endif // ENGINE_DEBUG
/*!
* \brief exception_ptr associated with the ThreadedOpr
* cannot modify state of exception object since dereferencing
* exception_ptr is undefined behavior. Using shared_ptr to hold
* exception_ptr and overcome this limitation */
ExceptionRef var_exception;
private:
// TODO(hotpxl) change this to spinlock for faster runtime
// TODO(hotpxl) consider rename head
/*! \brief internal mutex of the ThreadedVar */
std::mutex mutex_;
/*!
* \brief number of pending reads operation in the variable.
* will be marked as -1 when there is a already triggered pending write.
*/
int num_pending_reads_{0};
/*!
* \brief Points to the last VersionedVarBlock in the queue.
* head_ always points to a empty VersionedVarBlock.
* So when we want to append an operation to the queue:
* 1) update head_->trigger to be new op
* 2) update head_->next to be a new VersionedVarBlock
* 3) move head to head->next.
*/
VersionedVarBlock* head_{nullptr};
/*!
* \brief The pointer to next write to perform.
* This pointer will only be updated when the write completes.
* This is actually the head(oldest operation) in the queue.
*/
VersionedVarBlock* pending_write_{nullptr};
/*!
* \brief If true, delete after operation completes.
*/
bool to_delete_{false};
/*! \brief special const on num_pending_reads_ to mark write being triggered */
static constexpr int kWriteTriggered = -1;
/*!
* \brief derived invariant of ready to ready, without lock.
* \return whether the current variable is ready to read.
*/
inline bool is_ready_to_read() const {
return pending_write_ == nullptr;
}
}; // struct ThreadedVar
/*!
* \brief Operator used in ThreadedEngine.
*/
struct ThreadedOpr final : public Opr,
public common::ObjectPoolAllocatable<ThreadedOpr> {
/*! \brief The function to be invoked each time. */
Engine::AsyncFn fn;
/*! \brief The variable this operation will read from. */
std::vector<ThreadedVar*> const_vars;
/*! \brief The variable this operation will mutate. */
std::vector<ThreadedVar*> mutable_vars;
/*! \brief The property of the operator */
FnProperty prop;
/*! \brief The name of the operator */
const char* opr_name{nullptr};
/*!
* \brief Whether this is an temporary operator
* that can be deleted right after the operation completed.
*/
bool temporary{false};
/*!
* \brief Whether this is a WaitForVar operation
*/
bool wait{false};
/*!
* \brief Cast a Opr pointer to ThreadedOpr pointer
* \param ptr pointer from base.
* \return a casted pointer.
*/
inline static ThreadedOpr* CastFromBase(Opr* ptr) {
return ptr->Cast<ThreadedOpr>();
}
// define possible debug information
DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr);
/*!
* \brief exception_ptr associated with the ThreadedOpr
* cannot modify state of exception object since dereferencing
* exception_ptr is undefined behavior. Using shared_ptr to hold
* exception_ptr and overcome this limitation */
ExceptionRef opr_exception;
}; // struct ThreadedOpr
/*!
* \brief Base class of all ThreadedEngine.
* This class implements a thread safe version of engine.
* The engine tracks the dependencies, and will call PushToExecute
* to execute a specific task.
*
* Subclass can implement PushToExecute to design specific
* execution policy for the tasks.
*/
class ThreadedEngine : public Engine {
public:
// implementing all the functions from Engine.
ThreadedVar* NewVariable() override;
ThreadedOpr* NewOperator(AsyncFn fn,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
const char* opr_name = nullptr,
bool wait = false) override;
void DeleteOperator(OprHandle op) override;
void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) override;
void PushAsync(AsyncFn exec_fun, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr,
bool wait = false) override;
void PushSync(SyncFn exec_fn, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) override;
void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override;
void WaitForVar(VarHandle var) override;
void WaitForAll() override;
void Throw(VarHandle var) override;
void NotifyShutdown() override {
shutdown_phase_.store(true);
}
ThreadedEngine() {
engine_info_ = dmlc::GetEnv("MXNET_ENGINE_INFO", false);
objpool_opr_ref_ = common::ObjectPool<ThreadedOpr>::_GetSharedRef();
objpool_blk_ref_ = common::ObjectPool<OprBlock>::_GetSharedRef();
objpool_varblk_ref_ = common::ObjectPool<VersionedVarBlock>::_GetSharedRef();
objpool_var_ref_ = common::ObjectPool<ThreadedVar>::_GetSharedRef();
storage_ref_ = Storage::_GetSharedRef();
// Get a ref to the profiler so that it doesn't get killed before us
profiler::Profiler::Get(&profiler_);
}
~ThreadedEngine() {
{
std::unique_lock<std::mutex> lock{finished_m_};
kill_.store(true);
}
finished_cv_.notify_all();
}
protected:
/*!
* \brief Push the opr block to execution queue to be executed.
* This function is implemented by the corresponding subclass
* for specific policy.
*
* \param opr_block The operator block.
* \param pusher_thread whether the caller is the thread that calls push
*/
virtual void PushToExecute(OprBlock* opr_block, bool pusher_thread) = 0;
/*!
* \brief Call this function to actually execute an opr_block
* This function also deletes the opr_block after execution.
* \param run_ctx runtime context used to execute the function.
* \param opr_block the opr_block to be executed and deleted.
*/
void ExecuteOprBlock(RunContext run_ctx, OprBlock* opr_block) {
ThreadedOpr* threaded_opr = opr_block->opr;
if (opr_block->profiling && threaded_opr->opr_name) {
std::unique_ptr<profiler::ProfileOperator::Attributes> attrs;
if (profiler_->AggregateEnabled()) {
attrs.reset(new profiler::ProfileOperator::Attributes());
}
const Context& ctx = opr_block->ctx;
opr_block->opr_profile.reset(new profiler::ProfileOperator(threaded_opr->opr_name,
attrs.release()));
opr_block->opr_profile->startForDevice(ctx.dev_type, ctx.dev_id);
}
CallbackOnComplete callback =
this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block);
const bool debug_info = (engine_info_ && debug_push_opr_ == opr_block);
if (debug_info) {
LOG(INFO) << "ExecuteOprBlock " << opr_block
<< "shutdown_phase=" << shutdown_phase_;
}
// still run cleanup in shutdown_phase
if (!shutdown_phase_ || threaded_opr->prop == FnProperty::kDeleteVar) {
try {
OnStart(threaded_opr);
if (debug_info) {
LOG(INFO) << "ExecuteOprFn ";
}
try {
if ((!(threaded_opr->opr_exception && *threaded_opr->opr_exception) ||
threaded_opr->prop == FnProperty::kNoSkip) || threaded_opr->wait) {
threaded_opr->fn(run_ctx, callback);
} else {
callback();
}
} catch (const std::exception& e) {
threaded_opr->opr_exception =
std::make_shared<std::exception_ptr>(std::current_exception());
callback();
}
if (debug_info) {
LOG(INFO) << "Fin ExecuteOprFn ";
}
} catch (std::exception& e) {
std::string what = e.what();
if (what.find("driver shutting down") == std::string::npos &&
!shutdown_phase_) {
LOG(FATAL)
<< e.what() << "\n"
<< "A fatal error occurred in asynchronous engine operation. "
"If you do not know what caused this error, "
"you can try set environment variable MXNET_ENGINE_TYPE "
"to NaiveEngine and run with debugger (i.e. gdb). "
"This will force all operations to be synchronous and "
"backtrace will give you the series of calls that lead "
"to this error. Remember to set MXNET_ENGINE_TYPE back to "
"empty after debugging.";
}
}
} else {
callback();
}
}
int bulk_size() const override {
const profiler::Profiler *prof = profiler::Profiler::Get();
return (prof && prof->AggregateRunning()) ? 0 : BulkStatusStore::Get()->bulk_size;
}
int set_bulk_size(int bulk_size) override {
BulkStatus& bulk_status = *BulkStatusStore::Get();
std::swap(bulk_status.bulk_size, bulk_size);
if (bulk_status.count >= bulk_status.bulk_size) BulkFlush();
if (!bulk_status.functions) {
bulk_status.functions.reset(new std::vector<SyncFn>());
}
bulk_status.functions->reserve(bulk_size);
return bulk_size;
}
private:
/*! \brief structure for holding bulk execution status */
struct BulkStatus {
/*! \brief maximum number of ops per bulk */
int bulk_size = 0;
/*! \brief current number of ops in bulk */
int count = 0;
/*! \brief context of current ops */
Context ctx;
/*! \brief current op functions */
std::shared_ptr<std::vector<SyncFn>> functions;
/*! \brief constant variables */
std::vector<VarHandle> const_vars;
/*! \brief mutable variables */
std::vector<VarHandle> mutable_vars;
};
/*! thread local store for bulk */
typedef dmlc::ThreadLocalStore<BulkStatus> BulkStatusStore;
/*!
* \brief check if thee is duplication in const_vars and mutable_vars.
* \param const_vars the variables to read from.
* \param mutable_vars the variables to mutate.
*/
void CheckDuplicate(std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars);
/*!
* \brief Callback on operation completion.
*
* On operation completion, this will trigger subsequent operations.
*/
inline void OnComplete(ThreadedOpr* threaded_opr);
/*!
* \brief rethrow caught exception in WaitForVar
* \param threaded_var the var that we are waiting to read
*/
inline void ThrowException(ThreadedVar* threaded_var);
/*!
* \brief Mark exceptions before operation execution.
*
* Will mark the operator as a failure and associate exception_ptr
* if any of the read dependencies have exception associated.
*/
inline void OnStart(ThreadedOpr* threaded_opr) {
for (auto&& i : threaded_opr->const_vars) {
if (i->var_exception && *i->var_exception) {
threaded_opr->opr_exception = i->var_exception;
AddToGlobalExceptions(threaded_opr->opr_exception);
break;
}
}
if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception)) {
for (auto&& i : threaded_opr->mutable_vars) {
if (i->var_exception && *i->var_exception) {
threaded_opr->opr_exception = i->var_exception;
AddToGlobalExceptions(threaded_opr->opr_exception);
break;
}
}
}
}
static void OnCompleteStatic(Engine *engine, void *threaded_opr,
const dmlc::Error* error);
/*!
* \brief find exception in global_exception_refs and add it if missing
* \param opr_exception the exception to be added to global_exception_refs
*/
inline void AddToGlobalExceptions(const ExceptionRef& opr_exception) {
auto it = std::find(global_exception_refs_.begin(),
global_exception_refs_.end(), opr_exception);
if (it == global_exception_refs_.end()) {
global_exception_refs_.push_back(opr_exception);
}
return;
}
/*! \brief append an operator to bulk */
inline void BulkAppend(SyncFn exec_fn, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars) {
BulkStatus& bulk_status = *BulkStatusStore::Get();
if (!bulk_status.functions) {
bulk_status.functions.reset(new std::vector<SyncFn>());
}
bulk_status.functions->push_back(exec_fn);
if (!bulk_status.count) {
bulk_status.ctx = exec_ctx;
}
++bulk_status.count;
bulk_status.const_vars.insert(
bulk_status.const_vars.end(), const_vars.begin(), const_vars.end());
bulk_status.mutable_vars.insert(
bulk_status.mutable_vars.end(), mutable_vars.begin(), mutable_vars.end());
if (bulk_status.count >= bulk_status.bulk_size) BulkFlush();
}
/*! \brief flush current bulk to execution */
inline void BulkFlush() {
BulkStatus& bulk_status = *BulkStatusStore::Get();
if (!bulk_status.count) return;
bulk_status.count = 0;
DeduplicateVarHandle(&bulk_status.const_vars, &bulk_status.mutable_vars);
auto functions = bulk_status.functions;
this->PushAsync([functions](RunContext ctx, CallbackOnComplete on_complete) {
ctx.is_bulk = true;
for (auto& fn : *functions) {
fn(ctx);
}
ctx.is_bulk = false;
bool is_gpu = ctx.ctx.dev_mask() == gpu::kDevMask;
if (is_gpu) {
ctx.get_stream<gpu>()->Wait();
}
on_complete();
}, bulk_status.ctx, bulk_status.const_vars, bulk_status.mutable_vars,
FnProperty::kNormal, 0, "ImperativeBulk");
bulk_status.functions.reset(new std::vector<SyncFn>());
bulk_status.functions->reserve(bulk_status.bulk_size);
bulk_status.const_vars.clear();
bulk_status.mutable_vars.clear();
}
/*!
* \brief Number of pending operations.
*/
std::atomic<int> pending_{0};
/*! \brief whether we want to kill the waiters */
std::atomic<bool> kill_{false};
/*! \brief whether it is during shutdown phase*/
std::atomic<bool> shutdown_phase_{false};
/*!\brief show more information from engine actions */
bool engine_info_{false};
/*! \brief debug information about wait for var. */
std::atomic<ThreadedVar*> debug_wait_var_{nullptr};
/*! \brief debug information about wait for var. */
std::atomic<OprBlock*> debug_push_opr_{nullptr};
/*!
* \brief Mutex and condition_variable,
* used to Notify waits for single or all variables.
*/
std::mutex finished_m_;
std::condition_variable finished_cv_;
/*! \brief global exception refs, which are rethrown when WaitForAll is called */
std::vector<ExceptionRef> global_exception_refs_;
/*!
* \brief Holding a shared_ptr to the object pool to prevent it from being destructed too early
* See also #309 (https://github.com/dmlc/mxnet/issues/309)
*/
std::shared_ptr<common::ObjectPool<ThreadedOpr> > objpool_opr_ref_;
std::shared_ptr<common::ObjectPool<OprBlock> > objpool_blk_ref_;
std::shared_ptr<common::ObjectPool<VersionedVarBlock> > objpool_varblk_ref_;
std::shared_ptr<common::ObjectPool<ThreadedVar> > objpool_var_ref_;
/*!
* \brief Async destruction of some objects is relied on storage,
* prevent it from being destructed too early
*/
std::shared_ptr<Storage> storage_ref_;
#if MXNET_USE_CUDA
/*! \brief Number of GPU devices available */
std::atomic<int> device_count_{-1};
#endif
/*! \brief Hold a ref count ot the profiler */
std::shared_ptr<profiler::Profiler> profiler_;
/*!
* \brief Disallow copy construction and assignment.
* \note This must be last
*/
DISALLOW_COPY_AND_ASSIGN(ThreadedEngine);
}; // class ThreadedEngine
} // namespace engine
} // namespace mxnet
#endif // MXNET_ENGINE_THREADED_ENGINE_H_