Skip to content

Commit

Permalink
Added IR Constuctor for UnitTests
Browse files Browse the repository at this point in the history
+ Added simple GraphComparator for UnitTests based on gtest
+ Added new instruction Compare
  • Loading branch information
techie-mike committed Oct 5, 2023
1 parent 4cc05e8 commit e1f03da
Show file tree
Hide file tree
Showing 9 changed files with 226 additions and 5 deletions.
6 changes: 5 additions & 1 deletion src/graph.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "graph.h"
#include "graph.h"
#include <iterator>
#include <ostream>

Expand All @@ -19,4 +18,9 @@ void Graph::SetMethodName(const std::string& name)
name_method_ = name;
}

std::string Graph::GetMethodName() const
{
return name_method_;
}

}
30 changes: 30 additions & 0 deletions src/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class Graph
}

void SetMethodName(const std::string& name);
std::string GetMethodName() const;

void Dump(std::ostream &out);

#define CREATE_CREATORS(OPCODE) \
Expand All @@ -39,9 +41,37 @@ class Graph

#undef CREATE_CREATORS

template <typename T>
auto* CreateInstByIndex(uint32_t index) {
if (!unit_test_mode_) {
std::cerr << "Function only for unit tests\n";
std::abort();
}
auto inst = new T();
inst->SetId(index);
if (index >= all_inst_.size()) {
all_inst_.resize(index + 1);
}
all_inst_.at(index) = inst;
return inst;
}

Inst *GetInstByIndex(uint32_t index) {
return all_inst_.at(index);
}

void SetUnitTestMode() {
unit_test_mode_ = true;
}

uint32_t GetNumInsts() {
return all_inst_.size();
}

private:
std::vector<Inst *> all_inst_;
std::string name_method_;
bool unit_test_mode_;
};

};
47 changes: 44 additions & 3 deletions src/inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,20 @@ class Inst
return opc_;
}

virtual Inst *GetInput(uint32_t index) {
std::cerr << "Inst with opcode " << OPCODE_NAME[static_cast<size_t>(GetOpcode())] << " don't have inputs";
std::abort();
}

virtual void SetInput(uint32_t index, Inst *inst) {
std::cerr << "Inst with opcode " << OPCODE_NAME[static_cast<size_t>(GetOpcode())] << " don't have inputs";
std::abort();
}

virtual uint32_t NumInputs() {
return 0;
}

virtual void DumpInputs(std::ostream &out) const {};


Expand All @@ -89,6 +103,10 @@ class FixedInputs: public Inst
Inst(opc),
inputs_() {};

FixedInputs(Opcode opc, Type type):
Inst(opc, type),
inputs_() {};

FixedInputs(Opcode opc, Type type, std::array<Inst *, N> inputs) :
Inst(opc, type),
inputs_(inputs) {
Expand All @@ -98,19 +116,23 @@ class FixedInputs: public Inst
}
};

