Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lang] Add support for struct on Dynamic SNode #6502

Merged
merged 6 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numbers

from taichi._lib import core as _ti_core
from taichi.lang import expr, impl, matrix
from taichi.lang import expr, impl, matrix, struct
from taichi.lang.field import BitpackedFields, Field


Expand Down Expand Up @@ -371,11 +371,17 @@ def append(node, indices, val):
val (:mod:`~taichi.types.primitive_types`): the scalar data to be appended, only i32 value is support for now.
"""
if isinstance(val, matrix.Matrix):
raise ValueError("ti.append only supports appending a scalar value")
raise ValueError(
"ti.append only supports appending a scalar value or a struct")
ptrs = []
if isinstance(val, struct.Struct):
for item in val._members:
ptrs.append(expr.Expr(item).ptr)
else:
ptrs = [expr.Expr(val).ptr]
a = impl.expr_init(
_ti_core.expr_snode_append(node._snode.ptr,
expr.make_expr_group(indices),
expr.Expr(val).ptr))
expr.make_expr_group(indices), ptrs))
return a


Expand Down
2 changes: 1 addition & 1 deletion taichi/analysis/gen_offline_cache_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
emit(expr->op_type);
emit(expr->snode);
emit(expr->indices.exprs);
emit(expr->value);
emit(expr->values);
}

void visit(ConstExpression *expr) override {
Expand Down
4 changes: 3 additions & 1 deletion taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ Expr expr_rand(DataType dt) {
return Expr::make<RandExpression>(dt);
}

Expr snode_append(SNode *snode, const ExprGroup &indices, const Expr &val) {
Expr snode_append(SNode *snode,
const ExprGroup &indices,
const std::vector<Expr> &val) {
return Expr::make<SNodeOpExpression>(snode, SNodeOpType::append, indices,
val);
}
Expand Down
4 changes: 3 additions & 1 deletion taichi/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ Expr expr_rand() {
return taichi::lang::expr_rand(get_data_type<T>());
}

Expr snode_append(SNode *snode, const ExprGroup &indices, const Expr &val);
Expr snode_append(SNode *snode,
lin-hitonami marked this conversation as resolved.
Show resolved Hide resolved
const ExprGroup &indices,
const std::vector<Expr> &val);

Expr snode_is_active(SNode *snode, const ExprGroup &indices);

Expand Down
4 changes: 2 additions & 2 deletions taichi/ir/expression_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {
emit('(', expr->snode->get_node_type_name_hinted(), ", [");
emit_vector(expr->indices.exprs);
emit("]");
if (expr->value.expr) {
if (!expr->values.empty()) {
emit(' ');
expr->value->accept(this);
emit_vector(expr->values);
}
emit(')');
}
Expand Down
13 changes: 7 additions & 6 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -925,18 +925,19 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) {
} else if (op_type == SNodeOpType::get_addr) {
ctx->push_back<SNodeOpStmt>(SNodeOpType::get_addr, snode, ptr, nullptr);
} else if (op_type == SNodeOpType::append) {
flatten_rvalue(value, ctx);

for (auto &value : values) {
flatten_rvalue(value, ctx);
}
auto alloca = ctx->push_back<AllocaStmt>(PrimitiveType::i32);
auto addr =
ctx->push_back<SNodeOpStmt>(SNodeOpType::allocate, snode, ptr, alloca);
auto ch_addr = ctx->push_back<GetChStmt>(addr, snode, 0);
ctx->push_back<GlobalStoreStmt>(ch_addr, value->stmt);
for (int i = 0; i < values.size(); i++) {
auto ch_addr = ctx->push_back<GetChStmt>(addr, snode, i);
ctx->push_back<GlobalStoreStmt>(ch_addr, values[i]->stmt);
lin-hitonami marked this conversation as resolved.
Show resolved Hide resolved
}
ctx->push_back<LocalLoadStmt>(alloca);
TI_ERROR_IF(snode->type != SNodeType::dynamic,
"ti.append only works on dynamic nodes.");
TI_ERROR_IF(snode->ch.size() != 1,
"ti.append only works on single-child dynamic nodes.");
}
stmt = ctx->back_stmt();
}
Expand Down
6 changes: 3 additions & 3 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ class SNodeOpExpression : public Expression {
SNode *snode;
SNodeOpType op_type;
ExprGroup indices;
Expr value;
std::vector<Expr> values;
lin-hitonami marked this conversation as resolved.
Show resolved Hide resolved

SNodeOpExpression(SNode *snode, SNodeOpType op_type, const ExprGroup &indices)
: snode(snode), op_type(op_type), indices(indices) {
Expand All @@ -739,8 +739,8 @@ class SNodeOpExpression : public Expression {
SNodeOpExpression(SNode *snode,
SNodeOpType op_type,
const ExprGroup &indices,
const Expr &value)
: snode(snode), op_type(op_type), indices(indices), value(value) {
const std::vector<Expr> &values)
: snode(snode), op_type(op_type), indices(indices), values(values) {
}

void type_check(CompileConfig *config) override;
Expand Down
25 changes: 25 additions & 0 deletions tests/python/test_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,28 @@ def make_list():

for i in range(20):
assert x[i] == i * i * i * 10000000000


@test_utils.test(require=ti.extension.sparse, exclude=[ti.metal])
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved
def test_append_struct():
struct = ti.types.struct(a=ti.u8, b=ti.u16, c=ti.u32, d=ti.u64)
x = struct.field()
pixel = ti.root.dense(ti.i, 10).dynamic(ti.j, 20, 5)
pixel.place(x)

@ti.kernel
def make_list():
for i in range(10):
for j in range(20):
x[i].append(
struct(i * j * 10, i * j * 10000, i * j * 100000000,
i * j * ti.u64(10000000000)))

make_list()

for i in range(10):
for j in range(20):
assert x[i, j].a == i * j * 10 % 256
assert x[i, j].b == i * j * 10000 % 65536
assert x[i, j].c == i * j * 100000000 % 4294967296
assert x[i, j].d == i * j * 10000000000