Skip to content

Commit

Permalink
Don't count iterators as data spaces #19
Browse files Browse the repository at this point in the history
  • Loading branch information
riftEmber committed Nov 22, 2021
1 parent c3aaa29 commit 1b664ef
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 23 deletions.
6 changes: 4 additions & 2 deletions include/ComputationBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ class ComputationBuilder {
//! Computations referenced from any others, stored for potential re-use
static std::map<std::string, Computation *> subComputations;

//! Context information about the position we're currently at.
//! Updated to the most recently processed statement in any Computation
static PositionContext *positionContext;

private:
//! Top-level Computation being built up
Computation *computation;
//! Context information about the position we're currently at
PositionContext context;
//! Whether a return Stmt has been hit in this function
bool haveFoundAReturn = false;

Expand Down
3 changes: 3 additions & 0 deletions include/PositionContext.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ struct PositionContext {
//! Get a string representing the given data access
std::string getDataAccessString(DataAccess *);

//! Check whether the given name is an iterator in this context
bool isIteratorName(const std::string &varName);

// enter* and exit* methods add iterators and constraints when entering a
// new scope, remove when leaving the scope

Expand Down
40 changes: 21 additions & 19 deletions src/ComputationBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ using namespace clang;

namespace spf_ie {

PositionContext *ComputationBuilder::positionContext;

/* ComputationBuilder */

std::map<std::string, Computation *> ComputationBuilder::subComputations;
Expand All @@ -30,7 +32,7 @@ ComputationBuilder::buildComputationFromFunction(FunctionDecl *funcDecl) {
Utils::printErrorAndExit("Invalid function body", funcDecl->getBody());
}

context = PositionContext();
positionContext = new PositionContext();
computation = new iegenlib::Computation(funcDecl->getNameAsString());

// add function parameters to the Computation
Expand Down Expand Up @@ -76,25 +78,25 @@ void ComputationBuilder::processSingleStmt(clang::Stmt *stmt) {
}

if (auto *asForStmt = dyn_cast<ForStmt>(stmt)) {
context.schedule.advanceSchedule();
context.enterFor(asForStmt);
positionContext->schedule.advanceSchedule();
positionContext->enterFor(asForStmt);
processBody(asForStmt->getBody());
context.exitFor();
positionContext->exitFor();
} else if (auto *asIfStmt = dyn_cast<IfStmt>(stmt)) {
if (asIfStmt->getConditionVariable()) {
Utils::printErrorAndExit(
"If statement condition variable declarations are unsupported",
asIfStmt);
}
context.enterIf(asIfStmt);
positionContext->enterIf(asIfStmt);
processBody(asIfStmt->getThen());
context.exitIf();
positionContext->exitIf();
// treat else clause (if present) as another if statement, but with
// condition inverted
if (asIfStmt->hasElseStorage()) {
context.enterIf(asIfStmt, true);
positionContext->enterIf(asIfStmt, true);
processBody(asIfStmt->getElse());
context.exitIf();
positionContext->exitIf();
}
} else if (auto *asCallExpr = dyn_cast<CallExpr>(stmt)) {
// TODO: detect function calls that are not the only thing in the statement
Expand All @@ -103,7 +105,7 @@ void ComputationBuilder::processSingleStmt(clang::Stmt *stmt) {
Utils::printErrorAndExit("Cannot processes this kind of call expression", asCallExpr);
}
auto *calleeDefinition = callee->getDefinition();
context.schedule.advanceSchedule();
positionContext->schedule.advanceSchedule();
std::string calleeName = calleeDefinition->getNameAsString();
if (!subComputations.count(calleeName)) {
// build Computation from calleeDefinition, if we haven't done so already
Expand All @@ -121,15 +123,15 @@ void ComputationBuilder::processSingleStmt(clang::Stmt *stmt) {
callArgStrings.emplace_back(Utils::stmtToString(arg));
}
auto appendResult = computation->appendComputation(subComputations[calleeName],
context.getIterSpaceString(),
context.getExecScheduleString(),
positionContext->getIterSpaceString(),
positionContext->getExecScheduleString(),
callArgStrings);

context.schedule.skipToPosition(appendResult.tuplePosition);
positionContext->schedule.skipToPosition(appendResult.tuplePosition);

// TODO: handle return value
} else {
context.schedule.advanceSchedule();
positionContext->schedule.advanceSchedule();
addStmt(stmt);
}
}
Expand Down Expand Up @@ -176,10 +178,10 @@ void ComputationBuilder::addStmt(clang::Stmt *clangStmt) {
}
newStmt->setStmtSourceCode(stmtSourceCode);
// iteration space
std::string iterationSpace = context.getIterSpaceString();
std::string iterationSpace = positionContext->getIterSpaceString();
newStmt->setIterationSpace(iterationSpace);
// execution schedule
std::string executionSchedule = context.getExecScheduleString();
std::string executionSchedule = positionContext->getExecScheduleString();
newStmt->setExecutionSchedule(executionSchedule);
// data accesses
std::vector<std::pair<std::string, std::string>> dataReads;
Expand All @@ -188,7 +190,7 @@ void ComputationBuilder::addStmt(clang::Stmt *clangStmt) {
std::string dataSpaceAccessed = it_accesses.name;
// enforce loop invariance
if (!it_accesses.isRead) {
for (const auto &invariantGroup: context.invariants) {
for (const auto &invariantGroup: positionContext->invariants) {
if (std::find(
invariantGroup.begin(), invariantGroup.end(),
dataSpaceAccessed) != invariantGroup.end()) {
Expand All @@ -203,10 +205,10 @@ void ComputationBuilder::addStmt(clang::Stmt *clangStmt) {
// insert data access
if (it_accesses.isRead) {
newStmt->addRead(dataSpaceAccessed,
context.getDataAccessString(&it_accesses));
positionContext->getDataAccessString(&it_accesses));
} else {
newStmt->addWrite(dataSpaceAccessed,
context.getDataAccessString(&it_accesses));
positionContext->getDataAccessString(&it_accesses));
}
}

Expand All @@ -222,7 +224,7 @@ void ComputationBuilder::addStmt(clang::Stmt *clangStmt) {

void ComputationBuilder::processReturnStmt(clang::ReturnStmt *returnStmt) {
haveFoundAReturn = true;
if (context.nestLevel != 0) {
if (positionContext->nestLevel != 0) {
Utils::printErrorAndExit("Return within nested structures is disallowed.", returnStmt);
}

Expand Down
6 changes: 5 additions & 1 deletion src/DataAccessHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
#include <sstream>
#include <stack>
#include <string>
#include <utility>
#include <vector>

#include "Driver.hpp"
#include "Utils.hpp"
#include "ComputationBuilder.hpp"
#include "clang/AST/Expr.h"

using namespace clang;
Expand Down Expand Up @@ -76,6 +76,10 @@ void DataAccessHandler::processSingleAccessExpr(Expr *fullExpr,
auto accesses = gatherDataAccessesInExpr(fullExpr, isRead);

for (const auto &access: accesses) {
// skip counting iterators as data accesses
if (ComputationBuilder::positionContext->isIteratorName(access.name)) {
continue;
}
dataSpacesAccessed.emplace(access.name);
stmtDataAccesses.push_back(access);
}
Expand Down
10 changes: 9 additions & 1 deletion src/PositionContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include <memory>
#include <sstream>
#include <stack>
#include <string>
#include <tuple>
#include <vector>
Expand Down Expand Up @@ -106,6 +105,15 @@ std::string PositionContext::getDataAccessString(DataAccess *access) {
return os.str();
}

bool PositionContext::isIteratorName(const std::string &varName) {
for (const auto &iteratorName: iterators) {
if (varName == iteratorName) {
return true;
}
}
return false;
}

void PositionContext::enterFor(ForStmt *forStmt) {
std::string error;
std::string errorReason;
Expand Down

0 comments on commit 1b664ef

Please sign in to comment.