Inst *GetInput(uint32_t index) {
virtual Inst *GetInput(uint32_t index) override {
return inputs_.at(index);
}

const std::array<Inst *, N>& GetInputs() const {
return inputs_;
}

void SetInput(uint32_t index, Inst *inst) {
virtual void SetInput(uint32_t index, Inst *inst) override {
assert(index < N);
inputs_.at(index) = inst;
}

virtual uint32_t NumInputs() override {
return inputs_.size();
}

void DumpInputs(std::ostream &out) const override {
bool first = true;
for (auto inst : GetInputs()) {
Expand Down Expand Up @@ -155,10 +177,26 @@ class DynamicInputs: public Inst
inputs_.erase(it);
}

void SetInput(std::list<Inst *>::const_iterator &it, Inst *inst) {
virtual void SetInput(uint32_t index, Inst *inst) override {
auto it = inputs_.begin();
for (int i = 0; i < index; i++) {
it++;
}
inputs_.insert(it, inst);
}

virtual Inst *GetInput(uint32_t index) override {
auto it = inputs_.begin();
for (int i = 0; i < index; i++) {
it++;
}
return *it;
}

virtual uint32_t NumInputs() override {
return inputs_.size();
}

void DumpInputs(std::ostream &out) const override {
bool first = true;
for (auto inst : GetInputs()) {
Expand Down Expand Up @@ -197,6 +235,9 @@ class ImmidiateProperty
class AddInst : public FixedInputs<2>
{
public:
AddInst():
FixedInputs<2>(Opcode::Add) {}

AddInst(Type type, Inst *input0, Inst *input1):
FixedInputs<2>(Opcode::Add, type, {{input0, input1}}) {}

Expand Down
62 changes: 62 additions & 0 deletions src/ir_constructor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#pragma once

#include "graph.h"
#include "inst.h"
#include <type_traits>

namespace compiler {

class IrConstructor {
public:
IrConstructor():
graph_(new Graph()) {
graph_->SetUnitTestMode();
};

~IrConstructor() {
delete graph_;
}
template<typename T>
IrConstructor &CreateInst(uint32_t index) {
auto inst = graph_->CreateInstByIndex<T>(index);
current_inst_ = inst;
return *this;
}

IrConstructor &Imm(int64_t imm) {
assert(current_inst_ != nullptr);

switch(current_inst_->GetOpcode()) {
case Opcode::Constant: {
static_cast<ConstantInst *>(current_inst_)->SetImm(imm);
break;
}
default: {
assert(false && ("Should be unreachable!"));
}
}
return *this;
}

template<typename... Args>
IrConstructor &Inputs(Args ...args) {
for (auto it : {args...}) {
static_assert(std::is_same<decltype(it), int>(), "Is not \"Int\" in argument");
}
auto inputs = std::vector<int>({args...});
for (int i = 0; i < inputs.size(); i++) {
current_inst_->SetInput(i, graph_->GetInstByIndex(inputs.at(i)));
}
return *this;
}

Graph *GetGraph() {
return graph_;
}

private:
Graph *graph_;
Inst *current_inst_;
};

}
1 change: 1 addition & 0 deletions src/opcodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ constexpr std::array<const char *const, static_cast<size_t>(Opcode::NUM_OPCODES)

enum class Type {
NONE = 0,
BOOL,
INT32,
UINT32,
INT64,
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ include_directories(
add_executable(
graph_tests
graph_tests.cpp
graph_comparator.cpp
${COMPILER_SOURCES}
)

Expand Down
42 changes: 42 additions & 0 deletions tests/graph_comparator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include "graph_comparator.h"
#include "inst.h"

namespace compiler {

void GraphComparator::Compare() {
ASSERT_EQ(left_->GetMethodName(), right_->GetMethodName());
ASSERT_EQ(left_->GetNumInsts(), right_->GetNumInsts());
for (int i = 0; i < left_->GetNumInsts(); i++) {
auto left_inst = left_->GetInstByIndex(i);
auto right_inst = right_->GetInstByIndex(i);

CompareInstructions(left_inst, right_inst);
}
}

void GraphComparator::CompareInstructions(Inst *left, Inst *right) {
ASSERT_EQ(left->GetId(), right->GetId());
ASSERT_EQ(left->GetOpcode(), right->GetOpcode());
ASSERT_EQ(left->GetType(), right->GetType());
CompareInputs(left, right);

// Specific checks
switch (left->GetOpcode()) {
case Opcode::Constant: {
ASSERT_EQ(static_cast<ConstantInst *>(left)->GetImm(), static_cast<ConstantInst *>(right)->GetImm());
break;
}
default: {
break;
}
}
}

void GraphComparator::CompareInputs(Inst *left, Inst *right) {
ASSERT_EQ(left->NumInputs(), right->NumInputs());
for (int i = 0; i < left->NumInputs(); i++) {
ASSERT_EQ(left->GetInput(i)->GetId(), right->GetInput(i)->GetId());
}
}

}
25 changes: 25 additions & 0 deletions tests/graph_comparator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include "graph.h"
#include <gtest/gtest.h>

namespace compiler {

class GraphComparator
{
public:
GraphComparator(Graph *left, Graph *right):
left_(left), right_(right) {}

void Compare();

private:
void CompareInstructions(Inst *left, Inst *right);
void CompareInputs(Inst *left, Inst *right);

private:
Graph *left_;
Graph *right_;
};

}
17 changes: 16 additions & 1 deletion tests/graph_tests.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#include <gtest/gtest.h>
#include <ostream>
#include "inst.h"
#include "src/graph.h"
#include "graph.h"

#include "ir_constructor.h"
#include "tests/graph_comparator.h"

namespace compiler {

Expand Down Expand Up @@ -139,4 +142,16 @@ TEST(GraphTest, CreateIfFullWorkGraph) {
ASSERT_EQ(dump_out.str(), output);
}

TEST(GraphTest, TestConstructorAndComparator) {
auto ic = IrConstructor();
ic.CreateInst<ConstantInst>(0).Imm(123);
ic.CreateInst<AddInst>(1).Inputs(0, 0);

auto ic_after = IrConstructor();
ic_after.CreateInst<ConstantInst>(0).Imm(123);
ic_after.CreateInst<AddInst>(1).Inputs(0, 0);

GraphComparator(ic.GetGraph(), ic_after.GetGraph()).Compare();
}

}

0 comments on commit e1f03da

Please sign in to comment.