Skip to content

Commit

Permalink
Care types other than uchar
Browse files Browse the repository at this point in the history
  • Loading branch information
long-long-float committed Jul 19, 2020
1 parent 0baebed commit 1e26f02
Showing 1 changed file with 57 additions and 18 deletions.
75 changes: 57 additions & 18 deletions src/normalization/Normalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
#include "MemoryAccess.h"
#include "Rewrite.h"

#ifdef __GNUC__
#include <cxxabi.h>
#endif

#include "log.h"

#include <string>
Expand Down Expand Up @@ -359,11 +363,10 @@ void ValueBinaryOp::expand(ExpandedExprs& exprs)
break;
}

int sign = num >= 0;
// TODO: Care other types
auto value = Value(Literal(std::abs(num)), TYPE_INT32);
std::shared_ptr<ValueExpr> expr = std::make_shared<ValueTerm>(value);
exprs.push_back(std::make_pair(sign, expr));
exprs.push_back(std::make_pair(true, expr));
}
else
{
Expand Down Expand Up @@ -584,9 +587,6 @@ std::shared_ptr<ValueExpr> calcValueExpr(std::shared_ptr<ValueExpr> expr)

void combineDMALoads(const Module& module, Method& method, const Configuration& config)
{
// vload16(unsigned int, unsigned char*)
const std::string VLOAD16_METHOD_NAME = "_Z7vload16jPU3AS1Kh";

for(auto& bb : method)
{
std::vector<intermediate::MethodCall*> loadInstrs;
Expand All @@ -597,26 +597,55 @@ void combineDMALoads(const Module& module, Method& method, const Configuration&
// Find all method calls
if(auto call = dynamic_cast<intermediate::MethodCall*>(it.get()))
{
if(call->methodName == VLOAD16_METHOD_NAME)

auto name = call->methodName;

#ifdef __GNUC__
// Copied from src/spirv/SPIRVHelper.cpp
// TODO: Move these codes to the new helper file.
int status;
char* real_name = abi::__cxa_demangle(name.data(), nullptr, nullptr, &status);
std::string result = name;

if(status == 0)
{
offsetValues.push_back(call->assertArgument(0));
loadInstrs.push_back(call);
// if demangling is successful, output the demangled function name
result = real_name;
// the demangled name contains the arguments, so we need ignore them
result = result.substr(0, result.find('('));
}
free(real_name);
auto isVload16 = result == "vload16";
#else
auto isVload16 = name.find("vload16") != std::string::npos;
#endif

// TODO: Check whether all second argument values are equal.
if(isVload16)
{
if (!addrValue.has_value())
{
addrValue = call->getArgument(1);
}
else if (addrValue == call->getArgument(1))
else if (addrValue != call->getArgument(1))
{
continue;
}

offsetValues.push_back(call->assertArgument(0));
loadInstrs.push_back(call);
}
}
}

if(offsetValues.size() <= 1)
continue;

for(auto& inst : loadInstrs)
{
logging::debug() << inst->to_string() << logging::endl;
}

std::vector<std::pair<Value, std::shared_ptr<ValueExpr>>> addrExprs;

for(auto& addrValue : offsetValues)
Expand All @@ -643,10 +672,10 @@ void combineDMALoads(const Module& module, Method& method, const Configuration&
}
}

/*for(auto& pair : addrExprs)
for(auto& pair : addrExprs)
{
logging::debug() << pair.first.to_string() << " = " << pair.second->to_string() << logging::endl;
}*/
}

std::shared_ptr<ValueExpr> diff = nullptr;
bool eqDiff = true;
Expand All @@ -671,10 +700,17 @@ void combineDMALoads(const Module& module, Method& method, const Configuration&
}
}

// logging::debug() << "all loads are " << (eqDiff ? "" : "not ") << "equal difference" << logging::endl;
logging::debug() << addrExprs.size() << " loads are " << (eqDiff ? "" : "not ") << "equal difference" << logging::endl;

if(eqDiff)
{
// The form of diff should be "0 (+/-) expressions...", then remove the value 0 at most right.
ValueExpr::ExpandedExprs expanded;
diff->expand(expanded);
diff = expanded[0].second;

logging::debug() << "diff = " << diff->to_string() << logging::endl;

if (auto term = std::dynamic_pointer_cast<ValueTerm>(diff))
{
if (auto mpValue = term->value.getConstantValue())
Expand All @@ -683,8 +719,6 @@ void combineDMALoads(const Module& module, Method& method, const Configuration&
{
if (mpLiteral->unsignedInt() < (1u << 12))
{
// TODO: cover types other than uchar
uint16_t memoryPitch = static_cast<uint16_t>(mpLiteral->unsignedInt()) * 1 * 16;

auto it = bb.walk();
bool firstCall = true;
Expand All @@ -700,17 +734,22 @@ void combineDMALoads(const Module& module, Method& method, const Configuration&
{
firstCall = false;

// TODO: limit loadInstrs.size()
auto addrArg = call->assertArgument(1);

// TODO: limit loadInstrs.size()
Value offset = assign(it, TYPE_INT32) = offsetValues[0] << 4_val;
Value addr = assign(it, TYPE_INT32) = offset + call->assertArgument(1);
Value addr = assign(it, TYPE_INT32) = offset + addrArg;

auto elemType = addrArg.type.getElementType();
uint16_t memoryPitch = static_cast<uint16_t>(mpLiteral->unsignedInt()) * elemType.getInMemoryWidth() * 16;

DataType TYPE_UCHAR16{DataType::BYTE, 16, false};
// TODO: cover types other than uchar
DataType TYPE16{elemType.getInMemoryWidth() * DataType::BYTE, 16, false};

uint64_t rows = loadInstrs.size();
VPMArea area(VPMUsage::SCRATCH, 0, static_cast<uint8_t>(rows));
auto entries = Value(Literal(static_cast<uint32_t>(rows)), TYPE_INT32);
it = method.vpm->insertReadRAM(method, it, addr, TYPE_UCHAR16,/* &area */ nullptr,
it = method.vpm->insertReadRAM(method, it, addr, TYPE16,/* &area */ nullptr,
true, INT_ZERO, entries, Optional<uint16_t>(memoryPitch));

// const VPMArea* area = nullptr, bool useMutex = true, const Value& inAreaOffset = INT_ZERO);
Expand Down

0 comments on commit 1e26f02

Please sign in to comment.