Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ir] Remove the usage of kernel from type_check() pass #1848

Merged
merged 2 commits into from
Sep 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def extract(self, x):
def decl_scalar_arg(dtype):
dtype = cook_dtype(dtype)
id = taichi_lang_core.decl_arg(dtype, False)
return Expr(taichi_lang_core.make_arg_load_expr(id))
return Expr(taichi_lang_core.make_arg_load_expr(id, dtype))


def decl_ext_arr_arg(dtype, dim):
Expand Down
8 changes: 7 additions & 1 deletion python/taichi/lang/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,9 +866,15 @@ def visit_Return(self, node):
ret_expr = self.parse_expr('ti.cast(ti.Expr(0), 0)')
ret_expr.args[0].args[0] = node.value
ret_expr.args[1] = self.returns
dt_expr = self.parse_expr('ti.cook_dtype(0)')
dt_expr.args[0] = self.returns
ret_stmt = self.parse_stmt(
'ti.core.create_kernel_return(ret.ptr)')
'ti.core.create_kernel_return(ret.ptr, 0)')
# For args[0], it is an ast.Attribute, because it loads the
# attribute, |ptr|, of the expression |ret_expr|. Therefore we
# only need to replace the object part, i.e. args[0].value
ret_stmt.value.args[0].value = ret_expr
ret_stmt.value.args[1] = dt_expr
return ret_stmt
return node

