Skip to content

Commit

Permalink
Propagate underlying type from LoadRegAddress
Browse files Browse the repository at this point in the history
  • Loading branch information
tetsuo-cpp committed Jun 22, 2022
1 parent e9c9af1 commit 7fdcb51
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 53 deletions.
6 changes: 4 additions & 2 deletions include/remill/BC/InstructionLifter.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class LLVMContext;
class IntegerType;
class BasicBlock;
class Value;
class Type;
} // namespace llvm

namespace remill {
Expand Down Expand Up @@ -78,8 +79,9 @@ class InstructionLifter {
bool is_delayed = false);

// Load the address of a register.
llvm::Value *LoadRegAddress(llvm::BasicBlock *block, llvm::Value *state_ptr,
std::string_view reg_name) const;
std::pair<llvm::Value *, llvm::Type *>
LoadRegAddress(llvm::BasicBlock *block, llvm::Value *state_ptr,
std::string_view reg_name) const;

// Load the value of a register.
llvm::Value *LoadRegValue(llvm::BasicBlock *block, llvm::Value *state_ptr,
Expand Down
10 changes: 6 additions & 4 deletions include/remill/BC/Util.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,15 @@ llvm::CallInst *AddTerminatingTailCall(llvm::BasicBlock *source_block,

// Find a local variable defined in the entry block of the function. We use
// this to find register variables.
llvm::Value *FindVarInFunction(llvm::BasicBlock *block, std::string_view name,
bool allow_failure = false);
std::pair<llvm::Value *, llvm::Type *>
FindVarInFunction(llvm::BasicBlock *block, std::string_view name,
bool allow_failure = false);

// Find a local variable defined in the entry block of the function. We use
// this to find register variables.
llvm::Value *FindVarInFunction(llvm::Function *func, std::string_view name,
bool allow_failure = false);
std::pair<llvm::Value *, llvm::Type *>
FindVarInFunction(llvm::Function *func, std::string_view name,
bool allow_failure = false);

// Find the machine state pointer. The machine state pointer is, by convention,
// passed as the first argument to every lifted function.
Expand Down
10 changes: 5 additions & 5 deletions lib/Arch/Arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,11 +386,11 @@ namespace {

// These variables must always be defined within any lifted function.
static bool BlockHasSpecialVars(llvm::Function *basic_block) {
return FindVarInFunction(basic_block, kStateVariableName, true) &&
FindVarInFunction(basic_block, kMemoryVariableName, true) &&
FindVarInFunction(basic_block, kPCVariableName, true) &&
FindVarInFunction(basic_block, kNextPCVariableName, true) &&
FindVarInFunction(basic_block, kBranchTakenVariableName, true);
return FindVarInFunction(basic_block, kStateVariableName, true).first &&
FindVarInFunction(basic_block, kMemoryVariableName, true).first &&
FindVarInFunction(basic_block, kPCVariableName, true).first &&
FindVarInFunction(basic_block, kNextPCVariableName, true).first &&
FindVarInFunction(basic_block, kBranchTakenVariableName, true).first;
}

// Add attributes to llvm::Argument in a way portable across LLVMs
Expand Down
48 changes: 25 additions & 23 deletions lib/BC/InstructionLifter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ LiftStatus InstructionLifter::LiftIntoBlock(Instruction &arch_inst,
}

llvm::IRBuilder<> ir(block);
const auto mem_ptr_ref =
const auto [mem_ptr_ref, mem_ptr_ref_type] =
LoadRegAddress(block, state_ptr, kMemoryVariableName);
const auto pc_ref = LoadRegAddress(block, state_ptr, kPCVariableName);
const auto next_pc_ref =
const auto [pc_ref, pc_ref_type] = LoadRegAddress(block, state_ptr, kPCVariableName);
const auto [next_pc_ref, next_pc_ref_type] =
LoadRegAddress(block, state_ptr, kNextPCVariableName);
const auto next_pc = ir.CreateLoad(impl->word_type, next_pc_ref);

Expand Down Expand Up @@ -225,7 +225,7 @@ LiftStatus InstructionLifter::LiftIntoBlock(Instruction &arch_inst,
}

// Load the address of a register.
llvm::Value *
std::pair<llvm::Value *, llvm::Type *>
InstructionLifter::LoadRegAddress(llvm::BasicBlock *block,
llvm::Value *state_ptr,
std::string_view reg_name_) const {
Expand All @@ -241,18 +241,20 @@ InstructionLifter::LoadRegAddress(llvm::BasicBlock *block,
}

std::string reg_name(reg_name_.data(), reg_name_.size());
auto [reg_ptr_it, added] =
impl->reg_ptr_cache.emplace(std::move(reg_name), nullptr);
auto [reg_ptr_it, added] = impl->reg_ptr_cache.emplace(
std::move(reg_name),
std::pair<llvm::Value *, llvm::Type *>{nullptr, nullptr});

if (reg_ptr_it->second) {
if (reg_ptr_it->second.first) {
(void) added;
return reg_ptr_it->second;
}

// It's already a variable in the function.
if (const auto var_ptr = FindVarInFunction(func, reg_name_, true)) {
reg_ptr_it->second = var_ptr;
return var_ptr;
const auto [var_ptr, var_ptr_type] = FindVarInFunction(func, reg_name_, true);
if (var_ptr) {
reg_ptr_it->second = {var_ptr, var_ptr_type};
return reg_ptr_it->second;
}

// It's a register known to this architecture, so go and build a GEP to it
Expand Down Expand Up @@ -288,21 +290,21 @@ InstructionLifter::LoadRegAddress(llvm::BasicBlock *block,
<< LLVMThingToString(state_ptr);
}

reg_ptr_it->second = reg_ptr;
return reg_ptr;
reg_ptr_it->second = {reg_ptr, reg->type};
return reg_ptr_it->second;
}

// Try to find it as a global variable.
if (auto gvar = module->getGlobalVariable(reg_name)) {
return gvar;
return {gvar, gvar->getValueType()};
}

// Invent a fake one and keep going.
std::stringstream unk_var;
unk_var << "__remill_unknown_register_" << reg_name;
auto unk_var_name = unk_var.str();
if (auto var = module->getGlobalVariable(unk_var_name)) {
return var;
return {var, var->getValueType()};
}

// TODO(pag): Eventually refactor into a higher-level issue, perhaps a
Expand All @@ -311,9 +313,11 @@ InstructionLifter::LoadRegAddress(llvm::BasicBlock *block,
LOG(ERROR)
<< "Could not locate variable or register " << reg_name_;

return new llvm::GlobalVariable(
*module, impl->word_type, false, llvm::GlobalValue::ExternalLinkage,
llvm::UndefValue::get(impl->word_type), unk_var_name);
return {new llvm::GlobalVariable(*module, impl->word_type, false,
llvm::GlobalValue::ExternalLinkage,
llvm::UndefValue::get(impl->word_type),
unk_var_name),
impl->word_type};
}

// Clear out the cache of the current register values/addresses loaded.
Expand All @@ -326,10 +330,8 @@ void InstructionLifter::ClearCache(void) const {
llvm::Value *InstructionLifter::LoadRegValue(llvm::BasicBlock *block,
llvm::Value *state_ptr,
std::string_view reg_name) const {
auto ptr = LoadRegAddress(block, state_ptr, reg_name);
auto [ptr, ptr_ty] = LoadRegAddress(block, state_ptr, reg_name);
CHECK_NOTNULL(ptr);
// NOTE(alex): This isn't right. Not sure how to solve this right now.
auto ptr_ty = impl->word_type;
return new llvm::LoadInst(ptr_ty, ptr, llvm::Twine::createNull(), block);
}

Expand Down Expand Up @@ -562,7 +564,7 @@ llvm::Value *InstructionLifter::LiftRegisterOperand(Instruction &inst,
auto arg_type = IntendedArgumentType(arg);

if (llvm::isa<llvm::PointerType>(arg_type)) {
auto val = LoadRegAddress(block, state_ptr, arch_reg.name);
auto [val, val_type] = LoadRegAddress(block, state_ptr, arch_reg.name);
return ConvertToIntendedType(inst, op, block, val, real_arg_type);

} else {
Expand Down Expand Up @@ -763,7 +765,7 @@ llvm::Value *InstructionLifter::LiftExpressionOperandRec(
if (!arg || !llvm::isa<llvm::PointerType>(arg->getType())) {
return LoadRegValue(block, state_ptr, (*reg_op)->name);
} else {
return LoadRegAddress(block, state_ptr, (*reg_op)->name);
return LoadRegAddress(block, state_ptr, (*reg_op)->name).first;
}

} else if (auto ci_op = std::get_if<llvm::Constant *>(op)) {
Expand All @@ -773,7 +775,7 @@ llvm::Value *InstructionLifter::LiftExpressionOperandRec(
if (!arg || !llvm::isa<llvm::PointerType>(arg->getType())) {
return LoadRegValue(block, state_ptr, *str_op);
} else {
return LoadRegAddress(block, state_ptr, *str_op);
return LoadRegAddress(block, state_ptr, *str_op).first;
}
} else {
LOG(FATAL) << "Uninitialized Operand Expression";
Expand Down
3 changes: 2 additions & 1 deletion lib/BC/InstructionLifter.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class InstructionLifter::Impl {
llvm::Type *const memory_ptr_type;

// Cache of looked up registers inside of `last_func`.
std::unordered_map<std::string, llvm::Value *> reg_ptr_cache;
std::unordered_map<std::string, std::pair<llvm::Value *, llvm::Type *>>
reg_ptr_cache;

// The function into which we're lifting. If This gets out of date, we
// clear out `reg_ptr_cache`.
Expand Down
6 changes: 3 additions & 3 deletions lib/BC/TraceLifter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ bool TraceLifter::Impl::Lift(

func = get_trace_decl(trace_addr);
blocks.clear();

if (!func || !func->isDeclaration()) {
func = arch->DeclareLiftedFunction(manager.TraceName(trace_addr), module);
}
Expand All @@ -305,8 +305,8 @@ bool TraceLifter::Impl::Lift(

if (auto entry_block = &(func->front())) {
auto pc = LoadProgramCounterArg(func);
auto next_pc_ref = inst_lifter.LoadRegAddress(entry_block, state_ptr,
kNextPCVariableName);
auto [next_pc_ref, next_pc_ref_type] = inst_lifter.LoadRegAddress(
entry_block, state_ptr, kNextPCVariableName);

// Initialize `NEXT_PC`.
(void) new llvm::StoreInst(pc, next_pc_ref, entry_block);
Expand Down
33 changes: 18 additions & 15 deletions lib/BC/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,38 +207,41 @@ llvm::CallInst *AddTerminatingTailCall(llvm::BasicBlock *source_block,

// Find a local variable defined in the entry block of the function. We use
// this to find register variables.
llvm::Value *FindVarInFunction(llvm::BasicBlock *block, std::string_view name,
bool allow_failure) {
std::pair<llvm::Value *, llvm::Type *>
FindVarInFunction(llvm::BasicBlock *block, std::string_view name,
bool allow_failure) {
return FindVarInFunction(block->getParent(), name, allow_failure);
}

// Find a local variable defined in the entry block of the function. We use
// this to find register variables.
llvm::Value *FindVarInFunction(llvm::Function *function, std::string_view name_,
bool allow_failure) {
std::pair<llvm::Value *, llvm::Type *>
FindVarInFunction(llvm::Function *function, std::string_view name_,
bool allow_failure) {
// TODO(alex): Figure out how to to get types for the two cases below.
llvm::StringRef name(name_.data(), name_.size());
if (!function->empty()) {
for (auto &instr : function->getEntryBlock()) {
if (instr.getName() == name) {
return &instr;
return {&instr, nullptr};
}
}
}

for (auto &arg : function->args()) {
if (arg.getName() == name) {
return &arg;
return {&arg, nullptr};
}
}

auto module = function->getParent();
if (auto var = module->getGlobalVariable(name)) {
return var;
return {var, var->getValueType()};
}

CHECK(allow_failure) << "Could not find variable " << name_ << " in function "
<< function->getName().str();
return nullptr;
return {nullptr, nullptr};
}

// Find the machine state pointer.
Expand Down Expand Up @@ -293,12 +296,12 @@ llvm::Value *LoadProgramCounter(llvm::BasicBlock *block,

// Return a reference to the current program counter.
llvm::Value *LoadProgramCounterRef(llvm::BasicBlock *block) {
return FindVarInFunction(block->getParent(), kPCVariableName);
return FindVarInFunction(block->getParent(), kPCVariableName).first;
}

// Return a reference to the next program counter.
llvm::Value *LoadNextProgramCounterRef(llvm::BasicBlock *block) {
return FindVarInFunction(block->getParent(), kNextPCVariableName);
return FindVarInFunction(block->getParent(), kNextPCVariableName).first;
}

// Return the next program counter.
Expand All @@ -310,7 +313,7 @@ llvm::Value *LoadNextProgramCounter(llvm::BasicBlock *block,

// Return a reference to the return program counter.
llvm::Value *LoadReturnProgramCounterRef(llvm::BasicBlock *block) {
return FindVarInFunction(block->getParent(), kReturnPCVariableName);
return FindVarInFunction(block->getParent(), kReturnPCVariableName).first;
}

// Update the program counter in the state struct with a new value.
Expand Down Expand Up @@ -345,19 +348,19 @@ llvm::Value *LoadBranchTaken(llvm::BasicBlock *block) {
llvm::IRBuilder<> ir(block);
auto i8_type = llvm::Type::getInt8Ty(block->getContext());
auto cond = ir.CreateLoad(
i8_type, FindVarInFunction(block->getParent(), kBranchTakenVariableName));
i8_type, FindVarInFunction(block->getParent(), kBranchTakenVariableName).first);
auto true_val = llvm::ConstantInt::get(cond->getType(), 1);
return ir.CreateICmpEQ(cond, true_val);
}

// Return a reference to the branch taken
llvm::Value *LoadBranchTakenRef(llvm::BasicBlock *block) {
return FindVarInFunction(block->getParent(), kBranchTakenVariableName);
return FindVarInFunction(block->getParent(), kBranchTakenVariableName).first;
}

// Return a reference to the memory pointer.
llvm::Value *LoadMemoryPointerRef(llvm::BasicBlock *block) {
return FindVarInFunction(block->getParent(), kMemoryVariableName);
return FindVarInFunction(block->getParent(), kMemoryVariableName).first;
}

// Find a function with name `name` in the module `M`.
Expand Down Expand Up @@ -676,7 +679,7 @@ LiftedFunctionArgs(llvm::BasicBlock *block, const IntrinsicTable &intrinsics) {
// Set up arguments according to our ABI.
std::array<llvm::Value *, kNumBlockArgs> args;

if (FindVarInFunction(func, kPCVariableName, true)) {
if (FindVarInFunction(func, kPCVariableName, true).first) {
args[kMemoryPointerArgNum] = LoadMemoryPointer(block, intrinsics);
args[kStatePointerArgNum] = LoadStatePointer(block);
args[kPCArgNum] = LoadProgramCounter(block, intrinsics);
Expand Down

0 comments on commit 7fdcb51

Please sign in to comment.