diff --git a/CMakeLists.txt b/CMakeLists.txt index e49bc01..fde2605 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,6 +22,7 @@ set(COMPILER_SOURCES ${CMAKE_SOURCE_DIR}/src/optimizations/linear_scan.cpp ${CMAKE_SOURCE_DIR}/src/optimizations/peepholes.cpp ${CMAKE_SOURCE_DIR}/src/optimizations/constant_folding.cpp + ${CMAKE_SOURCE_DIR}/src/optimizations/inlining.cpp ) add_library(CompilerLibBase ${COMPILER_SOURCES}) diff --git a/src/graph.cpp b/src/graph.cpp index e9af296..97abb5a 100644 --- a/src/graph.cpp +++ b/src/graph.cpp @@ -33,6 +33,13 @@ std::string Graph::GetMethodName() const return name_method_; } +void Graph::SetNumParams(uint32_t num) { + num_params_ = num; +} +uint32_t Graph::GetNumParams() { + return num_params_; +} + Inst *Graph::CreateClearInstByOpcode(Opcode opc) { switch(opc) { @@ -100,7 +107,7 @@ void Graph::DumpDomTree(std::ostream &out) { } } -void Graph::DumpPlacedInsts(std::ostream &out) { +void Graph::DumpPlacedInsts(std::ostream &out) const { bool first = true; out << "Instructions is PLACED:" << std::endl; for (auto region : all_regions_) { @@ -116,6 +123,29 @@ void Graph::DumpPlacedInsts(std::ostream &out) { } } +void Graph::AddInst(Inst *inst) { + all_inst_.push_back(inst); +} +void Graph::DeleteInst(Inst *inst) { + // Delete all inputs + for (id_t i = 0; i < inst->NumAllInputs(); i++) { + if (inst->GetRawInput(i) != nullptr) { + inst->GetRawInput(i)->DeleteRawUser(inst); + } + } + // Delete all users + if (inst->NumDataUsers() > 0) { + for (auto it = inst->GetRawUsers().begin(); it != inst->GetRawUsers().end(); it = (*it == nullptr) ? ++it : inst->GetRawUsers().begin()) { + if (*it != nullptr) { + (*it)->DeleteInput(inst); + inst->DeleteRawUser(*it); + } + } + } + auto del = std::find(all_inst_.begin(), all_inst_.end(), inst); + *del = nullptr; + deleted_insts_.push_back(inst); +} } diff --git a/src/graph.h b/src/graph.h index 8b8131a..74fa82f 100644 --- a/src/graph.h +++ b/src/graph.h @@ -18,13 +18,24 @@ class Graph } delete inst; } + for (auto inst : deleted_insts_) { + if (inst == nullptr) { + continue; + } + delete inst; + } } void SetMethodName(const std::string& name); std::string GetMethodName() const; + void SetNumParams(uint32_t num); + uint32_t GetNumParams(); + void Dump(std::ostream &out); void DumpDomTree(std::ostream &out); + void AddInst(Inst *inst); + void DeleteInst(Inst *inst); #define CREATE_CREATORS(OPCODE, BASE) \ template \ @@ -89,7 +100,7 @@ class Graph unit_test_mode_ = true; } - size_t GetNumInsts() { + size_t GetNumInsts() const { return all_inst_.size(); } @@ -117,7 +128,7 @@ class Graph root_loop_ = loop; } - void DumpPlacedInsts(std::ostream &out); + void DumpPlacedInsts(std::ostream &out) const; RegionInst *GetStartRegion() { return GetInstByIndex(0)->CastToRegion(); @@ -127,7 +138,7 @@ class Graph insts_placed_ = true; } - bool IsInstsPlaced() { + bool IsInstsPlaced() const { return insts_placed_; } @@ -135,10 +146,12 @@ class Graph bool unit_test_mode_ = false; bool insts_placed_ = false; uint32_t num_loops_ = 0; + uint32_t num_params_ = 0; Loop *root_loop_ = nullptr; std::string name_method_; std::vector all_inst_; std::vector all_regions_; + std::vector deleted_insts_; }; } diff --git a/src/inst.cpp b/src/inst.cpp index 0636699..7daa45b 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -50,6 +50,7 @@ void Inst::SetDataInput(id_t index, Inst *inst) { Inst *Inst::GetDataInput(id_t index) { ASSERT(!IsRegion()); // RegionInst has special method GetRegionInput + ASSERT(index < NumDataInputs()); if (HasControlProp()) { index++; } @@ -77,6 +78,13 @@ void Inst::DeleteDataUser(Inst *inst) { users_.erase(it); } +void Inst::DeleteRawUser(Inst *inst) { + auto it = std::find(GetRawUsers().begin(), GetRawUsers().end(), inst); + if (it != GetRawUsers().end()) { + users_.erase(it); + } +} + uint32_t Inst::NumDataUsers() { return HasControlProp() ? GetRawUsers().size() - 1 : GetRawUsers().size(); } @@ -193,15 +201,15 @@ void PhiInst::DumpInputs(std::ostream &out) { } void DynamicInputs::DumpInputs(std::ostream &out) { - bool first = true; - for (auto inst : GetAllInputs()) { - if (inst == nullptr) { - continue; - } - out << std::string(first ? "" : ", ") << std::string("v") << std::to_string(inst->GetId()); - first = false; + bool first = true; + for (auto inst : GetAllInputs()) { + if (inst == nullptr) { + continue; } + out << std::string(first ? "" : ", ") << std::string("v") << std::to_string(inst->GetId()); + first = false; } +} std::string OpcodeToString(Opcode opc) { return OPCODE_NAME.at(static_cast(opc)); @@ -270,10 +278,27 @@ ConstantInst *Inst::CastToConstant() { return static_cast(this); } +ParameterInst *Inst::CastToParameter() { + ASSERT(GetOpcode() == Opcode::Parameter); + return static_cast(this); +} + +CallInst *Inst::CastToCall() { + ASSERT(GetOpcode() == Opcode::Call); + return static_cast(this); +} + +JumpInst *Inst::CastToJump() { + ASSERT(GetOpcode() == Opcode::Jump); + return static_cast(this); +} + void Inst::ReplaceDataUsers(Inst *from) { - ASSERT(!HasControlProp()); for (auto it = from->StartIteratorDataUsers(); it != from->GetRawUsers().end(); it = from->StartIteratorDataUsers()) { auto user = *it; + if (user == nullptr) { + continue; + } for (id_t i = 0; i < user->NumDataInputs(); i++) { if (user->GetDataInput(i) == from) { user->SetDataInput(i, this); @@ -283,4 +308,97 @@ void Inst::ReplaceDataUsers(Inst *from) { from->GetRawUsers().erase(from->StartIteratorDataUsers(), from->GetRawUsers().end()); } +void CallInst::SetNameFunc(const std::string &name) { + name_func_ = name; +} + +std::string CallInst::GetNameFunc() const { + return name_func_; +} + +Inst *Inst::LiteClone(Graph *target_graph, std::map &connect) { + auto new_inst = target_graph->CreateClearInstByOpcode(GetOpcode()); + new_inst->type_ = type_; + new_inst->SetId(target_graph->GetNumInsts()); + target_graph->AddInst(new_inst); + + // To many specific cases in common code! + auto opc = GetOpcode(); + if (opc == Opcode::Start || opc == Opcode::Parameter || opc == Opcode::Region || opc == Opcode::End) { + return new_inst; + } + // if (opc == Opcode::Region || opc == Opcode::End) { + // auto num_inputs = CastToRegion()->NumRegionInputs(); + // for (id_t index = 0; index < num_inputs; index++) { + // new_inst->CastToRegion()->SetRegionInput(index, target_graph->GetInstByIndex(connect[CastToRegion()->GetRegionInput(index)->GetId()])); + // } + // return new_inst; + // } + + if (HasControlProp() && GetControlInput() != nullptr) { + new_inst->SetControlInput(target_graph->GetInstByIndex(connect[GetControlInput()->GetId()])); + } + auto num_inputs = NumDataInputs(); + for (id_t index = 0; index < num_inputs; index++) { + // If id_t > num_all_inst in graph, write input in order (case for Jump and If) + new_inst->SetDataInput(index, target_graph->GetInstByIndex(connect[GetDataInput(index)->GetId()])); + } + return new_inst; +} + +Inst *ConstantInst::LiteClone(Graph *target_graph, std::map &connect) { + auto new_inst = static_cast(Inst::LiteClone(target_graph, connect)); + new_inst->SetImm(GetImm()); + return new_inst; +} + +Inst *CompareInst::LiteClone(Graph *target_graph, std::map &connect) { + auto new_inst = static_cast(FixedInputs<2>::LiteClone(target_graph, connect)); + new_inst->SetCC(GetCC()); + return new_inst; +} + +Inst *ParameterInst::LiteClone(Graph *target_graph, std::map &connect) { + auto new_inst = static_cast(Inst::LiteClone(target_graph, connect)); + new_inst->SetIndexParam(GetIndexParam()); + return new_inst; +} + +Inst *CallInst::LiteClone(Graph *target_graph, std::map &connect) { + auto new_inst = static_cast(Inst::LiteClone(target_graph, connect)); + new_inst->SetNameFunc(GetNameFunc()); + return new_inst; +} + +void CallInst::DumpInputs(std::ostream &out) { + out << "\"" << GetNameFunc() << "\" "; + Base::DumpInputs(out); +} + +void ParameterInst::DumpInputs(std::ostream &out) { + out << "\"" << GetIndexParam() << "\" "; + Inst::DumpInputs(out); +} + +void Inst::ReplaceAllUsers(Inst *from) { + ReplaceDataUsers(from); + ReplaceCtrUser(from); +} + +void Inst::ReplaceCtrUser(Inst *from) { + ASSERT(HasSingleDataUser()); + auto c_user = from->GetControlUser(); + if (c_user == nullptr) { + return; + } + SetControlUser(c_user); + from->SetControlUser(nullptr); +} + +void DynamicInputs::DeleteInput(Inst *inst) { + auto it = std::find(inputs_.begin(), inputs_.end(), inst); + ASSERT(it != inputs_.end()); + inputs_.erase(it); +} + } diff --git a/src/inst.h b/src/inst.h index d9ac8d7..0c5a3f1 100644 --- a/src/inst.h +++ b/src/inst.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -42,9 +43,13 @@ std::string OpcodeToString(Opcode opc); std::string CcToString(ConditionCode cc); std::string TypeToString(Type type); +class Graph; class RegionInst; class IfInst; +class JumpInst; class ConstantInst; +class ParameterInst; +class CallInst; class Inst { @@ -62,6 +67,13 @@ class Inst type_(type) {}; virtual ~Inst() = default; + virtual Inst *LiteClone(Graph *target_graph, std::map &connect); + + virtual void DeleteInput([[maybe_unused]] Inst *inst) { + std::cerr << "Inst with opcode " << OPCODE_NAME[static_cast(GetOpcode())] << " don't have inputs\n"; + UNREACHABLE(); + } + void Dump(std::ostream& out); virtual void DumpOpcode(std::ostream& out); @@ -97,7 +109,6 @@ class Inst uint32_t NumDataInputs() { ASSERT(!IsRegion()); // RegionInst has special method NumRegionInputs auto num_all = NumAllInputs(); - ASSERT(num_all > 0); return HasControlProp() ? num_all - 1 : num_all; } @@ -117,6 +128,7 @@ class Inst void AddDataUser(Inst *inst); void DeleteDataUser(Inst *inst); + void DeleteRawUser(Inst *inst); bool HasSingleDataUser() { return NumDataUsers() == 1; @@ -165,6 +177,10 @@ class Inst return GetOpcode() == Opcode::Constant; } + bool IsCall() const { + return GetOpcode() == Opcode::Call; + } + Inst *GetPrev() { return prev_; } @@ -179,7 +195,10 @@ class Inst RegionInst *CastToRegion(); IfInst *CastToIf(); + JumpInst *CastToJump(); ConstantInst *CastToConstant(); + ParameterInst *CastToParameter(); + CallInst *CastToCall(); bool IsPlaced() const { return inst_placed_; @@ -206,6 +225,8 @@ class Inst } void ReplaceDataUsers(Inst *from); + void ReplaceAllUsers(Inst *from); + void ReplaceCtrUser(Inst *from); private: auto StartIteratorDataUsers() { @@ -256,6 +277,12 @@ class FixedInputs: public Inst return inputs_.size(); } + virtual void DeleteInput(Inst *inst) override { + auto it = std::find(inputs_.begin(), inputs_.end(), inst); + ASSERT(it != inputs_.end()); + *it = nullptr; + } + virtual void DumpInputs(std::ostream &out) override { bool first = true; for (auto inst : GetAllInputs()) { @@ -311,11 +338,7 @@ class DynamicInputs: public Inst inputs_.push_back(inst); } - void DeleteInput(Inst *inst) { - auto it = std::find(inputs_.begin(), inputs_.end(), inst); - ASSERT(it == inputs_.end()); - inputs_.erase(it); - } + virtual void DeleteInput(Inst *inst) override; virtual uint32_t NumAllInputs() override { return inputs_.size(); @@ -424,6 +447,8 @@ class ConstantInst : public Inst, public ImmidiateProperty Inst(Opcode::Constant, Type::INT64), ImmidiateProperty(value) {} + virtual Inst *LiteClone(Graph *target_graph, std::map &connect) override; + virtual void DumpInputs(std::ostream &out) override { out << std::string("0x") << std::hex << GetImm() << std::dec; } @@ -582,6 +607,7 @@ class IfInst : public ControlProp> *(++GetRawUsers().begin()) = inst; static_cast(inst)->SetRegionInput(inst->NumAllInputs(), this); } + // We can't copy a Jump, because can be jump on inst, whitch still haven't create virtual void DumpUsers(std::ostream &out) override; }; @@ -598,6 +624,12 @@ class JumpInst : public ControlProp> SetControlUser(inst); static_cast(inst)->SetRegionInput(inst->NumAllInputs(), this); } + + Inst *GetJumpTo() { + return GetControlUser(); + } + + // We can't copy a Jump, because can be jump on inst, whitch still haven't create }; class CompareInst : public FixedInputs<2> @@ -615,6 +647,7 @@ class CompareInst : public FixedInputs<2> } virtual void DumpOpcode(std::ostream& out) override; + virtual Inst *LiteClone(Graph *target_graph, std::map &connect) override; private: ConditionCode cc_; @@ -652,18 +685,14 @@ class ParameterInst : public Inst return idx_param_; } + virtual void DumpInputs(std::ostream &out) override; + + virtual Inst *LiteClone(Graph *target_graph, std::map &connect) override; + private: id_t idx_param_; }; -// Call "NameFunc" -// Inputs: -// 0) CFG element -// 1, ...) - arguments of function -// -// Users: -// 0) CFG element -// 1, ...) - users of return value class CallInst : public ControlProp { public: @@ -672,8 +701,11 @@ class CallInst : public ControlProp Base(Opcode::Call), name_func_() {}; - void SetCFGUser(Inst *inst); - Inst *GetCFGUser(); + void SetNameFunc(const std::string &name); + std::string GetNameFunc() const; + + virtual Inst *LiteClone(Graph *target_graph, std::map &connect) override; + virtual void DumpInputs(std::ostream &out) override; private: std::string name_func_; diff --git a/src/ir_constructor.h b/src/ir_constructor.h index 2268ed5..6193c30 100644 --- a/src/ir_constructor.h +++ b/src/ir_constructor.h @@ -70,6 +70,13 @@ class IrConstructor { } return *this; } + IrConstructor &NameFunc(std::string name) { + if (!current_inst_->IsCall()) { + UNREACHABLE(); + } + current_inst_->CastToCall()->SetNameFunc(name); + return *this; + } IrConstructor &Branches(id_t true_br, id_t false_br) { ASSERT(current_inst_ != nullptr); diff --git a/src/optimizations/inlining.cpp b/src/optimizations/inlining.cpp new file mode 100644 index 0000000..79d211d --- /dev/null +++ b/src/optimizations/inlining.cpp @@ -0,0 +1,221 @@ +#include "inlining.h" +#include "analysis/rpo.h" + +namespace compiler { + +Inlining::Inlining(Graph *main_graph, std::vector additional_graphs): + graph_(main_graph) { + FillMapAdditionalGraphs(additional_graphs); +} + +void Inlining::Run() { + // Find call and try to inline them: + // 1) Find Call in RPO + // 2) Find callee method in additioanal graphs + // Checks: + // 1) Build graph (in our case already built) + // 2) Do something optimization + // 3) Calculate number of instructions + FindAllCalls(); + TryInlineCalls(); +} + +void Inlining::TryInlineCalls() { + for (auto inst : all_calls_) { + inst->Dump(std::cerr); + auto call = inst->CastToCall(); + auto func_graph = GetGraphFuncByCall(call); + if (!func_graph.has_value()) { + std::cerr << "Not found \n"; + continue; + } + if (!CanBeInlined(func_graph.value())) { + std::cerr << "Bad \n"; + continue; + } + InlineFunc(call, func_graph.value()); + already_inlined_insts_ += func_graph.value()->GetNumInsts(); + } +} + +void Inlining::InlineFunc([[maybe_unused]]CallInst *call, Graph *func_graph) { + inlined_subgraph_.clear(); + last_new_cfg_ = nullptr; + InsertExternalGraph(func_graph); + UpdateParameters(call); + UpdateReturn(call); + UpdateCfgSubgraph(call); + DeleteUnnecessaryInst(); +} + +Inst *FindFirstJump(Inst *inst) { + while (inst->GetOpcode() != Opcode::Jump) { + inst = inst->GetControlUser(); + } + return inst; +} + +void Inlining::DeleteUnnecessaryInst() { + graph_->DeleteInst(inlined_subgraph_.back()); + graph_->DeleteInst(inlined_subgraph_.front()); +} + +void Inlining::UpdateCfgSubgraph(CallInst *call) { + ASSERT(last_new_cfg_ != nullptr); + auto lower_connect = call->GetControlUser(); + auto upper_connect = call->GetControlInput(); + graph_->DeleteInst(call); + FindFirstJump(inlined_subgraph_.front())->SetControlInput(upper_connect); + lower_connect->SetControlInput(last_new_cfg_); +} + +void Inlining::UpdateReturn(Inst *call) { + auto end_region = inlined_subgraph_.back()->CastToRegion(); + auto num_returns = end_region->NumAllInputs(); + if (num_returns == 1) { + SingleReturn(call, end_region); + } else { + MultipleReturn(call, end_region, num_returns); + } +} + +void Inlining::SingleReturn(Inst *call, RegionInst *end_region) { + auto last_jump = end_region->GetRegionInput(0); + ASSERT(last_jump->GetOpcode() == Opcode::Jump); + auto inlined_return = last_jump->GetControlInput(); + ASSERT(inlined_return->GetOpcode() == Opcode::Return); + auto return_value = inlined_return->GetDataInput(0); + graph_->DeleteInst(inlined_return); + return_value->ReplaceDataUsers(call); + last_jump->SetControlInput(inlined_return->GetControlInput()); + last_new_cfg_ = graph_->CreateRegionInst(); + last_new_cfg_->SetControlInput(last_jump); +} + +void Inlining::MultipleReturn(Inst *call, RegionInst *end_region, uint32_t num_returns) { + auto sum_region = graph_->CreateRegionInst(); + auto sum_phi = graph_->CreatePhiInst(); + sum_phi->SetControlInput(sum_region); + + for (uint32_t i = 0; i < num_returns; i++) { + // Find multiple return, on each way to End region + auto last_jump = end_region->GetRegionInput(i); + ASSERT(last_jump->GetOpcode() == Opcode::Jump); + auto inlined_return = last_jump->GetControlInput(); + ASSERT(inlined_return->GetOpcode() == Opcode::Return); + // Take return value (data input) and delete them + auto return_value = inlined_return->GetDataInput(0); + graph_->DeleteInst(inlined_return); + last_jump->SetControlInput(inlined_return->GetControlInput()); + + sum_region->SetRegionInput(i, last_jump); + last_jump->SetControlUser(sum_region); + sum_phi->SetDataInput(i, return_value); + } + sum_phi->ReplaceDataUsers(call); + sum_phi->SetControlInput(call->GetControlUser()); + last_new_cfg_ = sum_phi; +} + +void Inlining::UpdateParameters(CallInst *call) { + [[maybe_unused]] auto num_params = call->NumDataInputs(); + for (id_t index = 0; index < num_params; index++) { + auto comparator = [index](Inst *inst) { + if (inst->GetOpcode() != Opcode::Parameter) { + return false; + } + return inst->CastToParameter()->GetIndexParam() == index; + }; + auto it_param = std::find_if(inlined_subgraph_.begin(), inlined_subgraph_.end(), comparator); + ASSERT(it_param != inlined_subgraph_.end()); + + auto argument = call->GetDataInput(index); + argument->ReplaceDataUsers(*it_param); + graph_->DeleteInst(*it_param); + } +} + +Inst *Inlining::InsertExternalGraph(Graph *ext_graph) { + auto rpo = RpoInsts(ext_graph); + rpo.Run(); + auto rpo_vector = rpo.GetVector(); + std::vector branchs_fill_after; + std::map connection_index; + + const id_t offset_index = graph_->GetNumInsts(); + // Exclude Start and End insts (first and last in rpo) + for (auto inst : rpo_vector) { + auto new_inst = inst->LiteClone(graph_, connection_index); + inlined_subgraph_.push_back(new_inst); + connection_index[inst->GetId()] = new_inst->GetId(); + + Opcode opc = new_inst->GetOpcode(); + if (opc == Opcode::Jump || opc == Opcode::If) { + branchs_fill_after.push_back(inst); + } + } + + for (auto inst : branchs_fill_after) { + auto old_index = inst->GetId(); + auto new_inst = graph_->GetInstByIndex(connection_index[inst->GetId()]); + if (inst->GetOpcode() == Opcode::Jump) { + auto new_jmp_to = connection_index[ext_graph->GetInstByIndex(old_index)->CastToJump()->GetJumpTo()->GetId()]; + new_inst->CastToJump()->SetJmpTo(graph_->GetInstByIndex(new_jmp_to)); + continue; + } + ASSERT(inst->GetOpcode() == Opcode::If); + auto new_true_branch = connection_index[inst->CastToIf()->GetTrueBranch()->GetId()]; + auto new_false_branch = connection_index[inst->CastToIf()->GetFalseBranch()->GetId()]; + new_inst->CastToIf()->SetTrueBranch(graph_->GetInstByIndex(new_true_branch)); + new_inst->CastToIf()->SetFalseBranch(graph_->GetInstByIndex(new_false_branch)); + } + return graph_->GetInstByIndex(offset_index); +} + +std::optional Inlining::GetGraphFuncByCall(CallInst *call) { + const auto num_params = call->NumDataInputs(); + const auto &name_func = call->GetNameFunc(); + auto it = map_.find({name_func, num_params}); + if (it == map_.end()) { + return std::nullopt; + } + return it->second; +} + +bool Inlining::CanBeInlined(const Graph *ext_call) { + // Special attribute in name for disable inlining for function + if (ext_call->GetMethodName().find("__noinline__") != std::string::npos) { + return false; + } + return ext_call->GetNumInsts() + already_inlined_insts_ <= max_inline_insts_; +} + +std::vector Inlining::GetRPOVector() { + auto rpo = RpoInsts(graph_); + rpo.Run(); + return rpo.GetVector(); +} + +void Inlining::FillMapAdditionalGraphs(std::vector &additional_graphs) { + for (auto graph : additional_graphs) { + if (map_.find({graph->GetMethodName(), graph->GetNumParams()}) != map_.end()) { + std::cerr << "Method \"" << graph->GetMethodName() << "\" with " << graph->GetNumParams() << " param(s) is already exist!\n"; + exit(1); + } + map_[{graph->GetMethodName(), graph->GetNumParams()}] = graph; + } +} + +void Inlining::FindAllCalls() { + auto rpo_vector = GetRPOVector(); + for (auto inst : rpo_vector) { + inst->Dump(std::cerr); + if (!inst->IsCall()) { + continue; + } + all_calls_.push_back(inst); + } +} + + +} diff --git a/src/optimizations/inlining.h b/src/optimizations/inlining.h new file mode 100644 index 0000000..cb903eb --- /dev/null +++ b/src/optimizations/inlining.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include +#include + +#include "graph.h" + +namespace compiler { + +class Inlining { +public: + Inlining(Graph *main_graph, std::vector additional_graphs); + + void Run(); + +private: + void TryInlineCalls(); + void FillMapAdditionalGraphs(std::vector &additional_graphs); + std::vector GetRPOVector(); + void FindAllCalls(); + bool CanBeInlined(const Graph *ext_graph); + std::optional GetGraphFuncByCall(CallInst *call); + void InlineFunc(CallInst *call, Graph *func_graph); + // Return value: first non-start Region of copied external graph + Inst *InsertExternalGraph(Graph *ext_graph); + void UpdateParameters(CallInst *call); + void UpdateReturn(Inst *call); + void SingleReturn(Inst *call, RegionInst *end_region); + void MultipleReturn(Inst *call, RegionInst *end, uint32_t num_returns); + void UpdateCfgSubgraph(CallInst *call); + void DeleteUnnecessaryInst(); + +private: + Inst *last_new_cfg_ = nullptr; + uint32_t max_inline_insts_ = 20; // This small default value for testing + uint32_t already_inlined_insts_ = 0; + Graph *graph_; + + std::vector all_calls_; + std::map, Graph *> map_; + std::vector inlined_subgraph_; +}; + +} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 30991c8..ccb5cd3 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -64,6 +64,25 @@ add_custom_target( COMMAND peepholes_tests ) +add_executable( + inlining_test + inlining_tests.cpp + graph_comparator.cpp +) + +target_link_libraries( + inlining_test + ${ALL_LIBS_FOR_TESTS} +) + +target_include_directories(inlining_test PUBLIC "${CMAKE_SOURCE_DIR}/src") + +gtest_discover_tests(inlining_test) + +add_custom_target( + inlining_test_gtest + COMMAND inlining_test +) add_custom_target( tests diff --git a/tests/graph_tests.cpp b/tests/graph_tests.cpp index 802a959..b4efd38 100644 --- a/tests/graph_tests.cpp +++ b/tests/graph_tests.cpp @@ -327,7 +327,7 @@ TEST(GraphTest, TestParameterReturn) { " 0. Start -> v2\n" " 1. End v4\n" " 2. Region v0 -> v4\n" - " 3. Parameter -> v4\n" + " 3. Parameter \"2\" -> v4\n" " 4. Return v2, v3 -> v1\n"; ic.GetFinalGraph()->Dump(dump_out); ASSERT_EQ(dump_out.str(), output); @@ -350,7 +350,7 @@ TEST(GraphTest, TestCall) { " 1. End v5\n" " 2. Region v0 -> v4\n" " 3.i64 Constant 0x4 -> v4\n" - " 4. Call v2, v3 -> v5, v5\n" + " 4. Call \"\" v2, v3 -> v5, v5\n" " 5. Return v4, v4 -> v1\n"; ic.GetFinalGraph()->Dump(dump_out); ASSERT_EQ(dump_out.str(), output); diff --git a/tests/inlining_tests.cpp b/tests/inlining_tests.cpp new file mode 100644 index 0000000..90f902e --- /dev/null +++ b/tests/inlining_tests.cpp @@ -0,0 +1,136 @@ +#include +#include +#include "graph.h" + +#include "ir_constructor.h" +#include "tests/graph_comparator.h" +#include "optimizations/analysis/rpo.h" +#include "optimizations/analysis/domtree.h" +#include "optimizations/analysis/loop_analysis.h" + +#include "optimizations/inlining.h" + +namespace compiler { + +TEST(InliningTest, SimpleTestJump) { + auto ic = IrConstructor(); + ic.CreateInst(0); + ic.CreateInst(10).CtrlInput(0).JmpTo(7); + + ic.CreateInst(7); + ic.CreateInst(8).NameFunc("Foo").CtrlInput(7); + ic.CreateInst(11).CtrlInput(8).DataInputs(8); + ic.CreateInst(9).CtrlInput(11).JmpTo(1); + + ic.CreateInst(1); + auto main_graph = ic.GetFinalGraph(); + main_graph->SetMethodName("main"); + + auto ic2 = IrConstructor(); + ic2.CreateInst(0); + ic2.CreateInst(3).Imm(0); + ic2.CreateInst(2).CtrlInput(0).JmpTo(5); + + ic2.CreateInst(5); + ic2.CreateInst(4).CtrlInput(5).DataInputs(3); + ic2.CreateInst(6).CtrlInput(4).JmpTo(1); + + ic2.CreateInst(1); + + ic2.FinalizeRegions(); + auto ext_graph = ic2.GetFinalGraph(); + ext_graph->SetMethodName("Foo"); + + auto inl = Inlining(main_graph, {ext_graph}); + inl.Run(); + main_graph->Dump(std::cerr); +} + +TEST(InliningTest, SimpleTestIf) { + auto ic = IrConstructor(); + ic.CreateInst(0); + ic.CreateInst(5).Imm(26); + ic.CreateInst(10).CtrlInput(0).JmpTo(7); + + ic.CreateInst(7); + ic.CreateInst(8).NameFunc("Foo").CtrlInput(7).DataInputs(5); + ic.CreateInst(11).CtrlInput(8).DataInputs(8); + ic.CreateInst(9).CtrlInput(11).JmpTo(1); + + ic.CreateInst(1); + auto main_graph = ic.GetFinalGraph(); + main_graph->SetMethodName("main"); + + // External method + auto ic2 = IrConstructor(); + ic2.CreateInst(0); + ic2.CreateInst(3).Imm(0); + ic2.CreateInst(2).CtrlInput(0).JmpTo(5); + + ic2.CreateInst(5); + ic2.CreateInst(7).CtrlInput(5).DataInputs(3).Branches(8, 10); + + ic2.CreateInst(8); + ic2.CreateInst(9).CtrlInput(8).JmpTo(10); + + ic2.CreateInst(10); + ic2.CreateInst(4).CtrlInput(10).DataInputs(3); + ic2.CreateInst(6).CtrlInput(4).JmpTo(1); + + ic2.CreateInst(1); + + ic2.FinalizeRegions(); + auto ext_graph = ic2.GetFinalGraph(); + ext_graph->SetMethodName("Foo"); + ext_graph->SetNumParams(1); + + auto inl = Inlining(main_graph, {ext_graph}); + inl.Run(); + main_graph->Dump(std::cerr); +} + +TEST(InliningTest, WillNotApplied) { + auto ic = IrConstructor(); + ic.CreateInst(0); + ic.CreateInst(5).Imm(26); + ic.CreateInst(10).CtrlInput(0).JmpTo(7); + + ic.CreateInst(7); + ic.CreateInst(8).NameFunc("FooBar").CtrlInput(7).DataInputs(5); + ic.CreateInst(11).CtrlInput(8).DataInputs(8); + ic.CreateInst(9).CtrlInput(11).JmpTo(1); + + ic.CreateInst(1); + auto main_graph = ic.GetFinalGraph(); + main_graph->SetMethodName("main"); + + // External method + auto ic2 = IrConstructor(); + ic2.CreateInst(0); + ic2.CreateInst(3).Imm(0); + ic2.CreateInst(2).CtrlInput(0).JmpTo(5); + + ic2.CreateInst(5); + ic2.CreateInst(7).CtrlInput(5).DataInputs(3).Branches(8, 10); + + ic2.CreateInst(8); + ic2.CreateInst(9).CtrlInput(8).JmpTo(10); + + ic2.CreateInst(10); + ic2.CreateInst(4).CtrlInput(10).DataInputs(3); + ic2.CreateInst(6).CtrlInput(4).JmpTo(1); + + ic2.CreateInst(1); + + ic2.FinalizeRegions(); + auto ext_graph = ic2.GetFinalGraph(); + ext_graph->SetMethodName("Foo"); + ext_graph->SetNumParams(1); + + auto inl = Inlining(main_graph, {ext_graph}); + inl.Run(); + main_graph->Dump(std::cerr); +} + + +}