Skip to content

Commit

Permalink
Reduce ad-hoc type conversions on functions
Browse files Browse the repository at this point in the history
By not treating function pointers specially in `GetLLVMType`, most of
the ad-hoc conversions from function types to pointer types are
eliminated. The only remaining conversion is when referencing the ID of
a function; the type of such ID expression should be a pointer type
instead of a function type. Additionally, when using such an ID
expression in a function call, we need to convert it back to a function
type.
  • Loading branch information
Lai-YT authored and leewei05 committed Jul 6, 2024
1 parent 7fc353d commit 366b3dc
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 31 deletions.
3 changes: 0 additions & 3 deletions include/llvm/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ class LLVMIRBuilderHelper {
llvm::Function* CurrFunc();

/// @brief Get the corresponding LLVM type from our type.
/// @note For Function Pointers, even though it is a pointer type, we return
/// `FunctionType` instead of `PointerType` because `FunctionType` is needed
/// for creating LLVM IR function call.
/// @throw `std::runtime_error` if the `type` is not unknown.
llvm::Type* GetLLVMType(const Type& type);

Expand Down
7 changes: 0 additions & 7 deletions src/llvm/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,6 @@ llvm::Type* LLVMIRBuilderHelper::GetLLVMType(const Type& type) {
}
throw std::runtime_error{"unknown type in GetLLVMType!"};
} else if (type.IsPtr()) {
auto ptr_type = dynamic_cast<const PtrType*>(&type);
auto base_type = ptr_type->base_type().Clone();
auto llvm_base_type = GetLLVMType(*base_type);
// Function pointers
if (llvm_base_type->isFunctionTy()) {
return llvm_base_type;
}
return builder_.getPtrTy();
} else if (type.IsArr()) {
auto arr_type = dynamic_cast<const ArrType*>(&type);
Expand Down
38 changes: 17 additions & 21 deletions src/llvm_ir_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@ void LLVMIRGenerator::Visit(const DeclStmtNode& decl_stmt) {

void LLVMIRGenerator::Visit(const VarDeclNode& decl) {
auto var_type = builder_helper_.GetLLVMType(*(decl.type));
// For function pointer, we need to change from FunctionType to PointerType
var_type = var_type->isFunctionTy() ? var_type->getPointerTo() : var_type;
auto addr = builder_.CreateAlloca(var_type);
if (decl.init) {
decl.init->Accept(*this);
Expand Down Expand Up @@ -249,13 +247,7 @@ void LLVMIRGenerator::Visit(const FuncDefNode& func_def) {
parameter->Accept(*this);
args_iter->setName(parameter->id);

llvm::Type* param_type = builder_helper_.GetLLVMType(*(parameter->type));
// Update type from FunctionType to PointerType for function pointer.
if (param_type->isFunctionTy()) {
param_type = param_type->getPointerTo();
}
args_iter->mutateType(param_type);
auto addr = builder_.CreateAlloca(param_type);
auto addr = builder_.CreateAlloca(args_iter->getType());
builder_.CreateStore(args_iter, addr);
id_to_val[parameter->id] = addr;
++args_iter;
Expand Down Expand Up @@ -528,13 +520,7 @@ void LLVMIRGenerator::Visit(const IdExprNode& id_expr) {
assert(id_to_val.count(id_expr.id) != 0);
auto id_val = id_to_val.at(id_expr.id);

llvm::Type* id_type = nullptr;
// LLVM requires the function to have pointer type when being referenced.
if (id_expr.type->IsPtr() || id_expr.type->IsFunc()) {
id_type = builder_.getPtrTy();
} else {
id_type = builder_helper_.GetLLVMType(*(id_expr.type));
}
auto id_type = builder_helper_.GetLLVMType(*(id_expr.type));
auto res = builder_.CreateLoad(id_type, id_val);
val_recorder.Record(res);
val_to_id_addr[res] = id_val;
Expand Down Expand Up @@ -632,9 +618,19 @@ void LLVMIRGenerator::Visit(const FuncCallExprNode& call_expr) {
val_recorder.Record(return_res);
}
} else if (val->getType()->isPointerTy()) {
// function pointer
auto type = builder_helper_.GetLLVMType(*(call_expr.func_expr->type));
if (auto func_type = llvm::dyn_cast<llvm::FunctionType>(type)) {
llvm::Type* expr_type = nullptr;
if (auto ptr_type =
dynamic_cast<PtrType*>((call_expr.func_expr->type).get())) {
expr_type = builder_helper_.GetLLVMType(ptr_type->base_type());
} else {
// If the expression is a dereferenced function pointer, the type of the
// value diverges from the type of the expression, with the former being a
// pointer type and the latter being a function type.
assert(call_expr.func_expr->type->IsFunc());
expr_type = builder_helper_.GetLLVMType(*(call_expr.func_expr->type));
}

if (auto func_type = llvm::dyn_cast<llvm::FunctionType>(expr_type)) {
auto return_res = builder_.CreateCall(func_type, val, arg_vals);
val_recorder.Record(return_res);
} else {
Expand Down Expand Up @@ -716,7 +712,7 @@ void LLVMIRGenerator::Visit(const UnaryExprNode& unary_expr) {
} break;
case UnaryOperator::kAddr: {
if (unary_expr.operand->type->IsFunc()) {
// No-op; the function itself already evaluates to the address.
// No-op; the value is still the function itself.
break;
}
auto operand = val_recorder.ValOfPrevExpr();
Expand All @@ -729,7 +725,7 @@ void LLVMIRGenerator::Visit(const UnaryExprNode& unary_expr) {
dynamic_cast<PtrType*>((unary_expr.operand->type).get())
->base_type()
.IsFunc()) {
// No-op; the function itself also evaluates to the address.
// No-op; the value is still the function itself.
break;
}

Expand Down

0 comments on commit 366b3dc

Please sign in to comment.