From 49d7d07c03c29664cba074f62107e795b68d4144 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Thu, 5 Jan 2023 10:22:30 +0800 Subject: [PATCH] [lang] Migrate TensorType expansion for subscription indices from Python to Frontend IR (#6942) Issue: https://github.com/taichi-dev/taichi/issues/5819 ### Brief Summary For indices of TensorType, instead of scalarizing them at Python level, it is up to the Frontend IR's consumer to decide whether TensorType'd indices are acceptable and if we should have it scalarized. This PR removes `expand_expr` in Expression subscription and migrate the scalarization logics to the following constructors: 1. MeshIndexConversionExpression::MeshIndexConversionExpression 2. IndexExpression::IndexExpression --- python/taichi/lang/_texture.py | 3 +- python/taichi/lang/any_array.py | 7 +- python/taichi/lang/ast/ast_transformer.py | 33 +-- python/taichi/lang/exception.py | 2 + python/taichi/lang/expr.py | 2 +- python/taichi/lang/impl.py | 55 ++--- python/taichi/lang/kernel_impl.py | 2 +- python/taichi/lang/matrix.py | 4 +- python/taichi/lang/mesh.py | 16 +- python/taichi/lang/simt/block.py | 7 +- taichi/common/exceptions.h | 4 + taichi/ir/expr.cpp | 6 - taichi/ir/expr.h | 2 - taichi/ir/frontend_ir.cpp | 201 +++++++++++++----- taichi/ir/frontend_ir.h | 25 +-- taichi/ir/ir_builder.cpp | 9 - taichi/ir/ir_builder.h | 4 - taichi/math/svd.h | 2 +- taichi/program/program.cpp | 17 +- taichi/python/export_lang.cpp | 26 +-- tests/cpp/ir/frontend_type_inference_test.cpp | 22 +- 21 files changed, 271 insertions(+), 178 deletions(-) diff --git a/python/taichi/lang/_texture.py b/python/taichi/lang/_texture.py index 9e1c33b467e29..ceb0e4be33950 100644 --- a/python/taichi/lang/_texture.py +++ b/python/taichi/lang/_texture.py @@ -13,7 +13,8 @@ def _get_entries(mat): if isinstance(mat, Matrix): return mat.entries assert isinstance(mat, Expr) and mat.is_tensor() - return impl.get_runtime().prog.current_ast_builder().expand_expr([mat.ptr]) + return impl.get_runtime().prog.current_ast_builder().expand_exprs( + [mat.ptr]) class TextureSampler: diff --git a/python/taichi/lang/any_array.py b/python/taichi/lang/any_array.py index c7c44ba20e244..4b980797cd243 100644 --- a/python/taichi/lang/any_array.py +++ b/python/taichi/lang/any_array.py @@ -78,6 +78,8 @@ def __init__(self, arr, indices_first): @taichi_scope def subscript(self, i, j): + ast_builder = impl.get_runtime().prog.current_ast_builder() + indices_second = (i, ) if len(self.arr.element_shape()) == 1 else (i, j) if self.arr.layout() == Layout.SOA: @@ -85,8 +87,9 @@ def subscript(self, i, j): else: indices = self.indices_first + indices_second return Expr( - _ti_core.subscript(self.arr.ptr, make_expr_group(*indices), - impl.get_runtime().get_current_src_info())) + ast_builder.expr_subscript( + self.arr.ptr, make_expr_group(*indices), + impl.get_runtime().get_current_src_info())) __all__ = [] diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 221ef10e7fdba..8e7d13d5215bc 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -15,7 +15,7 @@ ReturnStatus) from taichi.lang.ast.symbol_resolver import ASTResolver from taichi.lang.exception import (TaichiIndexError, TaichiSyntaxError, - TaichiTypeError) + TaichiTypeError, handle_exception_from_cpp) from taichi.lang.expr import Expr, make_expr_group from taichi.lang.field import Field from taichi.lang.matrix import Matrix, MatrixType, Vector, is_vector @@ -156,7 +156,7 @@ def build_assign_unpack(ctx, node_target, values, is_static_assign): raise ValueError( 'Matrices with more than one columns cannot be unpacked') - values = ctx.ast_builder.expand_expr([values.ptr]) + values = ctx.ast_builder.expand_exprs([values.ptr]) if len(values) == 1: values = values[0] @@ -302,7 +302,7 @@ def process_generators(ctx, node, now_comp, func, result): if isinstance(_iter, impl.Expr) and _iter.ptr.is_tensor(): shape = _iter.ptr.get_shape() flattened = [ - Expr(x) for x in ctx.ast_builder.expand_expr([_iter.ptr]) + Expr(x) for x in ctx.ast_builder.expand_exprs([_iter.ptr]) ] _iter = reshape_list(flattened, shape) @@ -514,7 +514,7 @@ def build_Call(ctx, node): # Expand Expr with Matrix-type return into list of Exprs arg_list = [ Expr(x) - for x in ctx.ast_builder.expand_expr([arg_list.ptr]) + for x in ctx.ast_builder.expand_exprs([arg_list.ptr]) ] for i in arg_list: @@ -730,7 +730,7 @@ def build_Return(ctx, node): elif isinstance(ctx.func.return_type, MatrixType): values = node.value.ptr if isinstance(values, Expr) and values.ptr.is_tensor(): - values = ctx.ast_builder.expand_expr([values.ptr]) + values = ctx.ast_builder.expand_exprs([values.ptr]) else: assert isinstance(values, Matrix) values = itertools.chain.from_iterable(values.to_list()) if\ @@ -819,12 +819,15 @@ def build_Attribute(ctx, node): # we continue to process it as a normal attribute node. try: build_stmt(ctx, node.value) - except TaichiIndexError as e: - node.value.ptr = None - if ASTTransformer.build_attribute_if_is_dynamic_snode_method( - ctx, node): - return node.ptr + except Exception as e: + e = handle_exception_from_cpp(e) + if isinstance(e, TaichiIndexError): + node.value.ptr = None + if ASTTransformer.build_attribute_if_is_dynamic_snode_method( + ctx, node): + return node.ptr raise e + if ASTTransformer.build_attribute_if_is_dynamic_snode_method( ctx, node): return node.ptr @@ -837,11 +840,11 @@ def build_Attribute(ctx, node): node.attr) attr_len = len(node.attr) if attr_len == 1: - node.ptr = Expr( - _ti_core.subscript( - node.value.ptr.ptr, - make_expr_group(keygroup.index(node.attr)), - impl.get_runtime().get_current_src_info())) + node.ptr = Expr(impl.get_runtime( + ).prog.current_ast_builder().expr_subscript( + node.value.ptr.ptr, + make_expr_group(keygroup.index(node.attr)), + impl.get_runtime().get_current_src_info())) else: node.ptr = Expr( _ti_core.subscript_with_multiple_indices( diff --git a/python/taichi/lang/exception.py b/python/taichi/lang/exception.py index 319ab0f4b5102..5d66bf47ad5be 100644 --- a/python/taichi/lang/exception.py +++ b/python/taichi/lang/exception.py @@ -56,6 +56,8 @@ def handle_exception_from_cpp(exc): return TaichiTypeError(str(exc)) if isinstance(exc, core.TaichiSyntaxError): return TaichiSyntaxError(str(exc)) + if isinstance(exc, core.TaichiIndexError): + return TaichiIndexError(str(exc)) if isinstance(exc, core.TaichiAssertionError): return TaichiAssertionError(str(exc)) return exc diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 1fe288585e1ed..29fcf8a436630 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -170,7 +170,7 @@ def _get_flattened_ptrs(val): ptrs.extend(_get_flattened_ptrs(item)) return ptrs if isinstance(val, Expr) and val.ptr.is_tensor(): - return impl.get_runtime().prog.current_ast_builder().expand_expr( + return impl.get_runtime().prog.current_ast_builder().expand_exprs( [val.ptr]) return [Expr(val).ptr] diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index ca99c9fa22465..330519c0e58d9 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -9,9 +9,8 @@ from taichi.lang._texture import RWTextureAccessor from taichi.lang.any_array import AnyArray from taichi.lang.enums import SNodeGradType -from taichi.lang.exception import (TaichiCompilationError, TaichiIndexError, - TaichiRuntimeError, TaichiSyntaxError, - TaichiTypeError) +from taichi.lang.exception import (TaichiCompilationError, TaichiRuntimeError, + TaichiSyntaxError, TaichiTypeError) from taichi.lang.expr import Expr, make_expr_group from taichi.lang.field import Field, ScalarField from taichi.lang.kernel_arguments import SparseMatrixProxy @@ -132,6 +131,7 @@ def check_validity(x): @taichi_scope def subscript(ast_builder, value, *_indices, skip_reordered=False): + ast_builder = get_runtime().prog.current_ast_builder() # Directly evaluate in Python for non-Taichi types if not isinstance( value, @@ -150,9 +150,6 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False): elif isinstance(_index, slice): ind = [_index] has_slice = True - elif isinstance(_index, Expr) and _index.is_tensor(): - # Expand Expr with TensorType return - ind = [Expr(e) for e in ast_builder.expand_expr([_index.ptr])] else: ind = [_index] flattened_indices += ind @@ -167,7 +164,6 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False): f"The type {type(value)} do not support index of slice type") else: indices_expr_group = make_expr_group(*indices) - index_dim = indices_expr_group.size() if isinstance(value, SharedArray): return value.subscript(*indices) @@ -178,13 +174,13 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False): if isinstance(value, (MeshReorderedScalarFieldProxy, MeshReorderedMatrixFieldProxy)) and not skip_reordered: - assert index_dim == 1 + reordered_index = tuple([ Expr( - _ti_core.get_index_conversion(value.mesh_ptr, - value.element_type, - Expr(indices[0]).ptr, - ConvType.g2r)) + ast_builder.mesh_index_conversion(value.mesh_ptr, + value.element_type, + Expr(indices[0]).ptr, + ConvType.g2r)) ]) return subscript(ast_builder, value, @@ -203,13 +199,12 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False): raise RuntimeError( f"Gradient {_var.get_expr_name()} has not been placed, check whether `needs_grad=True`" ) - field_dim = snode.num_active_indices() - if field_dim != index_dim: - raise TaichiIndexError( - f'Field with dim {field_dim} accessed with indices of dim {index_dim}' - ) + if isinstance(value, MatrixField): - return make_index_expr(value.ptr, indices_expr_group) + return Expr( + ast_builder.expr_subscript( + value.ptr, indices_expr_group, + get_runtime().get_current_src_info())) if isinstance(value, StructField): entries = { k: subscript(ast_builder, v, *indices) @@ -217,15 +212,13 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False): } entries['__struct_methods'] = value.struct_methods return _IntermediateStruct(entries) - return make_index_expr(_var, indices_expr_group) + return Expr( + ast_builder.expr_subscript(_var, indices_expr_group, + get_runtime().get_current_src_info())) if isinstance(value, AnyArray): - dim = _ti_core.get_external_tensor_dim(value.ptr) - element_dim = len(value.element_shape()) - if dim != index_dim + element_dim: - raise IndexError( - f'Field with dim {dim - element_dim} accessed with indices of dim {index_dim}' - ) - return make_index_expr(value.ptr, indices_expr_group) + return Expr( + ast_builder.expr_subscript(value.ptr, indices_expr_group, + get_runtime().get_current_src_info())) assert isinstance(value, Expr) # Index into TensorType # value: IndexExpression with ret_type = TensorType @@ -249,18 +242,14 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False): make_expr_group(i, j) for i in indices[0] for j in indices[1] ] return_shape = (len(indices[0]), len(indices[1])) + return Expr( _ti_core.subscript_with_multiple_indices( value.ptr, multiple_indices, return_shape, get_runtime().get_current_src_info())) - return make_index_expr(value.ptr, indices_expr_group) - - -@taichi_scope -def make_index_expr(_var, indices_expr_group): return Expr( - _ti_core.subscript(_var, indices_expr_group, - get_runtime().get_current_src_info())) + ast_builder.expr_subscript(value.ptr, indices_expr_group, + get_runtime().get_current_src_info())) class SrcInfoGuard: diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 670a0639d02b7..e933f8de190f1 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -262,7 +262,7 @@ def func_call_rvalue(self, key, args): impl.Expr) and args[i].ptr.is_tensor(): non_template_args.extend([ Expr(x) for x in impl.get_runtime().prog. - current_ast_builder().expand_expr([args[i].ptr]) + current_ast_builder().expand_exprs([args[i].ptr]) ]) else: non_template_args.append(args[i]) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index a4291535c0ddf..7dec71786400d 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1503,7 +1503,7 @@ def __call__(self, *args): elif isinstance(x, impl.Expr) and x.ptr.is_tensor(): entries += [ impl.Expr(e) for e in impl.get_runtime().prog. - current_ast_builder().expand_expr([x.ptr]) + current_ast_builder().expand_exprs([x.ptr]) ] elif isinstance(x, Matrix): entries += x.entries @@ -1616,7 +1616,7 @@ def __call__(self, *args): elif isinstance(x, impl.Expr) and x.ptr.is_tensor(): entries += [ impl.Expr(e) for e in impl.get_runtime().prog. - current_ast_builder().expand_expr([x.ptr]) + current_ast_builder().expand_exprs([x.ptr]) ] else: entries.append(x) diff --git a/python/taichi/lang/mesh.py b/python/taichi/lang/mesh.py index c1d7169c50c1e..4de03673abdaa 100644 --- a/python/taichi/lang/mesh.py +++ b/python/taichi/lang/mesh.py @@ -605,14 +605,17 @@ def _TetMesh(): class MeshElementFieldProxy: def __init__(self, mesh: MeshInstance, element_type: MeshElementType, entry_expr: impl.Expr): + ast_builder = impl.get_runtime().prog.current_ast_builder() + self.mesh = mesh self.element_type = element_type self.entry_expr = entry_expr element_field = self.mesh.fields[self.element_type] for key, attr in element_field.field_dict.items(): + global_entry_expr = impl.Expr( - _ti_core.get_index_conversion( + ast_builder.mesh_index_conversion( self.mesh.mesh_ptr, element_type, entry_expr, ConvType.l2r if element_field.attr_dict[key].reorder else ConvType.l2g)) # transform index space @@ -622,7 +625,7 @@ def __init__(self, mesh: MeshInstance, element_type: MeshElementType, setattr( self, key, impl.Expr( - _ti_core.subscript( + ast_builder.expr_subscript( attr.ptr, global_entry_expr_group, impl.get_runtime().get_current_src_info()))) elif isinstance(attr, StructField): @@ -633,7 +636,7 @@ def __init__(self, mesh: MeshInstance, element_type: MeshElementType, setattr( self, key, impl.Expr( - _ti_core.subscript( + ast_builder.expr_subscript( var, global_entry_expr_group, impl.get_runtime().get_current_src_info()))) @@ -650,10 +653,11 @@ def ptr(self): @property def id(self): # return the global non-reordered index + ast_builder = impl.get_runtime().prog.current_ast_builder() l2g_expr = impl.Expr( - _ti_core.get_index_conversion(self.mesh.mesh_ptr, - self.element_type, self.entry_expr, - ConvType.l2g)) + ast_builder.mesh_index_conversion(self.mesh.mesh_ptr, + self.element_type, + self.entry_expr, ConvType.l2g)) return l2g_expr diff --git a/python/taichi/lang/simt/block.py b/python/taichi/lang/simt/block.py index a9fc6af89eb94..c1154793df23f 100644 --- a/python/taichi/lang/simt/block.py +++ b/python/taichi/lang/simt/block.py @@ -54,5 +54,8 @@ def __init__(self, shape, dtype): @taichi_scope def subscript(self, *indices): - return impl.make_index_expr(self.shared_array_proxy, - make_expr_group(*indices)) + ast_builder = impl.get_runtime().prog.current_ast_builder() + return impl.Expr( + ast_builder.expr_subscript( + self.shared_array_proxy, make_expr_group(*indices), + impl.get_runtime().get_current_src_info())) diff --git a/taichi/common/exceptions.h b/taichi/common/exceptions.h index 5ad9d8d789609..eb3570ffe23b9 100644 --- a/taichi/common/exceptions.h +++ b/taichi/common/exceptions.h @@ -23,6 +23,10 @@ class TaichiSyntaxError : public TaichiExceptionImpl { using TaichiExceptionImpl::TaichiExceptionImpl; }; +class TaichiIndexError : public TaichiExceptionImpl { + using TaichiExceptionImpl::TaichiExceptionImpl; +}; + class TaichiRuntimeError : public TaichiExceptionImpl { using TaichiExceptionImpl::TaichiExceptionImpl; }; diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index 767d05ef5eeef..da2c541fdf1de 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -26,12 +26,6 @@ Expr bit_cast(const Expr &input, DataType dt) { return Expr::make(UnaryOpType::cast_bits, input, dt); } -Expr Expr::operator[](const ExprGroup &indices) const { - TI_ASSERT(is() || is() || - is() || is_tensor(expr->ret_type)); - return Expr::make(*this, indices); -} - Expr &Expr::operator=(const Expr &o) { set(o); return *this; diff --git a/taichi/ir/expr.h b/taichi/ir/expr.h index 0d5f974c7aa0a..b8058c7f1ad22 100644 --- a/taichi/ir/expr.h +++ b/taichi/ir/expr.h @@ -83,8 +83,6 @@ class Expr { // std::variant in FrontendPrintStmt. Expr &operator=(const Expr &o); - Expr operator[](const ExprGroup &indices) const; - template static Expr make(Args &&...args) { return Expr(std::make_shared(std::forward(args)...)); diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 4a81f56287141..b440a316bc172 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -25,7 +25,7 @@ FrontendSNodeOpStmt::FrontendSNodeOpStmt(ASTBuilder *builder, const Expr &val) : op_type(op_type), snode(snode), val(val) { this->indices = indices; - std::vector expanded_exprs = builder->expand_expr(this->indices.exprs); + std::vector expanded_exprs = builder->expand_exprs(this->indices.exprs); this->indices.exprs = expanded_exprs; if (val.expr != nullptr) { @@ -688,6 +688,26 @@ void MatrixExpression::flatten(FlattenContext *ctx) { stmt->ret_type = this->dt; } +IndexExpression::IndexExpression(const Expr &var, + const ExprGroup &indices, + std::string tb) + : var(var), indices_group({indices}) { + this->tb = tb; +} + +IndexExpression::IndexExpression(const Expr &var, + const std::vector &indices_group, + const std::vector &ret_shape, + std::string tb) + : var(var), indices_group(indices_group), ret_shape(ret_shape) { + // IndexExpression with ret_shape is used for matrix slicing, where each entry + // of ExprGroup is interpreted as a group of indices to return within each + // axis. For example, mat[0, 3:5] has indices_group={0, [3, 4]}, where [3, 4] + // means "m"-axis will return a TensorType with size of 2. In this case, we + // should not expand indices_group due to its special semantics. + this->tb = tb; +} + bool IndexExpression::is_field() const { return var.is(); } @@ -721,20 +741,43 @@ bool IndexExpression::is_global() const { return is_field() || is_matrix_field() || is_ndarray(); } +static void field_validation(FieldExpression *field_expr, int index_dim) { + TI_ASSERT(field_expr != nullptr); + TI_ASSERT(field_expr->snode != nullptr); + int field_dim = field_expr->snode->num_active_indices; + + if (field_dim != index_dim) { + throw TaichiIndexError( + fmt::format("Field with dim {} accessed with indices of dim {}", + field_dim, index_dim)); + } +} + void IndexExpression::type_check(CompileConfig *) { // TODO: Change to type-based solution // Currently, dimension compatibility check happens in Python TI_ASSERT(indices_group.size() == std::accumulate(begin(ret_shape), end(ret_shape), 1, std::multiplies<>())); - if (!ret_shape.empty()) { + int index_dim = indices_group.empty() ? 0 : indices_group[0].size(); + bool has_slice = !ret_shape.empty(); + if (has_slice) { TI_ASSERT_INFO(is_tensor(), "Slice or swizzle can only apply on matrices"); auto element_type = var->ret_type->as()->get_element_type(); ret_type = TypeFactory::create_tensor_type(ret_shape, element_type); + } else if (is_field()) { // field - ret_type = var.cast()->dt->get_compute_type(); + auto field_expr = var.cast(); + field_validation(field_expr.get(), index_dim); + ret_type = field_expr->dt->get_compute_type(); + } else if (is_matrix_field()) { auto matrix_field_expr = var.cast(); + + TI_ASSERT(!matrix_field_expr->fields.empty()); + auto field_expr = matrix_field_expr->fields[0].cast(); + field_validation(field_expr.get(), index_dim); + ret_type = TypeFactory::create_tensor_type(matrix_field_expr->element_shape, matrix_field_expr->fields[0] .cast() @@ -742,7 +785,12 @@ void IndexExpression::type_check(CompileConfig *) { } else if (is_ndarray()) { // ndarray auto external_tensor_expr = var.cast(); int total_dim = external_tensor_expr->dim; - int index_dim = indices_group[0].exprs.size(); + int element_dim = external_tensor_expr->dt.get_shape().size(); + if (total_dim != index_dim + element_dim) { + throw TaichiTypeError( + fmt::format("Array with dim {} accessed with indices of dim {}", + total_dim - element_dim, index_dim)); + } if (index_dim == total_dim) { // Access all the way to a single element @@ -910,7 +958,7 @@ SNodeOpExpression::SNodeOpExpression(ASTBuilder *builder, SNodeOpType op_type, const ExprGroup &indices) : snode(snode), op_type(op_type) { - std::vector expanded_indices = builder->expand_expr(indices.exprs); + std::vector expanded_indices = builder->expand_exprs(indices.exprs); this->indices = indices; this->indices.exprs = std::move(expanded_indices); } @@ -921,7 +969,7 @@ SNodeOpExpression::SNodeOpExpression(ASTBuilder *builder, const ExprGroup &indices, const std::vector &values) : SNodeOpExpression(builder, snode, op_type, indices) { - this->values = builder->expand_expr(values); + this->values = builder->expand_exprs(values); } void SNodeOpExpression::type_check(CompileConfig *config) { @@ -1163,6 +1211,14 @@ void MeshRelationAccessExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } +MeshIndexConversionExpression::MeshIndexConversionExpression( + mesh::Mesh *mesh, + mesh::MeshElementType idx_type, + const Expr idx, + mesh::ConvType conv_type) + : mesh(mesh), idx_type(idx_type), idx(idx), conv_type(conv_type) { +} + void MeshIndexConversionExpression::type_check(CompileConfig *) { ret_type = PrimitiveType::i32; } @@ -1359,6 +1415,25 @@ void ASTBuilder::expr_assign(const Expr &lhs, const Expr &rhs, std::string tb) { this->insert(std::move(stmt)); } +Expr ASTBuilder::expr_subscript(const Expr &expr, + const ExprGroup &indices, + std::string tb) { + TI_ASSERT(expr.is() || expr.is() || + expr.is() || + is_tensor(expr.expr->ret_type)); + + // IndexExpression without ret_shape is used for matrix indexing, + // where each entry of ExprGroup is interpreted as indexing into a specific + // axis. For example, mat[3, 4] has indices_group={[3, 4]}, where [3, 4] + // corresponds to "n"-axis and "m"-axis of the matrix. Therefore we expand + // indices_group={[3, 4]} into {3, 4} to avoid TensorType in indices. + std::vector expanded_indices = this->expand_exprs(indices.exprs); + auto expanded_expr_group = ExprGroup(); + expanded_expr_group.exprs = expanded_indices; + + return Expr::make(expr, expanded_expr_group, tb); +} + void ASTBuilder::create_assert_stmt(const Expr &cond, const std::string &msg, const std::vector &args) { @@ -1479,62 +1554,82 @@ Expr ASTBuilder::snode_get_addr(SNode *snode, const ExprGroup &indices) { indices); } -std::vector ASTBuilder::expand_expr(const std::vector &exprs) { - if (exprs.size() > 1 || exprs.size() == 0) { +std::vector ASTBuilder::expand_exprs(const std::vector &exprs) { + if (exprs.size() == 0) { return exprs; } - Expr index_expr = exprs[0]; - TI_ASSERT_TYPE_CHECKED(index_expr); - if (!index_expr->ret_type->is()) { - return exprs; - } - - // Expand TensorType expr - /* - Before: - TensorType<4 x i32> index = Expr; - - After: - TensorType<4 x i32>* id_expr = FrontendAllocaStmt(TensorType<4 x i32>) - i32 ind0 = IndexExpression(id_expr, 0) - i32 ind1 = IndexExpression(id_expr, 1) - i32 ind2 = IndexExpression(id_expr, 2) - i32 ind3 = IndexExpression(id_expr, 3) - - return {ind0, ind1, ind2, ind3} - - */ std::vector expanded_exprs; + for (auto expr : exprs) { + TI_ASSERT_TYPE_CHECKED(expr); + if (!expr->ret_type->is()) { + expanded_exprs.push_back(expr); + } else { + // Expand TensorType expr + /* + Before: + TensorType<4 x i32> index = Expr; + + After: + TensorType<4 x i32>* id_expr = FrontendAllocaStmt(TensorType<4 x i32>) + i32 ind0 = IndexExpression(id_expr, 0) + i32 ind1 = IndexExpression(id_expr, 1) + i32 ind2 = IndexExpression(id_expr, 2) + i32 ind3 = IndexExpression(id_expr, 3) + + return {ind0, ind1, ind2, ind3} + + */ + auto tensor_type = expr->ret_type->cast(); + + Expr id_expr; + if (expr.is()) { + id_expr = expr; + } else { + id_expr = make_var(expr, expr->tb); + } + auto shape = tensor_type->get_shape(); + if (shape.size() == 1) { + for (int i = 0; i < shape[0]; i++) { + auto ind = Expr(std::make_shared( + id_expr, ExprGroup(Expr(i)), expr->tb)); + ind.expr->ret_type = tensor_type->get_element_type(); + expanded_exprs.push_back(ind); + } + } else { + TI_ASSERT(shape.size() == 2); + for (int i = 0; i < shape[0]; i++) { + for (int j = 0; j < shape[1]; j++) { + auto ind = Expr(std::make_shared( + id_expr, ExprGroup(Expr(i), Expr(j)), expr->tb)); + ind.expr->ret_type = tensor_type->get_element_type(); + expanded_exprs.push_back(ind); + } + } + } + } + } - auto tensor_type = index_expr->ret_type->cast(); + return expanded_exprs; +} - Expr id_expr; - if (index_expr.is()) { - id_expr = index_expr; +Expr ASTBuilder::mesh_index_conversion(mesh::MeshPtr mesh_ptr, + mesh::MeshElementType idx_type, + const Expr &idx, + mesh::ConvType &conv_type) { + Expr expanded_idx; + if (idx.is() && idx.get_ret_type() == PrimitiveType::unknown) { + expanded_idx = idx; } else { - id_expr = make_var(index_expr, index_expr->tb); - } - auto shape = tensor_type->get_shape(); - if (shape.size() == 1) { - for (int i = 0; i < shape[0]; i++) { - auto ind = Expr(std::make_shared( - id_expr, ExprGroup(Expr(i)), index_expr->tb)); - ind.expr->ret_type = tensor_type->get_element_type(); - expanded_exprs.push_back(ind); - } - } else { - TI_ASSERT(shape.size() == 2); - for (int i = 0; i < shape[0]; i++) { - for (int j = 0; j < shape[1]; j++) { - auto ind = Expr(std::make_shared( - id_expr, ExprGroup(Expr(i), Expr(j)), index_expr->tb)); - ind.expr->ret_type = tensor_type->get_element_type(); - expanded_exprs.push_back(ind); - } + if (idx.expr->ret_type->is()) { + TI_ASSERT(idx.expr->ret_type->cast()->get_num_elements() == + 1); } + expanded_idx = this->expand_exprs({idx})[0]; } - return expanded_exprs; + + return Expr::make(mesh_ptr.ptr.get(), idx_type, + expanded_idx, conv_type); } void ASTBuilder::create_scope(std::unique_ptr &list, LoopType tp) { diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 955dd1d7c69fc..9e03107090bbe 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -590,18 +590,12 @@ class IndexExpression : public Expression { IndexExpression(const Expr &var, const ExprGroup &indices, - std::string tb = "") - : var(var), indices_group({indices}) { - this->tb = tb; - } + std::string tb = ""); IndexExpression(const Expr &var, const std::vector &indices_group, const std::vector &ret_shape, - std::string tb = "") - : var(var), indices_group(indices_group), ret_shape(ret_shape) { - this->tb = tb; - } + std::string tb = ""); void type_check(CompileConfig *config) override; @@ -868,9 +862,7 @@ class MeshIndexConversionExpression : public Expression { MeshIndexConversionExpression(mesh::Mesh *mesh, mesh::MeshElementType idx_type, const Expr idx, - mesh::ConvType conv_type) - : mesh(mesh), idx_type(idx_type), idx(idx), conv_type(conv_type) { - } + mesh::ConvType conv_type); void flatten(FlattenContext *ctx) override; @@ -960,6 +952,15 @@ class ASTBuilder { Expr expr_alloca(); Expr expr_alloca_shared_array(const std::vector &shape, const DataType &element_type); + Expr expr_subscript(const Expr &expr, + const ExprGroup &indices, + std::string tb = ""); + + Expr mesh_index_conversion(mesh::MeshPtr mesh_ptr, + mesh::MeshElementType idx_type, + const Expr &idx, + mesh::ConvType &conv_type); + void expr_assign(const Expr &lhs, const Expr &rhs, std::string tb); void create_assert_stmt(const Expr &cond, const std::string &msg, @@ -995,7 +996,7 @@ class ASTBuilder { Expr snode_length(SNode *snode, const ExprGroup &indices); Expr snode_get_addr(SNode *snode, const ExprGroup &indices); - std::vector expand_expr(const std::vector &exprs); + std::vector expand_exprs(const std::vector &exprs); void create_scope(std::unique_ptr &list, LoopType tp = NotLoop); void pop_scope(); diff --git a/taichi/ir/ir_builder.cpp b/taichi/ir/ir_builder.cpp index 15fbc2587a5e4..0d4d77f9662b7 100644 --- a/taichi/ir/ir_builder.cpp +++ b/taichi/ir/ir_builder.cpp @@ -475,15 +475,6 @@ MeshRelationAccessStmt *IRBuilder::get_relation_access( mesh, mesh_idx, to_type, neighbor_idx)); } -MeshIndexConversionStmt *IRBuilder::get_index_conversion( - mesh::Mesh *mesh, - mesh::MeshElementType idx_type, - Stmt *idx, - mesh::ConvType conv_type) { - return insert(Stmt::make_typed(mesh, idx_type, idx, - conv_type)); -} - MeshPatchIndexStmt *IRBuilder::get_patch_index() { return insert(Stmt::make_typed()); } diff --git a/taichi/ir/ir_builder.h b/taichi/ir/ir_builder.h index 316a2c14a39d2..08bdd1e83b9c1 100644 --- a/taichi/ir/ir_builder.h +++ b/taichi/ir/ir_builder.h @@ -278,10 +278,6 @@ class IRBuilder { Stmt *mesh_idx, mesh::MeshElementType to_type, Stmt *neighbor_idx); - MeshIndexConversionStmt *get_index_conversion(mesh::Mesh *mesh, - mesh::MeshElementType idx_type, - Stmt *idx, - mesh::ConvType conv_type); MeshPatchIndexStmt *get_patch_index(); private: diff --git a/taichi/math/svd.h b/taichi/math/svd.h index d0da10beb3da2..898d030660192 100644 --- a/taichi/math/svd.h +++ b/taichi/math/svd.h @@ -43,7 +43,7 @@ std::tuple sifakis_svd_export(ASTBuilder *ast_builder, const Expr &mat, int num_iters) { - auto expanded_exprs = ast_builder->expand_expr({mat}); + auto expanded_exprs = ast_builder->expand_exprs({mat}); TI_ASSERT(expanded_exprs.size() == 9); Expr a00 = expanded_exprs[0]; diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index 482dce4a5bf6c..5d7df34eb6290 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -362,10 +362,13 @@ Kernel &Program::get_snode_reader(SNode *snode) { auto &ker = kernel([snode, this] { ExprGroup indices; for (int i = 0; i < snode->num_active_indices; i++) { - indices.push_back(Expr::make(i, PrimitiveType::i32)); + auto argload_expr = Expr::make(i, PrimitiveType::i32); + argload_expr->type_check(&this->this_thread_config()); + indices.push_back(std::move(argload_expr)); } - auto ret = Stmt::make( - ExprGroup(Expr(snode_to_fields_.at(snode))[indices])); + ASTBuilder *builder = this->current_ast_builder(); + auto ret = Stmt::make(ExprGroup( + builder->expr_subscript(Expr(snode_to_fields_.at(snode)), indices))); this->current_ast_builder()->insert(std::move(ret)); }); ker.set_arch(get_accessor_arch()); @@ -383,9 +386,13 @@ Kernel &Program::get_snode_writer(SNode *snode) { auto &ker = kernel([snode, this] { ExprGroup indices; for (int i = 0; i < snode->num_active_indices; i++) { - indices.push_back(Expr::make(i, PrimitiveType::i32)); + auto argload_expr = Expr::make(i, PrimitiveType::i32); + argload_expr->type_check(&this->this_thread_config()); + indices.push_back(std::move(argload_expr)); } - auto expr = Expr(snode_to_fields_.at(snode))[indices]; + ASTBuilder *builder = current_ast_builder(); + auto expr = + builder->expr_subscript(Expr(snode_to_fields_.at(snode)), indices); this->current_ast_builder()->insert_assignment( expr, Expr::make(snode->num_active_indices, diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index d958e92a22325..d5485976eca1a 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -42,10 +42,6 @@ bool test_threading(); namespace taichi::lang { -Expr expr_index(const Expr &expr, const Expr &index) { - return expr[ExprGroup(index)]; -} - std::string libdevice_path(); } // namespace taichi::lang @@ -59,6 +55,8 @@ void export_lang(py::module &m) { PyExc_TypeError); py::register_exception(m, "TaichiSyntaxError", PyExc_SyntaxError); + py::register_exception(m, "TaichiIndexError", + PyExc_IndexError); py::register_exception(m, "TaichiRuntimeError", PyExc_RuntimeError); py::register_exception(m, "TaichiAssertionError", @@ -315,7 +313,9 @@ void export_lang(py::module &m) { .def("insert_expr_stmt", &ASTBuilder::insert_expr_stmt) .def("insert_thread_idx_expr", &ASTBuilder::insert_thread_idx_expr) .def("insert_patch_idx_expr", &ASTBuilder::insert_patch_idx_expr) - .def("expand_expr", &ASTBuilder::expand_expr) + .def("expand_exprs", &ASTBuilder::expand_exprs) + .def("mesh_index_conversion", &ASTBuilder::mesh_index_conversion) + .def("expr_subscript", &ASTBuilder::expr_subscript) .def("sifakis_svd_f32", sifakis_svd_export) .def("sifakis_svd_f64", sifakis_svd_export) .def("expr_var", &ASTBuilder::make_var) @@ -861,8 +861,6 @@ void export_lang(py::module &m) { return Expr::make(AtomicOpType::bit_xor, a, b); }); - m.def("expr_index", expr_index); - m.def("expr_assume_in_range", assume_range); m.def("expr_loop_unique", loop_unique); @@ -993,13 +991,6 @@ void export_lang(py::module &m) { m.def("data_type_name", data_type_name); - m.def("subscript", - [](const Expr &expr, const ExprGroup &expr_group, std::string tb) { - Expr idx_expr = expr[expr_group]; - idx_expr.set_tb(tb); - return idx_expr; - }); - m.def( "subscript_with_multiple_indices", Expr::make &, @@ -1044,13 +1035,6 @@ void export_lang(py::module &m) { mesh_ptr.ptr.get(), mesh_idx, to_type, neighbor_idx); }); - m.def("get_index_conversion", - [](mesh::MeshPtr mesh_ptr, mesh::MeshElementType idx_type, - const Expr &idx, mesh::ConvType &conv_type) { - return Expr::make( - mesh_ptr.ptr.get(), idx_type, idx, conv_type); - }); - py::class_(m, "FunctionKey") .def(py::init()) .def_readonly("instance_id", &FunctionKey::instance_id); diff --git a/tests/cpp/ir/frontend_type_inference_test.cpp b/tests/cpp/ir/frontend_type_inference_test.cpp index 11ad9269f6b67..558a72979b80a 100644 --- a/tests/cpp/ir/frontend_type_inference_test.cpp +++ b/tests/cpp/ir/frontend_type_inference_test.cpp @@ -86,21 +86,39 @@ TEST(FrontendTypeInference, TernaryOp) { } TEST(FrontendTypeInference, GlobalPtr_Field) { + auto prog = std::make_unique(Arch::x64); + auto func = []() {}; + auto kernel = std::make_unique(*prog, func, "fake_kernel"); + Callable::CurrentCallableGuard _(kernel->program, kernel.get()); + auto ast_builder = prog->current_ast_builder(); + auto global_var = Expr::make(PrimitiveType::u8, Identifier(0)); + SNode snode; + snode.num_active_indices = 1; + std::dynamic_pointer_cast(global_var.expr) + ->set_snode(&snode); + auto index = value(2); index->type_check(nullptr); - auto global_ptr = global_var[ExprGroup(index)]; + auto global_ptr = ast_builder->expr_subscript(global_var, ExprGroup(index)); global_ptr->type_check(nullptr); EXPECT_EQ(global_ptr->ret_type, PrimitiveType::u8); } TEST(FrontendTypeInference, GlobalPtr_ExternalTensor) { + auto prog = std::make_unique(Arch::x64); + auto func = []() {}; + auto kernel = std::make_unique(*prog, func, "fake_kernel"); + Callable::CurrentCallableGuard _(kernel->program, kernel.get()); + auto ast_builder = prog->current_ast_builder(); + auto index = value(2); index->type_check(nullptr); auto external_tensor = Expr::make(PrimitiveType::u16, 1, 0, 0); - auto global_ptr = external_tensor[ExprGroup(index)]; + auto global_ptr = + ast_builder->expr_subscript(external_tensor, ExprGroup(index)); EXPECT_THROW(global_ptr->type_check(nullptr), TaichiTypeError); }