Skip to content

Commit

Permalink
[SYCL] Eliminate const_cast
Browse files Browse the repository at this point in the history
Signed-off-by: Sergey Kanaev <[email protected]>
  • Loading branch information
Sergey Kanaev committed May 18, 2020
1 parent 692bf79 commit 6f3b4d7
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 34 deletions.
53 changes: 27 additions & 26 deletions sycl/source/detail/scheduler/graph_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ MemObjRecord *Scheduler::GraphBuilder::getMemObjRecord(SYCLMemObjI *MemObject) {

MemObjRecord *
Scheduler::GraphBuilder::getOrInsertMemObjRecord(const QueueImplPtr &Queue,
Requirement *Req) {
const Requirement *Req) {
SYCLMemObjI *MemObject = Req->MSYCLMemObj;
MemObjRecord *Record = getMemObjRecord(MemObject);

Expand Down Expand Up @@ -416,8 +416,8 @@ Command *Scheduler::GraphBuilder::addHostAccessor(Requirement *Req,
Command *UpdateHostAccCmd = insertUpdateHostReqCmd(Record, Req, HostQueue);

// Need empty command to be blocked until host accessor is destructed
EmptyCommand *EmptyCmd = addEmptyCmd(UpdateHostAccCmd, {Req}, HostQueue,
Command::BlockReason::HostAccessor);
EmptyCommand *EmptyCmd = addEmptyCmd<Requirement>(
UpdateHostAccCmd, {Req}, HostQueue, Command::BlockReason::HostAccessor);

Req->MBlockedCmd = EmptyCmd;

Expand Down Expand Up @@ -446,7 +446,7 @@ Command *Scheduler::GraphBuilder::addCGUpdateHost(
/// 2. New and examined commands has non-overlapping requirements -> can bypass
/// 3. New and examined commands have different contexts -> cannot bypass
std::set<Command *>
Scheduler::GraphBuilder::findDepsForReq(MemObjRecord *Record, Requirement *Req,
Scheduler::GraphBuilder::findDepsForReq(MemObjRecord *Record, const Requirement *Req,
const ContextImplPtr &Context) {
std::set<Command *> RetDeps;
std::set<Command *> Visited;
Expand Down Expand Up @@ -514,7 +514,7 @@ DepDesc Scheduler::GraphBuilder::findDepForRecord(Command *Cmd,
// The function searches for the alloca command matching context and
// requirement.
AllocaCommandBase *Scheduler::GraphBuilder::findAllocaForReq(
MemObjRecord *Record, Requirement *Req, const ContextImplPtr &Context) {
MemObjRecord *Record, const Requirement *Req, const ContextImplPtr &Context) {
auto IsSuitableAlloca = [&Context, Req](AllocaCommandBase *AllocaCmd) {
bool Res = sameCtx(AllocaCmd->getQueue()->getContextImplPtr(), Context);
if (IsSuitableSubReq(Req)) {
Expand All @@ -535,7 +535,7 @@ AllocaCommandBase *Scheduler::GraphBuilder::findAllocaForReq(
// Note, creation of new allocation command can lead to the current context
// (Record->MCurContext) change.
AllocaCommandBase *Scheduler::GraphBuilder::getOrCreateAllocaForReq(
MemObjRecord *Record, Requirement *Req, QueueImplPtr Queue) {
MemObjRecord *Record, const Requirement *Req, QueueImplPtr Queue) {

AllocaCommandBase *AllocaCmd =
findAllocaForReq(Record, Req, Queue->getContextImplPtr());
Expand Down Expand Up @@ -640,9 +640,14 @@ void Scheduler::GraphBuilder::markModifiedIfWrite(MemObjRecord *Record,
}
}

EmptyCommand *Scheduler::GraphBuilder::addEmptyCmd(
Command *Cmd, const std::vector<Requirement *> &Reqs,
const QueueImplPtr &Queue, Command::BlockReason Reason) {
template<typename T>
typename std::enable_if<std::is_same<typename std::remove_cv<T>::type,
Requirement>::value,
EmptyCommand *>::type
Scheduler::GraphBuilder::addEmptyCmd(Command *Cmd,
const std::vector<T *> &Reqs,
const QueueImplPtr &Queue,
Command::BlockReason Reason) {
EmptyCommand *EmptyCmd =
new EmptyCommand(Scheduler::getInstance().getDefaultHostQueue());

Expand All @@ -653,7 +658,7 @@ EmptyCommand *Scheduler::GraphBuilder::addEmptyCmd(
EmptyCmd->MEnqueueStatus = EnqueueResultT::SyclEnqueueBlocked;
EmptyCmd->MBlockReason = Reason;

for (Requirement *Req : Reqs) {
for (T *Req : Reqs) {
MemObjRecord *Record = getOrInsertMemObjRecord(Queue, Req);
AllocaCommandBase *AllocaCmd = getOrCreateAllocaForReq(Record, Req, Queue);
EmptyCmd->addRequirement(Cmd, AllocaCmd, Req);
Expand Down Expand Up @@ -941,23 +946,19 @@ void Scheduler::GraphBuilder::connectDepEvent(Command *const Cmd,
EmptyCommand *EmptyCmd = nullptr;

if (Dep.MDepRequirement) {
Requirement *Req = const_cast<Requirement *>(Dep.MDepRequirement);

// make ConnectCmd depend on requirement
{
ConnectCmd->addDep(Dep);
assert(reinterpret_cast<Command *>(DepEvent->getCommand()) ==
Dep.MDepCommand);
// add user to Dep.MDepCommand is already performed beyond this if branch
ConnectCmd->addDep(Dep);
assert(reinterpret_cast<Command *>(DepEvent->getCommand()) ==
Dep.MDepCommand);
// add user to Dep.MDepCommand is already performed beyond this if branch

MemObjRecord *Record = getMemObjRecord(Req->MSYCLMemObj);
MemObjRecord *Record = getMemObjRecord(Dep.MDepRequirement->MSYCLMemObj);

updateLeaves({ Dep.MDepCommand }, Record, Req->MAccessMode);
addNodeToLeaves(Record, ConnectCmd, Req->MAccessMode);
}
updateLeaves({ Dep.MDepCommand }, Record, Dep.MDepRequirement->MAccessMode);
addNodeToLeaves(Record, ConnectCmd, Dep.MDepRequirement->MAccessMode);

const std::vector<Requirement *> Reqs(1, Req);
EmptyCmd = addEmptyCmd(ConnectCmd, Reqs,
const std::vector<const Requirement *> Reqs(1, Dep.MDepRequirement);
EmptyCmd = addEmptyCmd<>(ConnectCmd, Reqs,
Scheduler::getInstance().getDefaultHostQueue(),
Command::BlockReason::HostTask);
// Dependencies for EmptyCmd are set in addEmptyCmd for provided Reqs.
Expand All @@ -970,9 +971,9 @@ void Scheduler::GraphBuilder::connectDepEvent(Command *const Cmd,
Cmd->addDep(CmdDep);
}
} else {
EmptyCmd = addEmptyCmd(ConnectCmd, {},
Scheduler::getInstance().getDefaultHostQueue(),
Command::BlockReason::HostTask);
EmptyCmd = addEmptyCmd<Requirement>(
ConnectCmd, {}, Scheduler::getInstance().getDefaultHostQueue(),
Command::BlockReason::HostTask);

// There is no requirement thus, empty command will only depend on
// ConnectCmd via its event.
Expand Down
20 changes: 12 additions & 8 deletions sycl/source/detail/scheduler/scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ class Scheduler {
/// \return a pointer to MemObjRecord for pointer to memory object. If the
/// record is not found, nullptr is returned.
MemObjRecord *getOrInsertMemObjRecord(const QueueImplPtr &Queue,
Requirement *Req);
const Requirement *Req);

/// Decrements leaf counters for all leaves of the record.
void decrementLeafCountersForRecord(MemObjRecord *Record);
Expand Down Expand Up @@ -546,20 +546,24 @@ class Scheduler {
const QueueImplPtr &Queue);

/// Finds dependencies for the requirement.
std::set<Command *> findDepsForReq(MemObjRecord *Record, Requirement *Req,
std::set<Command *> findDepsForReq(MemObjRecord *Record,
const Requirement *Req,
const ContextImplPtr &Context);

EmptyCommand *addEmptyCmd(Command *Cmd,
const std::vector<Requirement *> &Req,
const QueueImplPtr &Queue,
Command::BlockReason Reason);
template<typename T>
typename std::enable_if<std::is_same<typename std::remove_cv<T>::type,
Requirement>::value,
EmptyCommand *>::type
addEmptyCmd(Command *Cmd, const std::vector<T *> &Req,
const QueueImplPtr &Queue, Command::BlockReason Reason);

protected:
/// Finds a command dependency corresponding to the record.
DepDesc findDepForRecord(Command *Cmd, MemObjRecord *Record);

/// Searches for suitable alloca in memory record.
AllocaCommandBase *findAllocaForReq(MemObjRecord *Record, Requirement *Req,
AllocaCommandBase *findAllocaForReq(MemObjRecord *Record,
const Requirement *Req,
const ContextImplPtr &Context);

friend class Command;
Expand All @@ -569,7 +573,7 @@ class Scheduler {
///
/// If none found, creates new one.
AllocaCommandBase *getOrCreateAllocaForReq(MemObjRecord *Record,
Requirement *Req,
const Requirement *Req,
QueueImplPtr Queue);

void markModifiedIfWrite(MemObjRecord *Record, Requirement *Req);
Expand Down

0 comments on commit 6f3b4d7

Please sign in to comment.