Skip to content

Commit

Permalink
[lang] Add support for struct on Dynamic SNode
Browse files Browse the repository at this point in the history
  • Loading branch information
lin-hitonami committed Nov 2, 2022
1 parent 2ee18c5 commit 27e6c19
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 17 deletions.
13 changes: 10 additions & 3 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import collections.abc
import numbers

from taichi._lib import core as _ti_core
from taichi.lang import expr, impl, matrix
from taichi.lang import expr, impl, matrix, struct
from taichi.lang.field import BitpackedFields, Field


Expand Down Expand Up @@ -371,11 +372,17 @@ def append(node, indices, val):
val (:mod:`~taichi.types.primitive_types`): the scalar data to be appended, only i32 value is support for now.
"""
if isinstance(val, matrix.Matrix):
raise ValueError("ti.append only supports appending a scalar value")
raise ValueError("ti.append only supports appending a scalar value or a struct")
ptrs = []
if isinstance(val, struct.Struct):
for item in val._members:
ptrs.append(expr.Expr(item).ptr)
else:
ptrs = [expr.Expr(val).ptr]
a = impl.expr_init(
_ti_core.expr_snode_append(node._snode.ptr,
expr.make_expr_group(indices),
expr.Expr(val).ptr))
ptrs))
return a


Expand Down
2 changes: 1 addition & 1 deletion taichi/analysis/gen_offline_cache_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
emit(expr->op_type);
emit(expr->snode);
emit(expr->indices.exprs);
emit(expr->value);
emit(expr->values);
}

void visit(ConstExpression *expr) override {
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Expr expr_rand(DataType dt) {
return Expr::make<RandExpression>(dt);
}

Expr snode_append(SNode *snode, const ExprGroup &indices, const Expr &val) {
Expr snode_append(SNode *snode, const ExprGroup &indices, const std::vector<Expr> &val) {
return Expr::make<SNodeOpExpression>(snode, SNodeOpType::append, indices,
val);
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ Expr expr_rand() {
return taichi::lang::expr_rand(get_data_type<T>());
}

Expr snode_append(SNode *snode, const ExprGroup &indices, const Expr &val);
Expr snode_append(SNode *snode, const ExprGroup &indices, const std::vector<Expr> &val);

Expr snode_is_active(SNode *snode, const ExprGroup &indices);

Expand Down
4 changes: 2 additions & 2 deletions taichi/ir/expression_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {
emit('(', expr->snode->get_node_type_name_hinted(), ", [");
emit_vector(expr->indices.exprs);
emit("]");
if (expr->value.expr) {
if (!expr->values.empty()) {
emit(' ');
expr->value->accept(this);
emit_vector(expr->values);
}
emit(')');
}
Expand Down
13 changes: 7 additions & 6 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -925,18 +925,19 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) {
} else if (op_type == SNodeOpType::get_addr) {
ctx->push_back<SNodeOpStmt>(SNodeOpType::get_addr, snode, ptr, nullptr);
} else if (op_type == SNodeOpType::append) {
flatten_rvalue(value, ctx);

for (auto &value : values) {
flatten_rvalue(value, ctx);
}
auto alloca = ctx->push_back<AllocaStmt>(PrimitiveType::i32);
auto addr =
ctx->push_back<SNodeOpStmt>(SNodeOpType::allocate, snode, ptr, alloca);
auto ch_addr = ctx->push_back<GetChStmt>(addr, snode, 0);
ctx->push_back<GlobalStoreStmt>(ch_addr, value->stmt);
for (int i = 0; i < values.size(); i++) {
auto ch_addr = ctx->push_back<GetChStmt>(addr, snode, i);
ctx->push_back<GlobalStoreStmt>(ch_addr, values[i]->stmt);
}
ctx->push_back<LocalLoadStmt>(alloca);
TI_ERROR_IF(snode->type != SNodeType::dynamic,
"ti.append only works on dynamic nodes.");
TI_ERROR_IF(snode->ch.size() != 1,
"ti.append only works on single-child dynamic nodes.");
}
stmt = ctx->back_stmt();
}
Expand Down
6 changes: 3 additions & 3 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ class SNodeOpExpression : public Expression {
SNode *snode;
SNodeOpType op_type;
ExprGroup indices;
Expr value;
std::vector<Expr> values;

SNodeOpExpression(SNode *snode, SNodeOpType op_type, const ExprGroup &indices)
: snode(snode), op_type(op_type), indices(indices) {
Expand All @@ -739,8 +739,8 @@ class SNodeOpExpression : public Expression {
SNodeOpExpression(SNode *snode,
SNodeOpType op_type,
const ExprGroup &indices,
const Expr &value)
: snode(snode), op_type(op_type), indices(indices), value(value) {
const std::vector<Expr> &values)
: snode(snode), op_type(op_type), indices(indices), values(values) {
}

void type_check(CompileConfig *config) override;
Expand Down

0 comments on commit 27e6c19

Please sign in to comment.