Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added IR Constuctor for UnitTests #3

Merged
merged 1 commit into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/graph.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "graph.h"
#include "graph.h"
#include <iterator>
#include <ostream>

Expand All @@ -19,4 +18,9 @@ void Graph::SetMethodName(const std::string& name)
name_method_ = name;
}

std::string Graph::GetMethodName() const
{
return name_method_;
}

}
30 changes: 30 additions & 0 deletions src/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class Graph
}

void SetMethodName(const std::string& name);
std::string GetMethodName() const;

void Dump(std::ostream &out);

#define CREATE_CREATORS(OPCODE) \
Expand All @@ -39,9 +41,37 @@ class Graph

#undef CREATE_CREATORS

template <typename T>
auto* CreateInstByIndex(uint32_t index) {
if (!unit_test_mode_) {
std::cerr << "Function only for unit tests\n";
std::abort();
}
auto inst = new T();
inst->SetId(index);
if (index >= all_inst_.size()) {
all_inst_.resize(index + 1);
}
all_inst_.at(index) = inst;
return inst;
}

Inst *GetInstByIndex(uint32_t index) {
return all_inst_.at(index);
}

void SetUnitTestMode() {
unit_test_mode_ = true;
}

uint32_t GetNumInsts() {
return all_inst_.size();
}

private:
std::vector<Inst *> all_inst_;
std::string name_method_;
bool unit_test_mode_;
};

};
47 changes: 44 additions & 3 deletions src/inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,20 @@ class Inst
return opc_;
}

virtual Inst *GetInput(uint32_t index) {
std::cerr << "Inst with opcode " << OPCODE_NAME[static_cast<size_t>(GetOpcode())] << " don't have inputs";
std::abort();
}

virtual void SetInput(uint32_t index, Inst *inst) {
std::cerr << "Inst with opcode " << OPCODE_NAME[static_cast<size_t>(GetOpcode())] << " don't have inputs";
std::abort();
}

virtual uint32_t NumInputs() {
return 0;
}

virtual void DumpInputs(std::ostream &out) const {};


Expand All @@ -89,6 +103,10 @@ class FixedInputs: public Inst
Inst(opc),
inputs_() {};

FixedInputs(Opcode opc, Type type):
Inst(opc, type),
inputs_() {};

FixedInputs(Opcode opc, Type type, std::array<Inst *, N> inputs) :
Inst(opc, type),
inputs_(inputs) {
Expand All @@ -98,19 +116,23 @@ class FixedInputs: public Inst
}
};

