Skip to content

Commit

Permalink
[ir] MatrixField refactor 1/n: Make a GlobalVariableExpression solely…
Browse files Browse the repository at this point in the history
… represent a field (#5980)

* [ir] MatrixField refactor 1/n: Make a GlobalVariableExpression solely represents a field

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix cpp tests

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Sep 6, 2022
1 parent 5b18fc8 commit a88a0e2
Show file tree
Hide file tree
Showing 12 changed files with 223 additions and 216 deletions.
2 changes: 0 additions & 2 deletions python/taichi/lang/any_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ def shape(self):
def _loop_range(self):
"""Gets the corresponding taichi_python.Expr to serve as loop range.
This is not in use now because struct fors on AnyArrays are not supported yet.
Returns:
taichi_python.Expr: See above.
"""
Expand Down
6 changes: 3 additions & 3 deletions python/taichi/lang/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ def _get_field_members(self):
return self.vars

def _loop_range(self):
"""Gets representative field member for loop range info.
"""Gets SNode of representative field member for loop range info.
Returns:
taichi_python.Expr: Representative (first) field member.
taichi_python.SNode: SNode of representative (first) field member.
"""
return self.vars[0].ptr
return self.vars[0].ptr.snode()

def _set_grad(self, grad):
"""Sets corresponding grad field (reverse mode).
Expand Down
7 changes: 6 additions & 1 deletion python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,12 @@ def begin_frontend_struct_for(ast_builder, group, loop_range):
f'({group.size()} != {len(loop_range.shape)}). Maybe you wanted to '
'use "for I in ti.grouped(x)" to group all indices into a single vector I?'
)
ast_builder.begin_frontend_struct_for(group, loop_range._loop_range())
if isinstance(loop_range, AnyArray):
ast_builder.begin_frontend_struct_for_on_external_tensor(
group, loop_range._loop_range())
else:
ast_builder.begin_frontend_struct_for_on_snode(
group, loop_range._loop_range())


def begin_frontend_if(ast_builder, cond):
Expand Down
6 changes: 3 additions & 3 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,12 @@ def shape(self):
return ret

def _loop_range(self):
"""Gets the taichi_python.Expr wrapping the taichi_python.GlobalVariableExpression corresponding to `self` to serve as loop range.
"""Gets the taichi_python.SNode to serve as loop range.
Returns:
taichi_python.Expr: See above.
taichi_python.SNode: See above.
"""
return impl.get_runtime().prog.global_var_expr_from_snode(self.ptr)
return self.ptr

@property
def _name(self):
Expand Down
4 changes: 2 additions & 2 deletions python/taichi/lang/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,10 +510,10 @@ def _snode(self):
return self._members[0]._snode

def _loop_range(self):
"""Gets representative field member for loop range info.
"""Gets SNode of representative field member for loop range info.
Returns:
taichi_python.Expr: Representative (first) field member.
taichi_python.SNode: SNode of representative (first) field member.
"""
return self._members[0]._loop_range()

Expand Down
25 changes: 14 additions & 11 deletions taichi/analysis/gen_offline_cache_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ enum class StmtOpCode : std::uint8_t {
};

enum class ForLoopType : std::uint8_t {
RangeFor,
StructFor,
StructForOnSNode,
StructForOnExternalTensor,
MeshFor,
RangeFor
};

