From e1f03dab2c106ef4b5340fb0af633ae79e5d213f Mon Sep 17 00:00:00 2001 From: Mikhail Kaskov Date: Wed, 4 Oct 2023 01:09:00 +0300 Subject: [PATCH] Added IR Constuctor for UnitTests + Added simple GraphComparator for UnitTests based on gtest + Added new instruction Compare --- src/graph.cpp | 6 +++- src/graph.h | 30 ++++++++++++++++++ src/inst.h | 47 +++++++++++++++++++++++++++-- src/ir_constructor.h | 62 ++++++++++++++++++++++++++++++++++++++ src/opcodes.h | 1 + tests/CMakeLists.txt | 1 + tests/graph_comparator.cpp | 42 ++++++++++++++++++++++++++ tests/graph_comparator.h | 25 +++++++++++++++ tests/graph_tests.cpp | 17 ++++++++++- 9 files changed, 226 insertions(+), 5 deletions(-) create mode 100644 src/ir_constructor.h create mode 100644 tests/graph_comparator.cpp create mode 100644 tests/graph_comparator.h diff --git a/src/graph.cpp b/src/graph.cpp index 945a0d1..46b427f 100644 --- a/src/graph.cpp +++ b/src/graph.cpp @@ -1,5 +1,4 @@ #include "graph.h" -#include "graph.h" #include #include @@ -19,4 +18,9 @@ void Graph::SetMethodName(const std::string& name) name_method_ = name; } +std::string Graph::GetMethodName() const +{ + return name_method_; +} + } diff --git a/src/graph.h b/src/graph.h index f66f812..f12ba8c 100644 --- a/src/graph.h +++ b/src/graph.h @@ -24,6 +24,8 @@ class Graph } void SetMethodName(const std::string& name); + std::string GetMethodName() const; + void Dump(std::ostream &out); #define CREATE_CREATORS(OPCODE) \ @@ -39,9 +41,37 @@ class Graph #undef CREATE_CREATORS + template + auto* CreateInstByIndex(uint32_t index) { + if (!unit_test_mode_) { + std::cerr << "Function only for unit tests\n"; + std::abort(); + } + auto inst = new T(); + inst->SetId(index); + if (index >= all_inst_.size()) { + all_inst_.resize(index + 1); + } + all_inst_.at(index) = inst; + return inst; + } + + Inst *GetInstByIndex(uint32_t index) { + return all_inst_.at(index); + } + + void SetUnitTestMode() { + unit_test_mode_ = true; + } + + uint32_t GetNumInsts() { + return all_inst_.size(); + } + private: std::vector all_inst_; std::string name_method_; + bool unit_test_mode_; }; }; diff --git a/src/inst.h b/src/inst.h index 0d41128..3ad308e 100644 --- a/src/inst.h +++ b/src/inst.h @@ -68,6 +68,20 @@ class Inst return opc_; } + virtual Inst *GetInput(uint32_t index) { + std::cerr << "Inst with opcode " << OPCODE_NAME[static_cast(GetOpcode())] << " don't have inputs"; + std::abort(); + } + + virtual void SetInput(uint32_t index, Inst *inst) { + std::cerr << "Inst with opcode " << OPCODE_NAME[static_cast(GetOpcode())] << " don't have inputs"; + std::abort(); + } + + virtual uint32_t NumInputs() { + return 0; + } + virtual void DumpInputs(std::ostream &out) const {}; @@ -89,6 +103,10 @@ class FixedInputs: public Inst Inst(opc), inputs_() {}; + FixedInputs(Opcode opc, Type type): + Inst(opc, type), + inputs_() {}; + FixedInputs(Opcode opc, Type type, std::array inputs) : Inst(opc, type), inputs_(inputs) { @@ -98,7 +116,7 @@ class FixedInputs: public Inst } }; - Inst *GetInput(uint32_t index) { + virtual Inst *GetInput(uint32_t index) override { return inputs_.at(index); } @@ -106,11 +124,15 @@ class FixedInputs: public Inst return inputs_; } - void SetInput(uint32_t index, Inst *inst) { + virtual void SetInput(uint32_t index, Inst *inst) override { assert(index < N); inputs_.at(index) = inst; } + virtual uint32_t NumInputs() override { + return inputs_.size(); + } + void DumpInputs(std::ostream &out) const override { bool first = true; for (auto inst : GetInputs()) { @@ -155,10 +177,26 @@ class DynamicInputs: public Inst inputs_.erase(it); } - void SetInput(std::list::const_iterator &it, Inst *inst) { + virtual void SetInput(uint32_t index, Inst *inst) override { + auto it = inputs_.begin(); + for (int i = 0; i < index; i++) { + it++; + } inputs_.insert(it, inst); } + virtual Inst *GetInput(uint32_t index) override { + auto it = inputs_.begin(); + for (int i = 0; i < index; i++) { + it++; + } + return *it; + } + + virtual uint32_t NumInputs() override { + return inputs_.size(); + } + void DumpInputs(std::ostream &out) const override { bool first = true; for (auto inst : GetInputs()) { @@ -197,6 +235,9 @@ class ImmidiateProperty class AddInst : public FixedInputs<2> { public: + AddInst(): + FixedInputs<2>(Opcode::Add) {} + AddInst(Type type, Inst *input0, Inst *input1): FixedInputs<2>(Opcode::Add, type, {{input0, input1}}) {} diff --git a/src/ir_constructor.h b/src/ir_constructor.h new file mode 100644 index 0000000..18a50c2 --- /dev/null +++ b/src/ir_constructor.h @@ -0,0 +1,62 @@ +#pragma once + +#include "graph.h" +#include "inst.h" +#include + +namespace compiler { + +class IrConstructor { +public: + IrConstructor(): + graph_(new Graph()) { + graph_->SetUnitTestMode(); + }; + + ~IrConstructor() { + delete graph_; + } + template + IrConstructor &CreateInst(uint32_t index) { + auto inst = graph_->CreateInstByIndex(index); + current_inst_ = inst; + return *this; + } + + IrConstructor &Imm(int64_t imm) { + assert(current_inst_ != nullptr); + + switch(current_inst_->GetOpcode()) { + case Opcode::Constant: { + static_cast(current_inst_)->SetImm(imm); + break; + } + default: { + assert(false && ("Should be unreachable!")); + } + } + return *this; + } + + template + IrConstructor &Inputs(Args ...args) { + for (auto it : {args...}) { + static_assert(std::is_same(), "Is not \"Int\" in argument"); + } + auto inputs = std::vector({args...}); + for (int i = 0; i < inputs.size(); i++) { + current_inst_->SetInput(i, graph_->GetInstByIndex(inputs.at(i))); + } + return *this; + } + + Graph *GetGraph() { + return graph_; + } + +private: + Graph *graph_; + Inst *current_inst_; +}; + +} diff --git a/src/opcodes.h b/src/opcodes.h index 19ee037..13c6d8b 100644 --- a/src/opcodes.h +++ b/src/opcodes.h @@ -39,6 +39,7 @@ constexpr std::array(Opcode::NUM_OPCODES) enum class Type { NONE = 0, + BOOL, INT32, UINT32, INT64, diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c2dacb0..aec92e5 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -7,6 +7,7 @@ include_directories( add_executable( graph_tests graph_tests.cpp + graph_comparator.cpp ${COMPILER_SOURCES} ) diff --git a/tests/graph_comparator.cpp b/tests/graph_comparator.cpp new file mode 100644 index 0000000..f289e35 --- /dev/null +++ b/tests/graph_comparator.cpp @@ -0,0 +1,42 @@ +#include "graph_comparator.h" +#include "inst.h" + +namespace compiler { + +void GraphComparator::Compare() { + ASSERT_EQ(left_->GetMethodName(), right_->GetMethodName()); + ASSERT_EQ(left_->GetNumInsts(), right_->GetNumInsts()); + for (int i = 0; i < left_->GetNumInsts(); i++) { + auto left_inst = left_->GetInstByIndex(i); + auto right_inst = right_->GetInstByIndex(i); + + CompareInstructions(left_inst, right_inst); + } +} + +void GraphComparator::CompareInstructions(Inst *left, Inst *right) { + ASSERT_EQ(left->GetId(), right->GetId()); + ASSERT_EQ(left->GetOpcode(), right->GetOpcode()); + ASSERT_EQ(left->GetType(), right->GetType()); + CompareInputs(left, right); + + // Specific checks + switch (left->GetOpcode()) { + case Opcode::Constant: { + ASSERT_EQ(static_cast(left)->GetImm(), static_cast(right)->GetImm()); + break; + } + default: { + break; + } + } +} + +void GraphComparator::CompareInputs(Inst *left, Inst *right) { + ASSERT_EQ(left->NumInputs(), right->NumInputs()); + for (int i = 0; i < left->NumInputs(); i++) { + ASSERT_EQ(left->GetInput(i)->GetId(), right->GetInput(i)->GetId()); + } +} + +} diff --git a/tests/graph_comparator.h b/tests/graph_comparator.h new file mode 100644 index 0000000..6441601 --- /dev/null +++ b/tests/graph_comparator.h @@ -0,0 +1,25 @@ +#pragma once + +#include "graph.h" +#include + +namespace compiler { + +class GraphComparator +{ +public: + GraphComparator(Graph *left, Graph *right): + left_(left), right_(right) {} + + void Compare(); + +private: + void CompareInstructions(Inst *left, Inst *right); + void CompareInputs(Inst *left, Inst *right); + +private: + Graph *left_; + Graph *right_; +}; + +} diff --git a/tests/graph_tests.cpp b/tests/graph_tests.cpp index 04ed839..97acd21 100644 --- a/tests/graph_tests.cpp +++ b/tests/graph_tests.cpp @@ -1,7 +1,10 @@ #include #include #include "inst.h" -#include "src/graph.h" +#include "graph.h" + +#include "ir_constructor.h" +#include "tests/graph_comparator.h" namespace compiler { @@ -139,4 +142,16 @@ TEST(GraphTest, CreateIfFullWorkGraph) { ASSERT_EQ(dump_out.str(), output); } +TEST(GraphTest, TestConstructorAndComparator) { + auto ic = IrConstructor(); + ic.CreateInst(0).Imm(123); + ic.CreateInst(1).Inputs(0, 0); + + auto ic_after = IrConstructor(); + ic_after.CreateInst(0).Imm(123); + ic_after.CreateInst(1).Inputs(0, 0); + + GraphComparator(ic.GetGraph(), ic_after.GetGraph()).Compare(); +} + }