Expand Down
14 changes: 8 additions & 6 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ class FrontendKernelReturnStmt : public Stmt {
public:
Expr value;

FrontendKernelReturnStmt(const Expr &value) : value(value) {
FrontendKernelReturnStmt(const Expr &value, DataType dt) : value(value) {
ret_type = VectorType(1, dt);
}

bool is_container_statement() const override {
Expand All @@ -217,17 +218,18 @@ class FrontendKernelReturnStmt : public Stmt {
class ArgLoadExpression : public Expression {
public:
int arg_id;
DataType dt;

ArgLoadExpression(int arg_id) : arg_id(arg_id) {
ArgLoadExpression(int arg_id, DataType dt) : arg_id(arg_id), dt(dt) {
}

std::string serialize() override {
return fmt::format("arg[{}]", arg_id);
return fmt::format("arg[{}] (dt={})", arg_id, data_type_name(dt));
}

void flatten(FlattenContext *ctx) override {
auto ran = std::make_unique<ArgLoadStmt>(arg_id);
ctx->push_back(std::move(ran));
auto argl = std::make_unique<ArgLoadStmt>(arg_id, dt);
ctx->push_back(std::move(argl));
stmt = ctx->back_stmt();
}
};
Expand Down Expand Up @@ -380,7 +382,7 @@ class ExternalTensorExpression : public Expression {
}

void flatten(FlattenContext *ctx) override {
auto ptr = Stmt::make<ArgLoadStmt>(arg_id, true);
auto ptr = Stmt::make<ArgLoadStmt>(arg_id, dt, /*is_ptr=*/true);
ctx->push_back(std::move(ptr));
stmt = ctx->back_stmt();
}
Expand Down
6 changes: 4 additions & 2 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,8 @@ class ArgLoadStmt : public Stmt {
public:
int arg_id;

ArgLoadStmt(int arg_id, bool is_ptr = false) : arg_id(arg_id) {
ArgLoadStmt(int arg_id, DataType dt, bool is_ptr = false) : arg_id(arg_id) {
this->ret_type = VectorType(1, dt);
this->is_ptr = is_ptr;
TI_STMT_REG_FIELDS;
}
Expand Down Expand Up @@ -1336,7 +1337,8 @@ class KernelReturnStmt : public Stmt {
public:
Stmt *value;

KernelReturnStmt(Stmt *value) : value(value) {
KernelReturnStmt(Stmt *value, DataType dt) : value(value) {
this->ret_type = VectorType(1, dt);
TI_STMT_REG_FIELDS;
}

Expand Down
10 changes: 5 additions & 5 deletions taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,13 +591,13 @@ Arch Program::get_snode_accessor_arch() {
Kernel &Program::get_snode_reader(SNode *snode) {
TI_ASSERT(snode->type == SNodeType::place);
auto kernel_name = fmt::format("snode_reader_{}", snode->id);
auto &ker = kernel([&] {
auto &ker = kernel([snode] {
ExprGroup indices;
for (int i = 0; i < snode->num_active_indices; i++) {
indices.push_back(Expr::make<ArgLoadExpression>(i));
indices.push_back(Expr::make<ArgLoadExpression>(i, DataType::i32));
}
auto ret = Stmt::make<FrontendKernelReturnStmt>(
load_if_ptr((snode->expr)[indices]));
load_if_ptr((snode->expr)[indices]), snode->dt);
current_ast_builder().insert(std::move(ret));
});
ker.set_arch(get_snode_accessor_arch());
Expand All @@ -615,10 +615,10 @@ Kernel &Program::get_snode_writer(SNode *snode) {
auto &ker = kernel([&] {
ExprGroup indices;
for (int i = 0; i < snode->num_active_indices; i++) {
indices.push_back(Expr::make<ArgLoadExpression>(i));
indices.push_back(Expr::make<ArgLoadExpression>(i, DataType::i32));
}
(snode->expr)[indices] =
Expr::make<ArgLoadExpression>(snode->num_active_indices);
Expr::make<ArgLoadExpression>(snode->num_active_indices, snode->dt);
});
ker.set_arch(get_snode_accessor_arch());
ker.name = kernel_name;
Expand Down
8 changes: 5 additions & 3 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,9 @@ void export_lang(py::module &m) {
current_ast_builder().insert(Stmt::make<FrontendBreakStmt>());
});

m.def("create_kernel_return", [&](const Expr &value) {
current_ast_builder().insert(Stmt::make<FrontendKernelReturnStmt>(value));
m.def("create_kernel_return", [&](const Expr &value, const DataType &dt) {
current_ast_builder().insert(
Stmt::make<FrontendKernelReturnStmt>(value, dt));
});

m.def("insert_continue_stmt", [&]() {
Expand Down Expand Up @@ -477,7 +478,8 @@ void export_lang(py::module &m) {
m.def("make_frontend_assign_stmt",
Stmt::make<FrontendAssignStmt, const Expr &, const Expr &>);

m.def("make_arg_load_expr", Expr::make<ArgLoadExpression, int>);
m.def("make_arg_load_expr",
Expr::make<ArgLoadExpression, int, const DataType &>);

m.def("make_external_tensor_expr",
Expr::make<ExternalTensorExpression, const DataType &, int, int>);
Expand Down
16 changes: 8 additions & 8 deletions taichi/transforms/constant_fold.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
#include <cmath>
#include <deque>
#include <set>
#include <cmath>
#include <thread>

#include "taichi/ir/ir.h"
#include "taichi/ir/snode.h"
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"
#include "taichi/program/program.h"
#include "taichi/ir/ir.h"
#include "taichi/program/program.h"
#include "taichi/ir/snode.h"

TLANG_NAMESPACE_BEGIN

Expand All @@ -34,9 +32,11 @@ class ConstantFold : public BasicStmtVisitor {
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Off-topic: lines 10-11 are redundant in this file.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done (removed + reordered)


auto kernel_name = fmt::format("jit_evaluator_{}", cache.size());
auto func = [&]() {
auto lhstmt = Stmt::make<ArgLoadStmt>(0, false);
auto rhstmt = Stmt::make<ArgLoadStmt>(1, false);
auto func = [&id]() {
auto lhstmt =
Stmt::make<ArgLoadStmt>(/*arg_id=*/0, id.lhs, /*is_ptr=*/false);
auto rhstmt =
Stmt::make<ArgLoadStmt>(/*arg_id=*/1, id.rhs, /*is_ptr=*/false);
pStmt oper;
if (id.is_binary) {
oper = Stmt::make<BinaryOpStmt>(id.binary_op(), lhstmt.get(),
Expand All @@ -47,7 +47,7 @@ class ConstantFold : public BasicStmtVisitor {
oper->cast<UnaryOpStmt>()->cast_type = id.rhs;
}
}
auto ret = Stmt::make<KernelReturnStmt>(oper.get());
auto ret = Stmt::make<KernelReturnStmt>(oper.get(), id.ret);
current_ast_builder().insert(std::move(lhstmt));
if (id.is_binary)
current_ast_builder().insert(std::move(rhstmt));
Expand Down
4 changes: 3 additions & 1 deletion taichi/transforms/lower_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,9 @@ class LowerAST : public IRVisitor {
auto expr = stmt->value;
auto fctx = make_flatten_ctx();
expr->flatten(&fctx);
fctx.push_back<KernelReturnStmt>(fctx.back_stmt());
const auto dt = stmt->element_type();
TI_ASSERT(dt != DataType::unknown);
fctx.push_back<KernelReturnStmt>(fctx.back_stmt(), dt);
stmt->parent->replace_with(stmt, std::move(fctx.stmts));
throw IRModified();
}
Expand Down
31 changes: 11 additions & 20 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@ TLANG_NAMESPACE_BEGIN
// Var lookup and Type inference
class TypeCheck : public IRVisitor {
private:
Kernel *kernel;
CompileConfig config;

public:
TypeCheck(IRNode *root) {
kernel = root->get_kernel();
auto *kernel = root->get_kernel();
if (kernel != nullptr) {
config = kernel->program.config;
}
Expand Down Expand Up @@ -324,27 +323,19 @@ class TypeCheck : public IRVisitor {
}

void visit(ArgLoadStmt *stmt) {
Kernel *current_kernel = kernel;
if (current_kernel == nullptr) {
current_kernel = stmt->get_kernel();
}
TI_ASSERT(current_kernel != nullptr);
auto &args = current_kernel->args;
TI_ASSERT(0 <= stmt->arg_id && stmt->arg_id < args.size());
stmt->ret_type = VectorType(1, args[stmt->arg_id].dt);
const auto &rt = stmt->ret_type;
// TODO: Maybe have a type_inference() pass, which takes in the args/rets
// defined by the kernel. After that, type_check() pass will purely do
// verification, without modifying any types.
TI_ASSERT(rt.data_type != DataType::unknown);
TI_ASSERT(rt.width == 1);
}

void visit(KernelReturnStmt *stmt) {
Kernel *current_kernel = kernel;
if (current_kernel == nullptr) {
current_kernel = stmt->get_kernel();
}
auto &rets = current_kernel->rets;
TI_ASSERT(rets.size() >= 1);
auto ret = rets[0]; // TODO: stmt->ret_id?
auto ret_type = ret.dt;
TI_ASSERT(stmt->value->ret_type.data_type == ret_type);
stmt->ret_type = VectorType(1, ret_type);
// TODO: Support stmt->ret_id?
const auto &rt = stmt->ret_type;
TI_ASSERT(stmt->value->element_type() == rt.data_type);
TI_ASSERT(rt.width == 1);
}

void visit(ExternalPtrStmt *stmt) {
Expand Down