Skip to content

Commit

Permalink
[lang] Migrate TensorType expansion for ReturnStmt from Python code t…
Browse files Browse the repository at this point in the history
…o Frontend IR (#6946)

Issue: #5819

### Brief Summary
  • Loading branch information
jim19930609 authored Jan 9, 2023
1 parent 5ef5b63 commit ecc9664
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
3 changes: 0 additions & 3 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,6 @@ def _get_flattened_ptrs(val):
for item in val._members:
ptrs.extend(_get_flattened_ptrs(item))
return ptrs
if isinstance(val, Expr) and val.ptr.is_tensor():
return impl.get_runtime().prog.current_ast_builder().expand_exprs(
[val.ptr])
return [Expr(val).ptr]


Expand Down
8 changes: 7 additions & 1 deletion taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ FrontendSNodeOpStmt::FrontendSNodeOpStmt(ASTBuilder *builder,
}
}

FrontendReturnStmt::FrontendReturnStmt(const ExprGroup &group) : values(group) {
}

FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, const Expr &rhs)
: lhs(lhs), rhs(rhs) {
TI_ASSERT(lhs->is_lvalue());
Expand Down Expand Up @@ -1315,7 +1318,10 @@ Expr ASTBuilder::insert_patch_idx_expr() {
}

void ASTBuilder::create_kernel_exprgroup_return(const ExprGroup &group) {
this->insert(Stmt::make<FrontendReturnStmt>(group));
auto expanded_exprs = this->expand_exprs(group.exprs);
ExprGroup expanded_expr_group;
expanded_expr_group.exprs = std::move(expanded_exprs);
this->insert(Stmt::make<FrontendReturnStmt>(expanded_expr_group));
}

void ASTBuilder::create_print(
Expand Down
3 changes: 1 addition & 2 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,7 @@ class FrontendReturnStmt : public Stmt {
public:
ExprGroup values;

explicit FrontendReturnStmt(const ExprGroup &group) : values(group) {
}
explicit FrontendReturnStmt(const ExprGroup &group);

bool is_container_statement() const override {
return false;
Expand Down

0 comments on commit ecc9664

Please sign in to comment.