Skip to content

Commit

Permalink
[lang] Migrate TensorType expansion for SNode indices from Python to …
Browse files Browse the repository at this point in the history
…Frontend IR (taichi-dev#6934)

Issue: taichi-dev#5819

### Brief Summary

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent 8a1a3fa commit 7d3505e
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 105 deletions.
54 changes: 10 additions & 44 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import functools
import numbers
import warnings

Expand All @@ -8,34 +7,6 @@
from taichi.lang.util import get_traceback


def _get_expanded_indices(indices):
if isinstance(indices, matrix.Matrix):
indices = indices.entries
elif isinstance(indices, expr.Expr) and indices.is_tensor():
indices = [
expr.Expr(x)
for x in impl.get_runtime().prog.current_ast_builder().expand_expr(
[indices.ptr])
]
return indices


def _expand_indices(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# indices is the second argument to ti.append, ti.activate, ...
if len(args) > 1:
args = list(args)
args[1] = _get_expanded_indices(args[1])
else:
assert "indices" in kwargs.keys()
kwargs["indices"] = _get_expanded_indices(kwargs["indices"])

return func(*args, **kwargs)

return wrapper


class SNode:
"""A Python-side SNode wrapper.
Expand Down Expand Up @@ -414,7 +385,6 @@ def _rescale_index():
return _rescale_index()


@_expand_indices
def append(node, indices, val):
"""Append a value `val` to a SNode `node` at index `indices`.
Expand All @@ -424,14 +394,14 @@ def append(node, indices, val):
val (:mod:`~taichi.types.primitive_types`): the scalar data to be appended, only i32 value is support for now.
"""
ptrs = expr._get_flattened_ptrs(val)
append_expr = expr.Expr(_ti_core.expr_snode_append(
node._snode.ptr, expr.make_expr_group(indices), ptrs),
tb=impl.get_runtime().get_current_src_info())
append_expr = expr.Expr(
impl.get_runtime().prog.current_ast_builder().expr_snode_append(
node._snode.ptr, expr.make_expr_group(indices), ptrs),
tb=impl.get_runtime().get_current_src_info())
a = impl.expr_init(append_expr)
return a


@_expand_indices
def is_active(node, indices):
"""Explicitly query whether a cell in a SNode `node` at location
`indices` is active or not.
Expand All @@ -444,11 +414,10 @@ def is_active(node, indices):
bool: the cell `node[indices]` is active or not.
"""
return expr.Expr(
_ti_core.expr_snode_is_active(node._snode.ptr,
expr.make_expr_group(indices)))
impl.get_runtime().prog.current_ast_builder().expr_snode_is_active(
node._snode.ptr, expr.make_expr_group(indices)))


@_expand_indices
def activate(node, indices):
"""Explicitly activate a cell of `node` at location `indices`.
Expand All @@ -460,7 +429,6 @@ def activate(node, indices):
node._snode.ptr, expr.make_expr_group(indices))


@_expand_indices
def deactivate(node, indices):
"""Explicitly deactivate a cell of `node` at location `indices`.
Expand All @@ -475,7 +443,6 @@ def deactivate(node, indices):
node._snode.ptr, expr.make_expr_group(indices))


@_expand_indices
def length(node, indices):
"""Return the length of the dynamic SNode `node` at index `indices`.
Expand All @@ -487,11 +454,10 @@ def length(node, indices):
int: the length of cell `node[indices]`.
"""
return expr.Expr(
_ti_core.expr_snode_length(node._snode.ptr,
expr.make_expr_group(indices)))
impl.get_runtime().prog.current_ast_builder().expr_snode_length(
node._snode.ptr, expr.make_expr_group(indices)))


@_expand_indices
def get_addr(f, indices):
"""Query the memory address (on CUDA/x64) of field `f` at index `indices`.
Expand All @@ -505,8 +471,8 @@ def get_addr(f, indices):
ti.u64: The memory address of `f[indices]`.
"""
return expr.Expr(
_ti_core.expr_snode_get_addr(f._snode.ptr,
expr.make_expr_group(indices)))
impl.get_runtime().prog.current_ast_builder().expr_snode_get_addr(
f._snode.ptr, expr.make_expr_group(indices)))


__all__ = [
Expand Down
19 changes: 0 additions & 19 deletions taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,25 +83,6 @@ Expr expr_rand(DataType dt) {
return Expr::make<RandExpression>(dt);
}

Expr snode_append(SNode *snode,
const ExprGroup &indices,
const std::vector<Expr> &vals) {
return Expr::make<SNodeOpExpression>(snode, SNodeOpType::append, indices,
vals);
}

Expr snode_is_active(SNode *snode, const ExprGroup &indices) {
return Expr::make<SNodeOpExpression>(snode, SNodeOpType::is_active, indices);
}

Expr snode_length(SNode *snode, const ExprGroup &indices) {
return Expr::make<SNodeOpExpression>(snode, SNodeOpType::length, indices);
}

Expr snode_get_addr(SNode *snode, const ExprGroup &indices) {
return Expr::make<SNodeOpExpression>(snode, SNodeOpType::get_addr, indices);
}

Expr assume_range(const Expr &expr, const Expr &base, int low, int high) {
return Expr::make<RangeAssumptionExpression>(expr, base, low, high);
}
Expand Down
20 changes: 1 addition & 19 deletions taichi/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class Expression;
class Identifier;
class ExprGroup;
class SNode;
class ASTBuilder;

class Expr {
public:
Expand Down Expand Up @@ -133,25 +134,6 @@ Expr expr_rand() {
return taichi::lang::expr_rand(get_data_type<T>());
}

/*
* This function allocates the space for a new item (a struct or a scalar)
* in the Dynamic SNode, and assigns values to the elements inside it.
*
* When appending a struct, the size of vals must be equal to
* the number of elements in the struct. When appending a scalar,
* the size of vals must be one.
*/

Expr snode_append(SNode *snode,
const ExprGroup &indices,
const std::vector<Expr> &vals);

Expr snode_is_active(SNode *snode, const ExprGroup &indices);

Expr snode_length(SNode *snode, const ExprGroup &indices);

Expr snode_get_addr(SNode *snode, const ExprGroup &indices);

Expr assume_range(const Expr &expr, const Expr &base, int low, int high);

Expr loop_unique(const Expr &input, const std::vector<SNode *> &covers);
Expand Down
62 changes: 53 additions & 9 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@ static bool is_primitive_or_tensor_type(DataType &type) {
return type->is<PrimitiveType>() || type->is<TensorType>();
}

FrontendSNodeOpStmt::FrontendSNodeOpStmt(SNodeOpType op_type,
FrontendSNodeOpStmt::FrontendSNodeOpStmt(ASTBuilder *builder,
SNodeOpType op_type,
SNode *snode,
const ExprGroup &indices,
const Expr &val)
: op_type(op_type), snode(snode), indices(indices), val(val) {
: op_type(op_type), snode(snode), val(val) {
this->indices = indices;
std::vector<Expr> expanded_exprs = builder->expand_expr(this->indices.exprs);
this->indices.exprs = expanded_exprs;

if (val.expr != nullptr) {
TI_ASSERT(op_type == SNodeOpType::append);
} else {
Expand Down Expand Up @@ -923,6 +928,25 @@ void AtomicOpExpression::flatten(FlattenContext *ctx) {
stmt->tb = tb;
}

SNodeOpExpression::SNodeOpExpression(ASTBuilder *builder,
SNode *snode,
SNodeOpType op_type,
const ExprGroup &indices)
: snode(snode), op_type(op_type) {
std::vector<Expr> expanded_indices = builder->expand_expr(indices.exprs);
this->indices = indices;
this->indices.exprs = std::move(expanded_indices);
}

SNodeOpExpression::SNodeOpExpression(ASTBuilder *builder,
SNode *snode,
SNodeOpType op_type,
const ExprGroup &indices,
const std::vector<Expr> &values)
: SNodeOpExpression(builder, snode, op_type, indices) {
this->values = builder->expand_expr(values);
}

void SNodeOpExpression::type_check(CompileConfig *config) {
if (op_type == SNodeOpType::get_addr) {
ret_type = PrimitiveType::u64;
Expand Down Expand Up @@ -1469,20 +1493,40 @@ void ASTBuilder::insert_expr_stmt(const Expr &val) {

void ASTBuilder::insert_snode_activate(SNode *snode,
const ExprGroup &expr_group) {
this->insert(Stmt::make<FrontendSNodeOpStmt>(SNodeOpType::activate, snode,
expr_group));
this->insert(Stmt::make<FrontendSNodeOpStmt>(this, SNodeOpType::activate,
snode, expr_group));
}

void ASTBuilder::insert_snode_deactivate(SNode *snode,
const ExprGroup &expr_group) {
this->insert(Stmt::make<FrontendSNodeOpStmt>(SNodeOpType::deactivate, snode,
expr_group));
this->insert(Stmt::make<FrontendSNodeOpStmt>(this, SNodeOpType::deactivate,
snode, expr_group));
}

std::vector<Expr> ASTBuilder::expand_expr(const std::vector<Expr> &exprs) {
TI_ASSERT(exprs.size() > 0);
Expr ASTBuilder::snode_append(SNode *snode,
const ExprGroup &indices,
const std::vector<Expr> &vals) {
return Expr::make<SNodeOpExpression>(this, snode, SNodeOpType::append,
indices, vals);
}

if (exprs.size() > 1) {
Expr ASTBuilder::snode_is_active(SNode *snode, const ExprGroup &indices) {
return Expr::make<SNodeOpExpression>(this, snode, SNodeOpType::is_active,
indices);
}

Expr ASTBuilder::snode_length(SNode *snode, const ExprGroup &indices) {
return Expr::make<SNodeOpExpression>(this, snode, SNodeOpType::length,
indices);
}

Expr ASTBuilder::snode_get_addr(SNode *snode, const ExprGroup &indices) {
return Expr::make<SNodeOpExpression>(this, snode, SNodeOpType::get_addr,
indices);
}

std::vector<Expr> ASTBuilder::expand_expr(const std::vector<Expr> &exprs) {
if (exprs.size() > 1 || exprs.size() == 0) {
return exprs;
}

Expand Down
34 changes: 26 additions & 8 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

namespace taichi::lang {

class ASTBuilder;

struct ForLoopConfig {
bool is_bit_vectorized{false};
int num_cpu_threads{0};
Expand Down Expand Up @@ -88,7 +90,8 @@ class FrontendSNodeOpStmt : public Stmt {
ExprGroup indices;
Expr val;

FrontendSNodeOpStmt(SNodeOpType op_type,
FrontendSNodeOpStmt(ASTBuilder *builder,
SNodeOpType op_type,
SNode *snode,
const ExprGroup &indices,
const Expr &val = Expr(nullptr));
Expand Down Expand Up @@ -732,16 +735,16 @@ class SNodeOpExpression : public Expression {
ExprGroup indices;
std::vector<Expr> values; // Only for op_type==append

SNodeOpExpression(SNode *snode, SNodeOpType op_type, const ExprGroup &indices)
: snode(snode), op_type(op_type), indices(indices) {
}
SNodeOpExpression(ASTBuilder *builder,
SNode *snode,
SNodeOpType op_type,
const ExprGroup &indices);

SNodeOpExpression(SNode *snode,
SNodeOpExpression(ASTBuilder *builder,
SNode *snode,
SNodeOpType op_type,
const ExprGroup &indices,
const std::vector<Expr> &values)
: snode(snode), op_type(op_type), indices(indices), values(values) {
}
const std::vector<Expr> &values);

void type_check(CompileConfig *config) override;

Expand Down Expand Up @@ -1008,6 +1011,21 @@ class ASTBuilder {
void insert_snode_activate(SNode *snode, const ExprGroup &expr_group);
void insert_snode_deactivate(SNode *snode, const ExprGroup &expr_group);

/*
* This function allocates the space for a new item (a struct or a scalar)
* in the Dynamic SNode, and assigns values to the elements inside it.
*
* When appending a struct, the size of vals must be equal to
* the number of elements in the struct. When appending a scalar,
* the size of vals must be one.
*/
Expr snode_append(SNode *snode,
const ExprGroup &indices,
const std::vector<Expr> &vals);
Expr snode_is_active(SNode *snode, const ExprGroup &indices);
Expr snode_length(SNode *snode, const ExprGroup &indices);
Expr snode_get_addr(SNode *snode, const ExprGroup &indices);

std::vector<Expr> expand_expr(const std::vector<Expr> &exprs);

void create_scope(std::unique_ptr<Block> &list, LoopType tp = NotLoop);
Expand Down
9 changes: 4 additions & 5 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,10 @@ void export_lang(py::module &m) {
.def("begin_frontend_if_false", &ASTBuilder::begin_frontend_if_false)
.def("insert_deactivate", &ASTBuilder::insert_snode_deactivate)
.def("insert_activate", &ASTBuilder::insert_snode_activate)
.def("expr_snode_get_addr", &ASTBuilder::snode_get_addr)
.def("expr_snode_append", &ASTBuilder::snode_append)
.def("expr_snode_is_active", &ASTBuilder::snode_is_active)
.def("expr_snode_length", &ASTBuilder::snode_length)
.def("insert_external_func_call", &ASTBuilder::insert_external_func_call)
.def("make_matrix_expr", &ASTBuilder::make_matrix_expr)
.def("expr_alloca", &ASTBuilder::expr_alloca)
Expand Down Expand Up @@ -811,11 +815,6 @@ void export_lang(py::module &m) {

py::class_<Stmt>(m, "Stmt"); // NOLINT(bugprone-unused-raii)

m.def("expr_snode_get_addr", &snode_get_addr);
m.def("expr_snode_append", &snode_append);
m.def("expr_snode_is_active", &snode_is_active);
m.def("expr_snode_length", &snode_length);

m.def("insert_internal_func_call",
[&](const std::string &func_name, const ExprGroup &args,
bool with_runtime_context) {
Expand Down
4 changes: 3 additions & 1 deletion tests/cpp/ir/frontend_type_inference_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,13 @@ TEST(FrontendTypeInference, AtomicOp) {
}

TEST(FrontendTypeInference, SNodeOp) {
auto prog = std::make_unique<Program>(Arch::x64);
auto snode = std::make_unique<SNode>(0, SNodeType::root);
snode->dt = PrimitiveType::u8;
auto index = value<int32>(2);
index->type_check(nullptr);
auto snode_op = snode_get_addr(snode.get(), ExprGroup(index));
auto snode_op = prog->current_ast_builder()->snode_get_addr(snode.get(),
ExprGroup(index));
snode_op->type_check(nullptr);
EXPECT_EQ(snode_op->ret_type, PrimitiveType::u64);
}
Expand Down

0 comments on commit 7d3505e

Please sign in to comment.