Skip to content

Commit

Permalink
Fix location emission
Browse files Browse the repository at this point in the history
  • Loading branch information
driazati committed Oct 25, 2022
1 parent 0cf4cb8 commit d66fe0d
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 72 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,6 @@ gallery/how_to/work_with_microtvm/micro_tvmc.py

# Printed TIR code on disk
*.tir

# GDB history file
.gdb_history
47 changes: 43 additions & 4 deletions src/printer/tir_text_printer_debug.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,64 @@

#include "tir_text_printer_debug.h"

#include <optional>
#include <string>

#include "text_printer.h"

namespace tvm {
namespace tir {

std::string span_text(const Span& span) {
std::optional<std::string> span_text(const Span& span) {
if (!span.defined()) {
return "missing";
return std::nullopt;
}

std::string source("main.tir");
if (span->source_name.defined() && span->source_name->name.get()) {
source = span->source_name->name;
}
std::string source("file");
return source + ":" + std::to_string(span->line) + ":" + std::to_string(span->column);
}

template <typename ObjectPtr>
void add_all_relevant_lines(const std::vector<std::tuple<const ObjectPtr*, size_t>>& data,
size_t current_line, Doc* output) {
ICHECK(output) << "output must be a valid Doc";
for (const auto& item : data) {
if (std::get<1>(item) != current_line - 1) {
// Item is not relevant for this line, skip it
continue;
}

// Print out the item's span info if present
auto text = span_text(std::get<0>(item)->span);
if (text.has_value()) {
*output << *text;
} else {
*output << "missing";
}
*output << ", ";
}
}

Doc TIRTextPrinterDebug::NewLine() {
current_line_ += 1;

return TIRTextPrinter::NewLine();
if (!show_spans_) {
return TIRTextPrinter::NewLine();
}

Doc output;

output << " [";

add_all_relevant_lines(exprs_by_line_, current_line_, &output);
add_all_relevant_lines(stmts_by_line_, current_line_, &output);

output << "]" << TIRTextPrinter::NewLine();

return output;
}

#define X(TypeName) \
Expand Down
6 changes: 5 additions & 1 deletion src/printer/tir_text_printer_debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ namespace tir {

class TIRTextPrinterDebug : public TIRTextPrinter {
public:
TIRTextPrinterDebug() : TIRTextPrinter(false, &meta_), current_line_(1) {}
explicit TIRTextPrinterDebug(bool show_spans)
: TIRTextPrinter(false, &meta_), current_line_(1), show_spans_(show_spans) {}

std::vector<std::tuple<const PrimExprNode*, size_t>> GetExprsByLine() const {
return exprs_by_line_;
Expand All @@ -61,6 +62,9 @@ class TIRTextPrinterDebug : public TIRTextPrinter {
// Line that the printer is currently printing
size_t current_line_;

// Whether to include spans relevant to each line before a newline or not
bool show_spans_;

// Record of all stmts and exprs and their corresponding line
std::vector<std::tuple<const StmtNode*, size_t>> stmts_by_line_;
std::vector<std::tuple<const PrimExprNode*, size_t>> exprs_by_line_;
Expand Down
4 changes: 2 additions & 2 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ llvm::DISubprogram* CodeGenCPU::CreateDebugFunction(const PrimFunc& f) {
/*Flags=*/llvm::DINode::FlagPrototyped, /*isOptimized=*/true);
#endif
return DIFunction;
#else
return nullptr;
#endif
}

Expand Down Expand Up @@ -952,7 +954,6 @@ llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op, bool use_string_lo
}

llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) {
EmitDebugLocation(op);
ICHECK_EQ(op->args.size(), 6U);
PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as<IntImmNode>()->value,
op->args[4].as<IntImmNode>()->value, true);
Expand Down Expand Up @@ -1388,7 +1389,6 @@ void CodeGenCPU::AddStartupFunction() {
}

llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) {
EmitDebugLocation(op);
if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
return CreateCallPacked(op, true /* use_string_lookup */);
} else if (op->op.same_as(builtin::tvm_call_trace_packed_lowered())) {
Expand Down
39 changes: 2 additions & 37 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,6 @@ llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
}

llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
ICHECK(builder_->getCurrentDebugLocation() != llvm::DebugLoc()) << "Debug information missing";
llvm::Value* mask = llvm::UndefValue::get(DTypeToLLVMType(DataType::Int(32, target_lanes)));
int num_elems = GetVectorNumElements(vec);
if (num_elems == target_lanes) return vec;
Expand All @@ -680,7 +679,6 @@ llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
}

llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
ICHECK(builder_->getCurrentDebugLocation() != llvm::DebugLoc()) << "Debug information missing";
// To allow creating vectors from scalars, convert any scalars in "vecs" to single-lane
// LLVM vector types.
for (size_t i = 0, e = vecs.size(); i != e; ++i) {
Expand Down Expand Up @@ -764,7 +762,6 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va

// cast operatpr
llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) {
ICHECK(builder_->getCurrentDebugLocation() != llvm::DebugLoc()) << "Debug information missing";
llvm::Type* target = DTypeToLLVMType(to);
if (value->getType() == target) return value;
if (to.is_handle()) {
Expand Down Expand Up @@ -800,7 +797,6 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va

llvm::Constant* CodeGenLLVM::GetGlobalConstant(llvm::Constant* const_data, const std::string& name,
llvm::GlobalValue::LinkageTypes linkage_type) {
ICHECK(builder_->getCurrentDebugLocation() != llvm::DebugLoc()) << "Debug information missing";
llvm::Type* ty = const_data->getType();
llvm::GlobalVariable* global =
new llvm::GlobalVariable(*module_, ty, true, linkage_type, const_data, name);
Expand All @@ -816,7 +812,6 @@ llvm::Constant* CodeGenLLVM::GetGlobalConstant(llvm::Constant* const_data, const
}

llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) {
ICHECK(builder_->getCurrentDebugLocation() != llvm::DebugLoc()) << "Debug information missing";
auto it = str_map_.find(str);
if (it != str_map_.end()) return it->second;
auto llvm_str = llvm::ConstantDataArray::getString(*llvm_target_->GetContext(), str);
Expand All @@ -829,7 +824,6 @@ CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(llvm::Value* buffer_ptr,
DataType buffer_element_dtype,
llvm::ArrayRef<llvm::Value*> indices,
DataType value_dtype) {
ICHECK(builder_->getCurrentDebugLocation() != llvm::DebugLoc()) << "Debug information missing";
ICHECK_EQ(indices.size(), 1) << "CodeGenLLVM requires all buffers to be flat 1-d buffers.";
llvm::Value* index = indices[0];

Expand Down Expand Up @@ -1189,7 +1183,6 @@ void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool use_float16_abi) {
}

llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
EmitDebugLocation(op);
if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) {
ICHECK_GE(op->args.size(), 2U);
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
Expand Down Expand Up @@ -1226,7 +1219,6 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
} else if (op->op.same_as(builtin::bitwise_not())) {
return builder_->CreateNot(MakeValue(op->args[0]));
} else if (op->op.same_as(builtin::bitwise_xor())) {
EmitDebugLocation(op);
return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->op.same_as(builtin::shift_left())) {
return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1]));
Expand Down Expand Up @@ -1353,29 +1345,20 @@ void CodeGenLLVM::Scalarize(const PrimExpr& e, std::function<void(int i, llvm::V
}

// Visitors
llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) {
EmitDebugLocation(op);
return GetVarValue(op);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) { return GetVarValue(op); }

llvm::Value* CodeGenLLVM::VisitExpr_(const CastNode* op) {
EmitDebugLocation(op);
return CreateCast(op->value.dtype(), op->dtype, MakeValue(op->value));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) {
EmitDebugLocation(op);
return llvm::ConstantInt::getSigned(DTypeToLLVMType(op->dtype), op->value);
}

llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) {
EmitDebugLocation(op);
return llvm::ConstantFP::get(DTypeToLLVMType(op->dtype), op->value);
}

llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) {
EmitDebugLocation(op);
return GetConstString(op->value);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { return GetConstString(op->value); }

#define DEFINE_CODEGEN_BINARY_OP(Op) \
llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \
Expand All @@ -1397,7 +1380,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) {
} \
} \
llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \
EmitDebugLocation(op); \
return Create##Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \
}

