Skip to content

Commit

Permalink
Added new control flow insts
Browse files Browse the repository at this point in the history
+ Added insts Start, Region, If
+ Added test on full correct simple graph
  • Loading branch information
techie-mike committed Sep 27, 2023
1 parent e3d68d4 commit 4cc05e8
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 7 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
.cache
build
compile_commands.json
140 changes: 134 additions & 6 deletions src/inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,33 @@

#include <cstdint>
#include <array>
#include <vector>
#include <ios>
#include <list>
#include <assert.h>
#include <initializer_list>
#include <string>
#include <iostream>
#include <algorithm>

#include "opcodes.h"

namespace compiler {

/* Full list of instructions, plus and minus show what support in compiler at the moment:
/* ======================================================================================
* Full list of instructions, plus and minus show what support in compiler at the moment:
* + Constant
* + Add
* - Mul
* - Region
* - Start
* - If
* + Region
* + Start
* + If
* - Jmp
* - Phi
* - Return
* - Compare
* - Parameter
* ======================================================================================
*/

class Inst
Expand All @@ -34,12 +38,14 @@ class Inst
opc_(Opcode::NONE),
type_(Type::NONE) {};

Inst(Opcode opc):
opc_(opc),
type_(Type::NONE) {};

Inst(Opcode opc, Type type):
opc_(opc),
type_(type) {};

// virtual ~Inst() = default;

void Dump(std::ostream& out);

uint32_t GetId() const {
Expand Down Expand Up @@ -79,6 +85,10 @@ class FixedInputs: public Inst
FixedInputs():
inputs_() {};

FixedInputs(Opcode opc):
Inst(opc),
inputs_() {};

FixedInputs(Opcode opc, Type type, std::array<Inst *, N> inputs) :
Inst(opc, type),
inputs_(inputs) {
Expand All @@ -97,12 +107,16 @@ class FixedInputs: public Inst
}

void SetInput(uint32_t index, Inst *inst) {
assert(index < N);
inputs_.at(index) = inst;
}

void DumpInputs(std::ostream &out) const override {
bool first = true;
for (auto inst : GetInputs()) {
if (inst == nullptr) {
continue;
}
out << std::string(first ? "" : ", ") << std::string("v") << std::to_string(inst->GetId());
first = false;
}
Expand All @@ -112,6 +126,54 @@ class FixedInputs: public Inst
std::array<Inst *, N> inputs_;
};

class DynamicInputs: public Inst
{
public:
DynamicInputs():
Inst(),
inputs_() {};

DynamicInputs(Opcode opc):
Inst(opc),
inputs_() {};

DynamicInputs(Opcode opc, const std::list<Inst *>& inputs):
Inst(opc),
inputs_(inputs) {};

const std::list<Inst *>& GetInputs() const {
return inputs_;
}

void AddInput(Inst *inst) {
inputs_.push_back(inst);
}

void DeleteInput(Inst *inst) {
auto it = std::find(inputs_.begin(), inputs_.end(), inst);
assert(it == inputs_.end());
inputs_.erase(it);
}

void SetInput(std::list<Inst *>::const_iterator &it, Inst *inst) {
inputs_.insert(it, inst);
}

void DumpInputs(std::ostream &out) const override {
bool first = true;
for (auto inst : GetInputs()) {
if (inst == nullptr) {
continue;
}
out << std::string(first ? "" : ", ") << std::string("v") << std::to_string(inst->GetId());
first = false;
}
}

private:
std::list<Inst *> inputs_;
};

using ImmType = int64_t;
class ImmidiateProperty
{
Expand Down Expand Up @@ -153,4 +215,70 @@ class ConstantInst : public Inst, public ImmidiateProperty
private:
};

class StartInst : public FixedInputs<1>
{
public:
StartInst():
FixedInputs<1>(Opcode::Start) {}
};

class RegionInst : public DynamicInputs
{
public:
RegionInst():
DynamicInputs(Opcode::Region) {}

void SetUser(Inst *inst) {
user_ = inst;
}

private:
Inst *user_;
};

class IfInst : public FixedInputs<2>
{
public:
// First input is Region
// Second input is Bool condition value
IfInst():
FixedInputs<2>(Opcode::If) {}

Inst *GetTrueBranch() {
return GetBranch<BranchWay::True>();
}

Inst *GetFalseBranch() {
return GetBranch<BranchWay::False>();
}

void SetTrueBranch(Inst *inst) {
SetBranch<BranchWay::True>(inst);
}

void SetFalseBranch(Inst *inst) {
SetBranch<BranchWay::False>(inst);
}

private:
enum class BranchWay {
True = 0,
False
};

template<BranchWay V>
Inst *GetBranch() {
return branchs_[static_cast<size_t>(V)];
}

template<BranchWay V>
void SetBranch(Inst *inst) {
assert(inst->GetOpcode() == Opcode::Region);
static_cast<RegionInst *>(inst)->AddInput(this);
branchs_[static_cast<size_t>(V)] = inst;
}

std::array<Inst *, 2> branchs_;
};

}
6 changes: 5 additions & 1 deletion src/opcodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ namespace compiler {

#define OPCODE_LIST(ACTION) \
ACTION( Add ) \
ACTION( Constant )
ACTION( Constant ) \
ACTION( Start ) \
ACTION( If ) \
ACTION( Region )


enum class Opcode {
NONE = 0,
Expand Down
88 changes: 88 additions & 0 deletions tests/graph_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,92 @@ TEST(GraphTest, CreateAdd) {
ASSERT_EQ(dump_out.str(), output);
}

TEST(GraphTest, CreateStart) {
Graph graph;
auto *start = graph.CreateStartInst();
auto *end = graph.CreateStartInst();
end->SetInput(0, start);

std::ostringstream dump_out;
std::string output =
"Method: \n"
"Instructions:\n"
" 0. Start \n"
" 1. Start v0\n";
graph.Dump(dump_out);
ASSERT_EQ(dump_out.str(), output);
}

TEST(GraphTest, CreateRegions) {
Graph graph;
auto *start = graph.CreateStartInst();
auto *reg1 = graph.CreateRegionInst();
reg1->AddInput(start);
auto *end = graph.CreateStartInst();
end->SetInput(0, reg1);

std::ostringstream dump_out;
std::string output =
"Method: \n"
"Instructions:\n"
" 0. Start \n"
" 1. Region v0\n"
" 2. Start v1\n";
graph.Dump(dump_out);
ASSERT_EQ(dump_out.str(), output);
}

/*
* Start
* |
* \/
* Region <-+
* Constant 0 | |
* | | |
* \/ \/ |
* ------------ |
* | If | |
* ------------ |
* |True|False| |
* ------------ |
* | | |
* \/ +-->--+
* Region
* |
* \/
* Start
*/
TEST(GraphTest, CreateIfFullWorkGraph) {
Graph graph;
auto *start = graph.CreateStartInst();
auto *reg_loop = graph.CreateRegionInst();
reg_loop->AddInput(start);

auto *cnst = graph.CreateConstantInst();
auto *if_inst = graph.CreateIfInst();
if_inst->SetInput(0, reg_loop);
if_inst->SetInput(1, cnst);

auto *region_end = graph.CreateRegionInst();
if_inst->SetTrueBranch(region_end);
if_inst->SetFalseBranch(reg_loop);

auto *end = graph.CreateStartInst();
end->SetInput(0, region_end);

std::ostringstream dump_out;
std::string output =
"Method: \n"
"Instructions:\n"
" 0. Start \n"
" 1. Region v0, v3\n"
" 2. Constant 0x0\n"
" 3. If v1, v2\n"
" 4. Region v3\n"
" 5. Start v4\n";

graph.Dump(dump_out);
ASSERT_EQ(dump_out.str(), output);
}

}

0 comments on commit 4cc05e8

Please sign in to comment.