Inst *GetInput(uint32_t index) {
virtual Inst *GetInput(uint32_t index) override {
return inputs_.at(index);
}

const std::array<Inst *, N>& GetInputs() const {
return inputs_;
}

void SetInput(uint32_t index, Inst *inst) {
virtual void SetInput(uint32_t index, Inst *inst) override {
assert(index < N);
inputs_.at(index) = inst;
}

virtual uint32_t NumInputs() override {
return inputs_.size();
}

void DumpInputs(std::ostream &out) const override {
bool first = true;
for (auto inst : GetInputs()) {
Expand Down Expand Up @@ -155,10 +177,26 @@ class DynamicInputs: public Inst
inputs_.erase(it);
}

void SetInput(std::list<Inst *>::const_iterator &it, Inst *inst) {
virtual void SetInput(uint32_t index, Inst *inst) override {
auto it = inputs_.begin();
for (int i = 0; i < index; i++) {
it++;
}
inputs_.insert(it, inst);
}

virtual Inst *GetInput(uint32_t index) override {
auto it = inputs_.begin();
for (int i = 0; i < index; i++) {
it++;
}
return *it;
}

virtual uint32_t NumInputs() override {
return inputs_.size();
}

void DumpInputs(std::ostream &out) const override {
bool first = true;
for (auto inst : GetInputs()) {
Expand Down Expand Up @@ -197,6 +235,9 @@ class ImmidiateProperty
class AddInst : public FixedInputs<2>
{
public:
AddInst():
FixedInputs<2>(Opcode::Add) {}

AddInst(Type type, Inst *input0, Inst *input1):
FixedInputs<2>(Opcode::Add, type, {{input0, input1}}) {}

Expand Down
62 changes: 62 additions & 0 deletions src/ir_constructor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#pragma once

#include "graph.h"
#include "inst.h"
#include <type_traits>

namespace compiler {

class IrConstructor {
public:
IrConstructor():
graph_(new Graph()) {
graph_->SetUnitTestMode();
};

~IrConstructor() {
delete graph_;
}
template<typename T>
IrConstructor &CreateInst(uint32_t index) {
auto inst = graph_->CreateInstByIndex<T>(index);
current_inst_ = inst;
return *this;
}

IrConstructor &Imm(int64_t imm) {
assert(current_inst_ != nullptr);

switch(current_inst_->GetOpcode()) {
case Opcode::Constant: {
static_cast<ConstantInst *>(current_inst_)->SetImm(imm);
break;
}
default: {
assert(false && ("Should be unreachable!"));
}
}
return *this;
}

template<typename... Args>
IrConstructor &Inputs(Args ...args) {
for (auto it : {args...}) {
static_assert(std::is_same<decltype(it), int>(), "Is not \"Int\" in argument");
}
auto inputs = std::vector<int>({args...});
for (int i = 0; i < inputs.size(); i++) {
current_inst_->SetInput(i, graph_->GetInstByIndex(inputs.at(i)));
}
return *this;
}

Graph *GetGraph() {
return graph_;
}

private:
Graph *graph_;
Inst *current_inst_;
};

}
1 change: 1 addition & 0 deletions src/opcodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ constexpr std::array<const char *const, static_cast<size_t>(Opcode::NUM_OPCODES)

enum class Type {
NONE = 0,
BOOL,
INT32,
UINT32,
INT64,
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ include_directories(
add_executable(
graph_tests
graph_tests.cpp
graph_comparator.cpp
${COMPILER_SOURCES}
)

Expand Down
42 changes: 42 additions & 0 deletions tests/graph_comparator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include "graph_comparator.h"
#include "inst.h"

namespace compiler {

void GraphComparator::Compare() {
ASSERT_EQ(left_->GetMethodName(), right_->GetMethodName());
ASSERT_EQ(left_->GetNumInsts(), right_->GetNumInsts());
for (int i = 0; i < left_->GetNumInsts(); i++) {
auto left_inst = left_->GetInstByIndex(i);
auto right_inst = right_->GetInstByIndex(i);

CompareInstructions(left_inst, right_inst);
}
}

void GraphComparator::CompareInstructions(Inst *left, Inst *right) {
ASSERT_EQ(left->GetId(), right->GetId());
ASSERT_EQ(left->GetOpcode(), right->GetOpcode());
ASSERT_EQ(left->GetType(), right->GetType());
CompareInputs(left, right);

// Specific checks
switch (left->GetOpcode()) {
case Opcode::Constant: {
ASSERT_EQ(static_cast<ConstantInst *>(left)->GetImm(), static_cast<ConstantInst *>(right)->GetImm());
break;
}
default: {
break;
}
}
}

void GraphComparator::CompareInputs(Inst *left, Inst *right) {
ASSERT_EQ(left->NumInputs(), right->NumInputs());
for (int i = 0; i < left->NumInputs(); i++) {
ASSERT_EQ(left->GetInput(i)->GetId(), right->GetInput(i)->GetId());
}
}

}
25 changes: 25 additions & 0 deletions tests/graph_comparator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include "graph.h"
#include <gtest/gtest.h>

namespace compiler {

class GraphComparator
{
public:
GraphComparator(Graph *left, Graph *right):
left_(left), right_(right) {}

void Compare();

private:
void CompareInstructions(Inst *left, Inst *right);
void CompareInputs(Inst *left, Inst *right);

private:
Graph *left_;
Graph *right_;
};

}
17 changes: 16 additions & 1 deletion tests/graph_tests.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#include <gtest/gtest.h>
#include <ostream>
#include "inst.h"
#include "src/graph.h"
#include "graph.h"

#include "ir_constructor.h"
#include "tests/graph_comparator.h"

namespace compiler {

Expand Down Expand Up @@ -139,4 +142,16 @@ TEST(GraphTest, CreateIfFullWorkGraph) {
ASSERT_EQ(dump_out.str(), output);
}

TEST(GraphTest, TestConstructorAndComparator) {
auto ic = IrConstructor();
ic.CreateInst<ConstantInst>(0).Imm(123);
ic.CreateInst<AddInst>(1).Inputs(0, 0);

auto ic_after = IrConstructor();
ic_after.CreateInst<ConstantInst>(0).Imm(123);
ic_after.CreateInst<AddInst>(1).Inputs(0, 0);

GraphComparator(ic.GetGraph(), ic_after.GetGraph()).Compare();
}

}