Skip to content

Commit

Permalink
[ir] [refactor] Let the type of Alloca be pointer
Browse files Browse the repository at this point in the history
ghstack-source-id: 5ca25de2e989b93b5daa78ee5204f552d80bc2f0
Pull Request resolved: #8007
  • Loading branch information
lin-hitonami committed May 30, 2023
1 parent afbc85b commit 01b06f6
Show file tree
Hide file tree
Showing 27 changed files with 394 additions and 220 deletions.
2 changes: 1 addition & 1 deletion python/taichi/lang/_ndrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, *args):
for arg in args:
for bound in arg:
if not isinstance(bound, (int, np.integer)) and not (
isinstance(bound, Expr) and is_integral(bound.ptr.get_ret_type())
isinstance(bound, Expr) and is_integral(bound.ptr.get_rvalue_type())
):
raise TaichiTypeError(
"Every argument of ndrange should be an integer scalar or a tuple/list of (int, int)"
Expand Down
8 changes: 4 additions & 4 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def reshape_list(flat_list, target_shape):


def boundary_type_cast_warning(expression):
expr_dtype = expression.ptr.get_ret_type()
expr_dtype = expression.ptr.get_rvalue_type()
if not is_integral(expr_dtype) or expr_dtype in [
primitive_types.i64,
primitive_types.u64,
Expand Down Expand Up @@ -107,7 +107,7 @@ def build_assign_annotated(ctx, target, value, is_static_assign, annotation):
ctx.create_variable(target.id, var)
else:
var = build_stmt(ctx, target)
if var.ptr.get_ret_type() != anno:
if var.ptr.get_rvalue_type() != anno:
raise TaichiSyntaxError("Static assign cannot have type overloading")
var._assign(value)
return var
Expand Down Expand Up @@ -709,7 +709,7 @@ def transform_as_kernel():
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix, but got {type(data)}."
)

element_shape = data.ptr.get_ret_type().shape()
element_shape = data.ptr.get_rvalue_type().shape()
if len(element_shape) != ctx.func.arguments[i].annotation.ndim:
raise TaichiSyntaxError(
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix with ndim {ctx.func.arguments[i].annotation.ndim}, but got {len(element_shape)}."
Expand Down Expand Up @@ -1449,7 +1449,7 @@ def ti_format_list_to_assert_msg(raw):
if isinstance(entry, str):
msg += entry
elif isinstance(entry, _ti_core.Expr):
ty = entry.get_ret_type()
ty = entry.get_rvalue_type()
if ty in primitive_types.real_types:
msg += "%f"
elif ty in primitive_types.integer_types:
Expand Down
8 changes: 4 additions & 4 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,25 +48,25 @@ def is_struct(self):
return self.ptr.is_struct()

def element_type(self):
return self.ptr.get_ret_type().element_type()
return self.ptr.get_rvalue_type().element_type()

def get_shape(self):
if not self.is_tensor():
raise TaichiCompilationError(f"Getting shape of non-tensor type: {self.ptr.get_ret_type()}")
raise TaichiCompilationError(f"Getting shape of non-tensor type: {self.ptr.get_rvalue_type()}")
return tuple(self.ptr.get_shape())

@property
def n(self):
shape = self.get_shape()
if len(shape) < 1:
raise TaichiCompilationError(f"Getting n of tensor type < 1D: {self.ptr.get_ret_type()}")
raise TaichiCompilationError(f"Getting n of tensor type < 1D: {self.ptr.get_rvalue_type()}")
return shape[0]

@property
def m(self):
shape = self.get_shape()
if len(shape) < 2:
raise TaichiCompilationError(f"Getting m of tensor type < 2D: {self.ptr.get_ret_type()}")
raise TaichiCompilationError(f"Getting m of tensor type < 2D: {self.ptr.get_rvalue_type()}")
return shape[1]

def __hash__(self):
Expand Down
6 changes: 3 additions & 3 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _infer_entry_dt(entry):
if isinstance(entry, (float, np.floating)):
return impl.get_runtime().default_fp
if isinstance(entry, expr.Expr):
dt = entry.ptr.get_ret_type()
dt = entry.ptr.get_rvalue_type()
if dt == ti_python_core.DataType_unknown:
raise TaichiTypeError("Element type of the matrix cannot be inferred. Please set dt instead for now.")
return dt
Expand Down Expand Up @@ -1412,7 +1412,7 @@ def __call__(self, *args):
# Init from a real Matrix
if isinstance(args[0], expr.Expr) and args[0].ptr.is_tensor():
arg = args[0]
shape = arg.ptr.get_ret_type().shape()
shape = arg.ptr.get_rvalue_type().shape()
assert self.ndim == len(shape)
assert self.n == shape[0]
if self.ndim > 1:
Expand Down Expand Up @@ -1554,7 +1554,7 @@ def __call__(self, *args):
# Init from a real Matrix
if isinstance(args[0], expr.Expr) and args[0].ptr.is_tensor():
arg = args[0]
shape = arg.ptr.get_ret_type().shape()
shape = arg.ptr.get_rvalue_type().shape()
assert len(shape) == 1
assert self.n == shape[0]
return expr.Expr(arg.ptr)
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/matrix_ops_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def same_shapes(*xs):
elif isinstance(x, list):
shapes.append(tuple(get_list_shape(x)))
elif isinstance(x, Expr):
shapes.append(tuple(x.ptr.get_ret_type().shape()))
shapes.append(tuple(x.ptr.get_rvalue_type().shape()))
else:
return False, f"same_shapes() received an unexpected argument of type: {x}"

Expand Down
18 changes: 1 addition & 17 deletions taichi/analysis/gen_offline_cache_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,6 @@ namespace taichi::lang {

namespace {

enum class ExprOpCode : std::uint8_t {
NIL,
#define PER_EXPRESSION(x) x,
#include "taichi/inc/expressions.inc.h"
#undef PER_EXPRESSION
};

enum class StmtOpCode : std::uint8_t {
NIL,
EnterBlock,
ExitBlock,
StopGrad,
#define PER_STATEMENT(x) x,
#include "taichi/inc/frontend_statements.inc.h"
#undef PER_STATEMENT
};

enum class ForLoopType : std::uint8_t {
StructForOnSNode,
StructForOnExternalTensor,
Expand Down Expand Up @@ -198,6 +181,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
void visit(IdExpression *expr) override {
emit(ExprOpCode::IdExpression);
emit(expr->id);
emit(expr->op);
}

void visit(AtomicOpExpression *expr) override {
Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {

void visit(AllocaStmt *stmt) override {
// Override shared memory codegen logic for large shared memory
if (stmt->ret_type->is<TensorType>() && stmt->is_shared) {
auto tensor_type = stmt->ret_type->cast<TensorType>();
auto tensor_type = stmt->ret_type.ptr_removed()->cast<TensorType>();
if (tensor_type && stmt->is_shared) {
size_t shared_array_bytes =
tensor_type->get_num_elements() *
data_type_size(tensor_type->get_element_type());
Expand Down
18 changes: 11 additions & 7 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,10 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) {
}

void TaskCodeGenLLVM::visit(AllocaStmt *stmt) {
if (stmt->ret_type->is<TensorType>()) {
auto tensor_type = stmt->ret_type->cast<TensorType>();
TI_ASSERT(stmt->ret_type.is_pointer());
auto alloca_type = stmt->ret_type->as<PointerType>()->get_pointee_type();
if (alloca_type->is<TensorType>()) {
auto tensor_type = alloca_type->cast<TensorType>();
auto type = tlctx->get_data_type(tensor_type);
if (stmt->is_shared) {
auto base = new llvm::GlobalVariable(
Expand All @@ -141,12 +143,10 @@ void TaskCodeGenLLVM::visit(AllocaStmt *stmt) {
llvm_val[stmt] = create_entry_block_alloca(type);
}
} else {
llvm_val[stmt] =
create_entry_block_alloca(stmt->ret_type, stmt->ret_type.is_pointer());
llvm_val[stmt] = create_entry_block_alloca(alloca_type);
// initialize as zero if element is not a pointer
if (!stmt->ret_type.is_pointer())
builder->CreateStore(tlctx->get_constant(stmt->ret_type, 0),
llvm_val[stmt]);
if (!alloca_type->is<PointerType>())
builder->CreateStore(tlctx->get_constant(alloca_type, 0), llvm_val[stmt]);
}
}

Expand Down Expand Up @@ -878,6 +878,7 @@ void TaskCodeGenLLVM::visit(IfStmt *if_stmt) {
llvm::BasicBlock::Create(*llvm_context, "false_block", func);
llvm::BasicBlock *after_if =
llvm::BasicBlock::Create(*llvm_context, "after_if", func);
llvm_val[if_stmt->cond]->dump();
llvm::Value *cond = builder->CreateIsNotNull(llvm_val[if_stmt->cond]);
builder->CreateCondBr(cond, true_block, false_block);
builder->SetInsertPoint(true_block);
Expand Down Expand Up @@ -1310,6 +1311,9 @@ void TaskCodeGenLLVM::visit(LocalLoadStmt *stmt) {
}

void TaskCodeGenLLVM::visit(LocalStoreStmt *stmt) {
// irpass::print(stmt);
// llvm_val[stmt->val]->dump();
// llvm_val[stmt->dest]->dump();
builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]);
}

Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/llvm/llvm_codegen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ class LLVMModuleBuilder {

llvm::Value *create_entry_block_alloca(DataType dt, bool is_pointer = false) {
auto type = tlctx->get_data_type(dt);
if (is_pointer)
type = llvm::PointerType::get(type, 0);
// if (is_pointer)
// type = llvm::PointerType::get(type, 0);
return create_entry_block_alloca(type);
}

Expand Down
6 changes: 3 additions & 3 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ class TaskCodegen : public IRVisitor {

void visit(AllocaStmt *alloca) override {
spirv::Value ptr_val;
if (alloca->ret_type->is<TensorType>()) {
auto tensor_type = alloca->ret_type->cast<TensorType>();
auto alloca_type = alloca->ret_type->as<PointerType>()->get_pointee_type();
if (auto tensor_type = alloca_type->cast<TensorType>()) {
auto elem_num = tensor_type->get_num_elements();
spirv::SType elem_type =
ir_->get_primitive_type(tensor_type->get_element_type());
Expand All @@ -288,7 +288,7 @@ class TaskCodegen : public IRVisitor {
}
} else {
// Alloca for a single variable
spirv::SType src_type = ir_->get_primitive_type(alloca->element_type());
spirv::SType src_type = ir_->get_primitive_type(alloca_type);
ptr_val = ir_->alloca_variable(src_type);
ir_->store_variable(ptr_val, ir_->get_zero(src_type));
}
Expand Down
7 changes: 4 additions & 3 deletions taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const {
// result: the value to store
Stmt *result = irpass::analysis::get_store_data(
block->statements[last_def_position].get());
bool is_tensor_involved = var->ret_type->is<TensorType>();
bool is_tensor_involved = var->ret_type.ptr_removed()->is<TensorType>();
if (!(var->is<AllocaStmt>() && !is_tensor_involved)) {
// In between the store stmt and current stmt,
// if there's a third-stmt that "may" have stored a "different value" to
Expand Down Expand Up @@ -355,7 +355,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const {

// Check for aliased address
// There's a store to the same dest_addr before this stmt
bool is_tensor_involved = var->ret_type->is<TensorType>();
bool is_tensor_involved = var->ret_type.ptr_removed()->is<TensorType>();
if (!(var->is<AllocaStmt>() && !is_tensor_involved)) {
// In between the store stmt and current stmt,
// if there's a third-stmt that "may" have stored a "different value" to
Expand Down Expand Up @@ -440,7 +440,8 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access,
continue;

// special case of alloca (initialized to 0)
auto zero = Stmt::make<ConstStmt>(TypedConstant(result->ret_type, 0));
auto zero = Stmt::make<ConstStmt>(
TypedConstant(result->ret_type.ptr_removed(), 0));
replace_with(i, std::move(zero), true);
} else {
if (result->ret_type.ptr_removed()->is<TensorType>() &&
Expand Down
17 changes: 14 additions & 3 deletions taichi/ir/expression_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {
}

void visit(TernaryOpExpression *expr) override {
emit(ternary_type_name(expr->type), '(');
emit(ternary_type_name(expr->type), "(op1: ");
expr->op1->accept(this);
emit(' ');
emit(", op2: ");
expr->op2->accept(this);
emit(' ');
emit(", op3: ");
expr->op3->accept(this);
emit(')');
}
Expand Down Expand Up @@ -125,6 +125,7 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {
}

void visit(IndexExpression *expr) override {
emit("<" + expr->ret_type->to_string() + ">");
expr->var->accept(this);
emit('[');
if (expr->ret_shape.empty()) {
Expand Down Expand Up @@ -164,7 +165,10 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {
}

void visit(IdExpression *expr) override {
emit("<" + expr->ret_type->to_string() + ">");
emit(expr->id.name());
emit(": ");
emit(to_string(expr->op));
}

void visit(AtomicOpExpression *expr) override {
Expand Down Expand Up @@ -251,6 +255,13 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {
return oss.str();
}

static std::string expr_to_string(Expression *expr) {
std::ostringstream oss;
ExpressionHumanFriendlyPrinter printer(&oss);
expr->accept(&printer);
return oss.str();
}

protected:
template <typename... Args>
void emit(Args &&...args) {
Expand Down
Loading

0 comments on commit 01b06f6

Please sign in to comment.