enum class ExternalFuncType : std::uint8_t {
Expand Down Expand Up @@ -357,20 +358,22 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {

void visit(FrontendForStmt *stmt) override {
emit(StmtOpCode::FrontendForStmt);
if (stmt->is_ranged()) {
emit(ForLoopType::RangeFor);
emit(stmt->loop_var_id);
emit(stmt->begin);
emit(stmt->end);
} else if (stmt->mesh_for) {
if (stmt->snode) {
emit(ForLoopType::StructForOnSNode);
emit(stmt->snode);
} else if (stmt->external_tensor) {
emit(ForLoopType::StructForOnExternalTensor);
emit(stmt->external_tensor);
} else if (stmt->mesh) {
emit(ForLoopType::MeshFor);
emit(stmt->element_type);
emit(stmt->mesh);
} else {
emit(ForLoopType::StructFor);
emit(stmt->loop_var_id);
emit(stmt->global_var);
emit(ForLoopType::RangeFor);
emit(stmt->begin);
emit(stmt->end);
}
emit(stmt->loop_var_ids);
emit(stmt->is_bit_vectorized);
emit(stmt->num_cpu_threads);
emit(stmt->strictly_serialized);
Expand Down
128 changes: 64 additions & 64 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,85 +32,70 @@ FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, const Expr &rhs)
}
}

IRNode *FrontendContext::root() {
return static_cast<IRNode *>(root_node_.get());
FrontendForStmt::FrontendForStmt(const ExprGroup &loop_vars,
SNode *snode,
Arch arch,
const ForLoopConfig &config)
: snode(snode) {
init_config(arch, config);
init_loop_vars(loop_vars);
}

FrontendForStmt::FrontendForStmt(const ExprGroup &loop_var,
const Expr &global_var,
FrontendForStmt::FrontendForStmt(const ExprGroup &loop_vars,
const Expr &external_tensor,
Arch arch,
const ForLoopConfig &config)
: global_var(global_var),
is_bit_vectorized(config.is_bit_vectorized),
num_cpu_threads(config.num_cpu_threads),
strictly_serialized(config.strictly_serialized),
mem_access_opt(config.mem_access_opt),
block_dim(config.block_dim) {
if (arch == Arch::cuda) {
this->num_cpu_threads = 1;
TI_ASSERT(this->block_dim <= taichi_max_gpu_block_dim);
} else {
// cpu
if (this->num_cpu_threads == 0)
this->num_cpu_threads = std::thread::hardware_concurrency();
}
loop_var_id.reserve(loop_var.size());
for (int i = 0; i < (int)loop_var.size(); i++) {
loop_var_id.push_back(loop_var[i].cast<IdExpression>()->id);
loop_var[i].expr->ret_type = PrimitiveType::i32;
}
: external_tensor(external_tensor) {
init_config(arch, config);
init_loop_vars(loop_vars);
}

FrontendForStmt::FrontendForStmt(const ExprGroup &loop_var,
FrontendForStmt::FrontendForStmt(const ExprGroup &loop_vars,
const mesh::MeshPtr &mesh,
const mesh::MeshElementType &element_type,
Arch arch,
const ForLoopConfig &config)
: is_bit_vectorized(config.is_bit_vectorized),
num_cpu_threads(config.num_cpu_threads),
mem_access_opt(config.mem_access_opt),
block_dim(config.block_dim),
mesh_for(true),
mesh(mesh.ptr.get()),
element_type(element_type) {
if (arch == Arch::cuda) {
this->num_cpu_threads = 1;
TI_ASSERT(this->block_dim <= taichi_max_gpu_block_dim);
} else {
// cpu
if (this->num_cpu_threads == 0)
this->num_cpu_threads = std::thread::hardware_concurrency();
}
loop_var_id.reserve(loop_var.size());
for (int i = 0; i < (int)loop_var.size(); i++) {
loop_var_id.push_back(loop_var[i].cast<IdExpression>()->id);
}
}

FrontendContext::FrontendContext(Arch arch) {
root_node_ = std::make_unique<Block>();
current_builder_ = std::make_unique<ASTBuilder>(root_node_.get(), arch);
: mesh(mesh.ptr.get()), element_type(element_type) {
init_config(arch, config);
init_loop_vars(loop_vars);
}

FrontendForStmt::FrontendForStmt(const Expr &loop_var,
const Expr &begin,
const Expr &end,
Arch arch,
const ForLoopConfig &config)
: begin(begin),
end(end),
is_bit_vectorized(config.is_bit_vectorized),
num_cpu_threads(config.num_cpu_threads),
strictly_serialized(config.strictly_serialized),
mem_access_opt(config.mem_access_opt),
block_dim(config.block_dim) {
: begin(begin), end(end) {
init_config(arch, config);
add_loop_var(loop_var);
}

void FrontendForStmt::init_config(Arch arch, const ForLoopConfig &config) {
is_bit_vectorized = config.is_bit_vectorized;
strictly_serialized = config.strictly_serialized;
mem_access_opt = config.mem_access_opt;
block_dim = config.block_dim;
if (arch == Arch::cuda) {
this->num_cpu_threads = 1;
} else {
if (this->num_cpu_threads == 0)
this->num_cpu_threads = std::thread::hardware_concurrency();
num_cpu_threads = 1;
TI_ASSERT(block_dim <= taichi_max_gpu_block_dim);
} else { // cpu
if (config.num_cpu_threads == 0) {
num_cpu_threads = std::thread::hardware_concurrency();
} else {
num_cpu_threads = config.num_cpu_threads;
}
}
loop_var_id.push_back(loop_var.cast<IdExpression>()->id);
}

void FrontendForStmt::init_loop_vars(const ExprGroup &loop_vars) {
loop_var_ids.reserve(loop_vars.size());
for (int i = 0; i < (int)loop_vars.size(); i++) {
add_loop_var(loop_vars[i]);
}
}

void FrontendForStmt::add_loop_var(const Expr &loop_var) {
loop_var_ids.push_back(loop_var.cast<IdExpression>()->id);
loop_var.expr->ret_type = PrimitiveType::i32;
}

Expand Down Expand Up @@ -956,7 +941,7 @@ Expr ASTBuilder::insert_patch_idx_expr() {
}
}
TI_ERROR_IF(!(loop && loop->is<FrontendForStmt>() &&
loop->as<FrontendForStmt>()->mesh_for),
loop->as<FrontendForStmt>()->mesh),
"ti.mesh_patch_idx() is only valid within mesh-for loops.");
return Expr::make<MeshPatchIndexExpression>();
}
Expand Down Expand Up @@ -1083,20 +1068,35 @@ void ASTBuilder::begin_frontend_range_for(const Expr &i,
for_loop_dec_.reset();
}

void ASTBuilder::begin_frontend_struct_for(const ExprGroup &loop_vars,
const Expr &global) {
void ASTBuilder::begin_frontend_struct_for_on_snode(const ExprGroup &loop_vars,
SNode *snode) {
TI_WARN_IF(
for_loop_dec_.config.strictly_serialized,
"ti.loop_config(serialize=True) does not have effect on the struct for. "
"The execution order is not guaranteed.");
auto stmt_unique = std::make_unique<FrontendForStmt>(loop_vars, global, arch_,
auto stmt_unique = std::make_unique<FrontendForStmt>(loop_vars, snode, arch_,
for_loop_dec_.config);
for_loop_dec_.reset();
auto stmt = stmt_unique.get();
this->insert(std::move(stmt_unique));
this->create_scope(stmt->body, For);
}

void ASTBuilder::begin_frontend_struct_for_on_external_tensor(
const ExprGroup &loop_vars,
const Expr &external_tensor) {
TI_WARN_IF(
for_loop_dec_.config.strictly_serialized,
"ti.loop_config(serialize=True) does not have effect on the struct for. "
"The execution order is not guaranteed.");
auto stmt_unique = std::make_unique<FrontendForStmt>(
loop_vars, external_tensor, arch_, for_loop_dec_.config);
for_loop_dec_.reset();
auto stmt = stmt_unique.get();
this->insert(std::move(stmt_unique));
this->create_scope(stmt->body, For);
}

void ASTBuilder::begin_frontend_mesh_for(
const Expr &i,
const mesh::MeshPtr &mesh_ptr,
Expand Down
53 changes: 28 additions & 25 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,34 +161,30 @@ class FrontendPrintStmt : public Stmt {

class FrontendForStmt : public Stmt {
public:
SNode *snode{nullptr};
Expr external_tensor;
mesh::Mesh *mesh{nullptr};
mesh::MeshElementType element_type;
Expr begin, end;
Expr global_var;
std::unique_ptr<Block> body;
std::vector<Identifier> loop_var_id;
std::vector<Identifier> loop_var_ids;
bool is_bit_vectorized;
int num_cpu_threads;
bool strictly_serialized;
MemoryAccessOptions mem_access_opt;
int block_dim;

bool mesh_for = false;
mesh::Mesh *mesh;
mesh::MeshElementType element_type;

bool is_ranged() const {
if (global_var.expr == nullptr && !mesh_for) {
return true;
} else {
return false;
}
}
FrontendForStmt(const ExprGroup &loop_vars,
SNode *snode,
Arch arch,
const ForLoopConfig &config);

FrontendForStmt(const ExprGroup &loop_var,
const Expr &global_var,
FrontendForStmt(const ExprGroup &loop_vars,
const Expr &external_tensor,
Arch arch,
const ForLoopConfig &config);

FrontendForStmt(const ExprGroup &loop_var,
FrontendForStmt(const ExprGroup &loop_vars,
const mesh::MeshPtr &mesh,
const mesh::MeshElementType &element_type,
Arch arch,
Expand All @@ -205,6 +201,13 @@ class FrontendForStmt : public Stmt {
}

TI_DEFINE_ACCEPT

private:
void init_config(Arch arch, const ForLoopConfig &config);

void init_loop_vars(const ExprGroup &loop_vars);

void add_loop_var(const Expr &loop_var);
};

class FrontendFuncDefStmt : public Stmt {
Expand Down Expand Up @@ -499,10 +502,6 @@ class GlobalVariableExpression : public Expression {
: ident(ident), dt(dt) {
}

GlobalVariableExpression(SNode *snode, const Identifier &ident)
: ident(ident), dt(snode->dt), snode(snode) {
}

void type_check(CompileConfig *config) override {
}

Expand Down Expand Up @@ -929,8 +928,11 @@ class ASTBuilder {
const std::string &msg,
const std::vector<Expr> &args);
void begin_frontend_range_for(const Expr &i, const Expr &s, const Expr &e);
void begin_frontend_struct_for(const ExprGroup &loop_vars,
const Expr &global);
void begin_frontend_struct_for_on_snode(const ExprGroup &loop_vars,
SNode *snode);
void begin_frontend_struct_for_on_external_tensor(
const ExprGroup &loop_vars,
const Expr &external_tensor);
void begin_frontend_mesh_for(const Expr &i,
const mesh::MeshPtr &mesh_ptr,
const mesh::MeshElementType &element_type);
Expand Down Expand Up @@ -984,14 +986,15 @@ class FrontendContext {
std::unique_ptr<Block> root_node_;

public:
FrontendContext(Arch arch);
FrontendContext(Arch arch) {
root_node_ = std::make_unique<Block>();
current_builder_ = std::make_unique<ASTBuilder>(root_node_.get(), arch);
}

ASTBuilder &builder() {
return *current_builder_;
}

IRNode *root();

std::unique_ptr<Block> get_root() {
return std::move(root_node_);
}
Expand Down
Loading

0 comments on commit a88a0e2

Please sign in to comment.