Expand All @@ -1417,7 +1399,6 @@ DEFINE_CODEGEN_BINARY_OP(Mul);
} \
} \
llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \
EmitDebugLocation(op); \
return Create##Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \
}

Expand All @@ -1427,7 +1408,6 @@ DEFINE_CODEGEN_CMP_OP(GT);
DEFINE_CODEGEN_CMP_OP(GE);

llvm::Value* CodeGenLLVM::VisitExpr_(const DivNode* op) {
EmitDebugLocation(op);
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
if (op->dtype.is_int()) {
Expand All @@ -1441,7 +1421,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const DivNode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) {
EmitDebugLocation(op);
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
if (op->dtype.is_int()) {
Expand All @@ -1455,21 +1434,18 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const MinNode* op) {
EmitDebugLocation(op);
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b);
}

llvm::Value* CodeGenLLVM::VisitExpr_(const MaxNode* op) {
EmitDebugLocation(op);
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b);
}

llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) {
EmitDebugLocation(op);
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
Expand All @@ -1480,7 +1456,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const NENode* op) {
EmitDebugLocation(op);
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
if (op->a.dtype().is_int() || op->a.dtype().is_uint()) {
Expand All @@ -1491,28 +1466,23 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const NENode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const AndNode* op) {
EmitDebugLocation(op);
return builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b));
}

llvm::Value* CodeGenLLVM::VisitExpr_(const OrNode* op) {
EmitDebugLocation(op);
return builder_->CreateOr(MakeValue(op->a), MakeValue(op->b));
}

llvm::Value* CodeGenLLVM::VisitExpr_(const NotNode* op) {
EmitDebugLocation(op);
return builder_->CreateNot(MakeValue(op->a));
}

llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) {
EmitDebugLocation(op);
return builder_->CreateSelect(MakeValue(op->condition), MakeValue(op->true_value),
MakeValue(op->false_value));
}

llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) {
EmitDebugLocation(op);
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
ICHECK(deep_equal_(it->second->value, op->value))
Expand Down Expand Up @@ -1630,7 +1600,6 @@ void CodeGenLLVM::BufferAccessHelper(
}

llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
EmitDebugLocation(op);
DataType value_dtype = op->dtype;

std::vector<llvm::Value*> loads;
Expand Down Expand Up @@ -1668,7 +1637,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
EmitDebugLocation(op);
if (auto* ptr_op = op->op.as<OpNode>()) {
auto call_op = GetRef<Op>(ptr_op);
if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
Expand All @@ -1695,7 +1663,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) {
EmitDebugLocation(op);
llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype));
for (int i = 0; i < op->lanes; ++i) {
vec = builder_->CreateInsertElement(
Expand All @@ -1705,7 +1672,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) {
EmitDebugLocation(op);
std::vector<llvm::Value*> vecs(op->vectors.size());
int total_lanes = 0;
for (int i = 0, e = op->vectors.size(); i < e; ++i) {
Expand All @@ -1730,7 +1696,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) {
}

llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) {
EmitDebugLocation(op);
return CreateBroadcast(MakeValue(op->value), op->lanes);
}

Expand Down
17 changes: 7 additions & 10 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@
#else
#include <llvm/IR/Operator.h>
#endif
#include <llvm/IR/DebugInfoMetadata.h>
#include <llvm/IR/GlobalValue.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/Intrinsics.h>
#include <llvm/Support/Casting.h>
#include <llvm/IR/DebugInfoMetadata.h>
#if TVM_LLVM_VERSION >= 140
#include <llvm/MC/TargetRegistry.h>
#else
Expand Down Expand Up @@ -534,21 +534,18 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,

void EmitDebugLocation(const Span& span) {
ICHECK(di_subprogram_ != nullptr) << "DISubprogram not initialized";
llvm::LLVMContext* ctx = llvm_target_->GetContext();
if (!span.defined()) {
auto loc = llvm::DebugLoc(llvm::DILocation::get(*ctx, 212, 212, di_subprogram_));
builder_->SetCurrentDebugLocation(loc);
} else {
auto loc =
llvm::DebugLoc(llvm::DILocation::get(*ctx, span->line, span->column, di_subprogram_));
builder_->SetCurrentDebugLocation(loc);
VLOG(0) << "Cannot emit debug location for undefined span";
return;
}
llvm::LLVMContext* ctx = llvm_target_->GetContext();
auto loc =
llvm::DebugLoc(llvm::DILocation::get(*ctx, span->line, span->column, di_subprogram_));
builder_->SetCurrentDebugLocation(loc);
}

void EmitDebugLocation() { builder_->SetCurrentDebugLocation(nullptr); }

void EmitDebugLocation(const StmtNode* op) { EmitDebugLocation(op->span); }
void EmitDebugLocation(const PrimExprNode* op) { EmitDebugLocation(op->span); }

/*! \brief Helper struct for debug infos. */
struct DebugInfo {
Expand Down
Loading

0 comments on commit d66fe0d

Please sign in to comment.