From 4ce08c87305de093850ba22518fe14dd1d6159ca Mon Sep 17 00:00:00 2001 From: AD1024 Date: Mon, 15 Aug 2022 13:39:32 -0400 Subject: [PATCH 01/92] cherrypick Matrix repr support --- python/taichi/lang/ast/ast_transformer.py | 7 ++++- python/taichi/lang/impl.py | 6 ++++ python/taichi/lang/matrix.py | 16 ++++++++++ taichi/analysis/data_source_analysis.cpp | 6 +++- taichi/analysis/gen_offline_cache_key.cpp | 9 ++++++ taichi/analysis/offline_cache_util.cpp | 1 + taichi/analysis/same_statements.cpp | 18 +++++++++++ taichi/codegen/llvm/codegen_llvm.cpp | 37 ++++++++++++++++++----- taichi/codegen/llvm/codegen_llvm.h | 2 ++ taichi/inc/expressions.inc.h | 1 + taichi/inc/statements.inc.h | 1 + taichi/ir/control_flow_graph.cpp | 3 +- taichi/ir/expression_printer.h | 8 +++++ taichi/ir/frontend_ir.cpp | 26 ++++++++++++++++ taichi/ir/frontend_ir.h | 23 ++++++++++++++ taichi/ir/statements.h | 12 ++++++++ taichi/ir/type.cpp | 8 +++++ taichi/ir/type.h | 2 ++ taichi/ir/type_utils.cpp | 32 ++++++++++++++++++++ taichi/program/compile_config.cpp | 1 + taichi/program/compile_config.h | 1 + taichi/python/export_lang.cpp | 2 ++ taichi/runtime/llvm/llvm_context.cpp | 4 +++ taichi/transforms/alg_simp.cpp | 28 +++++++++++++++-- taichi/transforms/die.cpp | 7 +++++ taichi/transforms/ir_printer.cpp | 13 ++++++++ taichi/transforms/type_check.cpp | 20 ++++++++++++ 27 files changed, 281 insertions(+), 13 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 1d30ffceae77d..aeca62a0c7421 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -14,7 +14,7 @@ from taichi.lang.ast.symbol_resolver import ASTResolver from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError from taichi.lang.field import Field -from taichi.lang.matrix import (Matrix, MatrixType, _PyScopeMatrixImpl, +from taichi.lang.matrix import (Matrix, Vector, MatrixType, _PyScopeMatrixImpl, _TiScopeMatrixImpl) from taichi.lang.snode import append from taichi.lang.util import in_taichi_scope, is_taichi_class, to_taichi_type @@ -487,6 +487,11 @@ def build_Call(ctx, node): args.insert(0, node.func.value.ptr) node.ptr = impl.ti_format(*args, **keywords) return node.ptr + + if isinstance(node.func, + ast.Attribute) and func == Matrix or func == Vector: + node.ptr = matrix.make_matrix(*args, **keywords) + return node.ptr if ASTTransformer.build_call_if_is_builtin(ctx, node, args, keywords): return node.ptr diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 1cf13596ac70f..d9ab715b28dad 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -35,6 +35,12 @@ def expr_init_local_tensor(shape, element_type, elements): shape, element_type, elements) +@taichi_scope +def expr_init_matrix(shape, element_type, elements): + return get_runtime().prog.current_ast_builder().expr_alloca_matrix( + shape, element_type, elements) + + @taichi_scope def expr_init_shared_array(shape, element_type): return get_runtime().prog.current_ast_builder().expr_alloca_shared_array( diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index e72831224738b..d694659de0975 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -97,6 +97,22 @@ def prop_setter(instance, value): return cls +def make_matrix(arr, dt=None, suppress_warning=False, is_ref=False, **kwargs): + if not impl.current_cfg().real_matrix or in_python_scope(): + return Matrix(arr, dt, suppress_warning, is_ref, **kwargs) + cast = (lambda x: ops_mod.cast(x, dt)) if dt else ( + lambda x: x if isinstance(x, expr.Expr) else expr.Expr(x)) + if len(arr) == 0: + return impl.expr_init(impl.expr_init_matrix([0], dt, [])) + if not isinstance(arr[0], Iterable): + return impl.expr_init( + impl.expr_init_matrix([len(arr)], dt, + [cast(elt).ptr for elt in arr])) + return impl.expr_init( + impl.expr_init_matrix([len(arr), len(arr[0])], dt, + [cast(elt).ptr for row in arr for elt in row])) + + class _MatrixBaseImpl: def __init__(self, m, n, entries): self.m = m diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index 4a018afa6bf47..9baad69fb80fb 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -37,6 +37,10 @@ std::vector get_load_pointers(Stmt *load_stmt) { return external_func->arg_stmts; } else if (auto ref = load_stmt->cast()) { return {ref->var}; + } else if (auto matrix_init = load_stmt->cast()) { + return matrix_init->values; + } else if (auto ptr_offset = load_stmt->cast()) { + return {ptr_offset->origin}; } else { return std::vector(); } @@ -59,7 +63,7 @@ Stmt *get_store_data(Stmt *store_stmt) { std::vector get_store_destination(Stmt *store_stmt) { // If store_stmt provides some data sources, return the pointers of the data. - if (store_stmt->is() && !store_stmt->ret_type->is()) { + if (store_stmt->is()) { // The statement itself provides a data source (const [0]). return std::vector(1, store_stmt); } else if (auto local_store = store_stmt->cast()) { diff --git a/taichi/analysis/gen_offline_cache_key.cpp b/taichi/analysis/gen_offline_cache_key.cpp index d1dfd5166c8ba..06c4a012e02b3 100644 --- a/taichi/analysis/gen_offline_cache_key.cpp +++ b/taichi/analysis/gen_offline_cache_key.cpp @@ -165,6 +165,15 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { emit(expr->indices.exprs); } + void visit(MatrixExpression *expr) override { + emit(ExprOpCode::MatrixExpression); + emit(expr->dt); + for (auto elt : expr->elements) { + emit(elt); + } + } + + void visit(StrideExpression *expr) override { emit(ExprOpCode::StrideExpression); emit(expr->var); diff --git a/taichi/analysis/offline_cache_util.cpp b/taichi/analysis/offline_cache_util.cpp index 19f8a32c01dbf..28263b0f0c3d0 100644 --- a/taichi/analysis/offline_cache_util.cpp +++ b/taichi/analysis/offline_cache_util.cpp @@ -65,6 +65,7 @@ static std::vector get_offline_cache_key_of_compile_config( serializer(config->demote_no_access_mesh_fors); serializer(config->experimental_auto_mesh_local); serializer(config->auto_mesh_local_default_occupacy); + serializer(config->real_matrix); serializer.finalize(); return serializer.data; diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index c26957e81f4a8..5f908603c9497 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -196,6 +196,24 @@ class IRNodeComparator : public IRVisitor { basic_check(stmt); } + void visit(MatrixInitStmt *stmt) override { + basic_check(stmt); + if (!same) + return; + auto o = other_node_->as(); + if (stmt->values.size() != o->values.size()) { + same = false; + return; + } + for (int i = 0; i < stmt->values.size(); ++i) { + other_node_ = o->values[i]; + stmt->values[i]->accept(this); + other_node_ = o; + if (!same) + return; + } + } + void visit(IfStmt *stmt) override { basic_check(stmt); if (!same) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 771f9f055dbf0..5b6d31a345f70 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -124,7 +124,7 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) { void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { if (stmt->ret_type->is()) { auto tensor_type = stmt->ret_type->cast(); - auto type = tlctx->get_data_type(tensor_type->get_element_type()); + auto type = tlctx->get_data_type(tensor_type); auto array_size = tlctx->get_constant(tensor_type->get_num_elements()); // Return type is [array_size x type]*. if (stmt->is_shared) { @@ -688,6 +688,11 @@ llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) { return llvm::Type::getDoubleTy(*llvm_context); } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*llvm_context); + } else if (dt->is()) { + auto tensor_type = dt->cast(); + auto element_type = llvm_type(tensor_type->get_element_type()); + return llvm::VectorType::get(element_type, tensor_type->get_num_elements(), + false); } else { TI_NOT_IMPLEMENTED; } @@ -800,12 +805,20 @@ void TaskCodeGenLLVM::visit(PrintStmt *stmt) { if (std::holds_alternative(content)) { auto arg_stmt = std::get(content); auto value = llvm_val[arg_stmt]; - if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) || - arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) - value = builder->CreateFPExt(value, - tlctx->get_data_type(PrimitiveType::f64)); - args.push_back(value); - formats += data_type_format(arg_stmt->ret_type); + if (arg_stmt->ret_type->is()) { + auto dtype = arg_stmt->ret_type->cast(); + for (int i = 0; i < dtype->get_num_elements(); ++i) { + args.push_back(builder->CreateExtractElement(value, i)); + } + formats += data_type_format(arg_stmt->ret_type); + } else { + if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) || + arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) + value = builder->CreateFPExt( + value, tlctx->get_data_type(PrimitiveType::f64)); + args.push_back(value); + formats += data_type_format(arg_stmt->ret_type); + } } else { auto arg_str = std::get(content); auto value = builder->CreateGlobalStringPtr(arg_str, "content_string"); @@ -2515,6 +2528,16 @@ void TaskCodeGenLLVM::visit(MeshPatchIndexStmt *stmt) { llvm_val[stmt] = get_arg(2); } +void TaskCodeGenLLVM::visit(MatrixInitStmt *stmt) { + auto type = tlctx->get_data_type(stmt->ret_type->as()); + llvm::Value *vec = llvm::UndefValue::get(type); + for (int i = 0; i < stmt->values.size(); ++i) { + auto *elem = llvm_val[stmt->values[i]]; + vec = builder->CreateInsertElement(vec, elem, i); + } + llvm_val[stmt] = vec; +} + void TaskCodeGenLLVM::eliminate_unused_functions() { TaichiLLVMContext::eliminate_unused_functions( module.get(), [&](std::string func_name) { diff --git a/taichi/codegen/llvm/codegen_llvm.h b/taichi/codegen/llvm/codegen_llvm.h index 6f97ed7dff0f4..356866b12b8c4 100644 --- a/taichi/codegen/llvm/codegen_llvm.h +++ b/taichi/codegen/llvm/codegen_llvm.h @@ -369,6 +369,8 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(ReferenceStmt *stmt) override; + void visit(MatrixInitStmt *stmt) override; + llvm::Value *create_xlogue(std::unique_ptr &block); llvm::Value *create_mesh_xlogue(std::unique_ptr &block); diff --git a/taichi/inc/expressions.inc.h b/taichi/inc/expressions.inc.h index 9b20ba86bd80a..ac3d3b7bc9b1b 100644 --- a/taichi/inc/expressions.inc.h +++ b/taichi/inc/expressions.inc.h @@ -7,6 +7,7 @@ PER_EXPRESSION(InternalFuncCallExpression) PER_EXPRESSION(ExternalTensorExpression) PER_EXPRESSION(GlobalVariableExpression) PER_EXPRESSION(IndexExpression) +PER_EXPRESSION(MatrixExpression) PER_EXPRESSION(StrideExpression) PER_EXPRESSION(RangeAssumptionExpression) PER_EXPRESSION(LoopUniqueExpression) diff --git a/taichi/inc/statements.inc.h b/taichi/inc/statements.inc.h index fe12a8941f7f5..05056ce46b9ae 100644 --- a/taichi/inc/statements.inc.h +++ b/taichi/inc/statements.inc.h @@ -38,6 +38,7 @@ PER_STATEMENT(LoopUniqueStmt) PER_STATEMENT(AssertStmt) PER_STATEMENT(ExternalFuncCallStmt) PER_STATEMENT(ExternalTensorShapeAlongAxisStmt) +PER_STATEMENT(MatrixInitStmt) // Locals with reverse-mode autodiff PER_STATEMENT(AdStackAllocaStmt) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 3c4ddfedf2cac..0460894fb65f6 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -501,7 +501,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { } } auto load_ptrs = irpass::analysis::get_load_pointers(stmt); - if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1) { + if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1 && + !stmt->is()) { // Identical load elimination auto load_ptr = load_ptrs.front(); if (!after_lower_access || diff --git a/taichi/ir/expression_printer.h b/taichi/ir/expression_printer.h index 74887c890e099..5d7bb1c012292 100644 --- a/taichi/ir/expression_printer.h +++ b/taichi/ir/expression_printer.h @@ -110,6 +110,14 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { } } + void visit(MatrixExpression *expr) override { + emit('['); + emit_vector(expr->elements); + emit(']'); + emit(fmt::format(" (dt={})", expr->dt->to_string())); + } + + void visit(IndexExpression *expr) override { expr->var->accept(this); emit('['); diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 56d5de025a9e8..6a00ed9e62121 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -429,6 +429,25 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, return ctx->push_back(var->stmt, offset_stmt); } +void MatrixExpression::type_check(CompileConfig *config) { + // TODO: typecheck matrix + for (auto &arg : elements) { + TI_ASSERT_TYPE_CHECKED(arg); + } +} + +void MatrixExpression::flatten(FlattenContext *ctx) { + // TODO: implement flatten + TI_ASSERT(this->dt->is()); + std::vector values; + for (auto &elt : elements) { + flatten_rvalue(elt, ctx); + values.push_back(elt->stmt); + } + stmt = ctx->push_back(values); + stmt->ret_type = this->dt; +} + bool IndexExpression::is_field() const { return var.is(); } @@ -960,6 +979,13 @@ Expr ASTBuilder::expr_alloca() { return var; } +Expr ASTBuilder::expr_alloca_local_matrix(const std::vector &shape, + const std::optional &dt, + const std::vector &elements) { + auto dtype = dt.value_or(PrimitiveType::unknown); + return Expr(std::make_shared(elements, shape, dtype)); +} + Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements) { diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index b56c3ec48f416..d1dfcbe56d2d5 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -504,6 +504,26 @@ class GlobalVariableExpression : public Expression { TI_DEFINE_ACCEPT_FOR_EXPRESSION }; +class MatrixExpression : public Expression { + public: + std::vector elements; + DataType dt; + + MatrixExpression(const std::vector &elements, + std::vector shape, + DataType element_type) + : elements(elements) { + this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type)); + this->ret_type = this->dt; + } + + void type_check(CompileConfig *config) override; + + void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION +}; + class IndexExpression : public Expression { public: // `var` is one of GlobalVariableExpression, ExternalTensorExpression, @@ -876,6 +896,9 @@ class ASTBuilder { const ExprGroup &args, const ExprGroup &outputs); Expr expr_alloca(); + Expr expr_alloca_local_matrix(const std::vector &shape, + const std::optional &dt, + const std::vector &elements); Expr expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements); diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 9a2ea841e6a66..784c0a499c202 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -1807,5 +1807,17 @@ class MeshPatchIndexStmt : public Stmt { TI_DEFINE_ACCEPT_AND_CLONE }; +class MatrixInitStmt : public Stmt { + public: + std::vector values; + + MatrixInitStmt(const std::vector &values) : values(values) { + TI_STMT_REG_FIELDS; + } + + TI_STMT_DEF_FIELDS(ret_type, values); + TI_DEFINE_ACCEPT_AND_CLONE +}; + } // namespace lang } // namespace taichi diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 6b9d1a51e7990..0188722f15e6f 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -87,6 +87,14 @@ std::string TensorType::to_string() const { return s; } +int TensorType::vector_width() const { + int vw = 1; + for (auto dim : shape_) { + vw *= dim; + } + return vw; +} + int Type::vector_width() const { return 1; // TODO: CPU vectorization } diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 339e2553ffb32..7588093588bcc 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -175,6 +175,8 @@ class TensorType : public Type { return shape_; } + int vector_width() const; + Type *get_compute_type() override { return this; } diff --git a/taichi/ir/type_utils.cpp b/taichi/ir/type_utils.cpp index e49428b022445..dc5f816ecfa19 100644 --- a/taichi/ir/type_utils.cpp +++ b/taichi/ir/type_utils.cpp @@ -53,6 +53,36 @@ int data_type_size(DataType t) { } } +std::string tensor_type_format_helper(const std::vector &shape, + std::string format_str, + int dim) { + std::string fmt = "["; + for (int i = 0; i < shape[dim]; ++i) { + if (dim != shape.size() - 1) { + fmt += tensor_type_format_helper(shape, format_str, dim + 1); + } else { + fmt += format_str; + } + if (i != shape[dim] - 1) { + fmt += ", "; + if (dim == 0) { + fmt += "\n"; + } + } + } + fmt += "]"; + return fmt; +} + +std::string tensor_type_format(DataType t) { + TI_ASSERT(t->is()); + auto tensor_type = t->as(); + auto shape = tensor_type->get_shape(); + auto element_type = tensor_type->get_element_type(); + auto element_type_format = data_type_format(element_type); + return tensor_type_format_helper(shape, element_type_format, 0); +} + std::string data_type_format(DataType dt) { if (dt->is_primitive(PrimitiveTypeID::i16)) { return "%hd"; @@ -79,6 +109,8 @@ std::string data_type_format(DataType dt) { // TaskCodeGenLLVM::visit(PrintStmt *stmt) and // TaskCodeGenCUDA::visit(PrintStmt *stmt) for more details. return "%f"; + } else if (dt->is()) { + return tensor_type_format(dt); } else { TI_NOT_IMPLEMENTED } diff --git a/taichi/program/compile_config.cpp b/taichi/program/compile_config.cpp index a15ebcf7f9c5a..6f3d7a16deb2f 100644 --- a/taichi/program/compile_config.cpp +++ b/taichi/program/compile_config.cpp @@ -48,6 +48,7 @@ CompileConfig::CompileConfig() { detect_read_only = true; ndarray_use_cached_allocator = true; use_mesh = false; + real_matrix = false; saturating_grid_dim = 0; max_block_dim = 0; diff --git a/taichi/program/compile_config.h b/taichi/program/compile_config.h index 7820834066041..14f17e419fd6d 100644 --- a/taichi/program/compile_config.h +++ b/taichi/program/compile_config.h @@ -44,6 +44,7 @@ struct CompileConfig { bool detect_read_only; bool ndarray_use_cached_allocator; bool use_mesh; + bool real_matrix; DataType default_fp; DataType default_ip; DataType default_up; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 39a2149db656f..f210522a92996 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -193,6 +193,7 @@ void export_lang(py::module &m) { .def_readwrite("ndarray_use_cached_allocator", &CompileConfig::ndarray_use_cached_allocator) .def_readwrite("use_mesh", &CompileConfig::use_mesh) + .def_readwrite("real_matrix", &CompileConfig::real_matrix) .def_readwrite("cc_compile_cmd", &CompileConfig::cc_compile_cmd) .def_readwrite("cc_link_cmd", &CompileConfig::cc_link_cmd) .def_readwrite("quant_opt_store_fusion", @@ -285,6 +286,7 @@ void export_lang(py::module &m) { .def("insert_external_func_call", &ASTBuilder::insert_external_func_call) .def("expr_alloca", &ASTBuilder::expr_alloca) .def("expr_alloca_local_tensor", &ASTBuilder::expr_alloca_local_tensor) + .def("expr_alloca_matrix", &ASTBuilder::expr_alloca_local_matrix) .def("expr_alloca_shared_array", &ASTBuilder::expr_alloca_shared_array) .def("create_assert_stmt", &ASTBuilder::create_assert_stmt) .def("expr_assign", &ASTBuilder::expr_assign) diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index 379acb1efdae0..edde242f48f96 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -135,6 +135,10 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { return llvm::Type::getInt64Ty(*ctx); } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*ctx); + } else if (dt->is()) { + auto vectorty = dt->as(); + auto dtype = this->get_data_type(vectorty->get_element_type()); + return llvm::VectorType::get(dtype, vectorty->get_num_elements(), false); } else { TI_INFO(data_type_name(dt)); TI_NOT_IMPLEMENTED diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 4ce787066f230..75ed12641e31b 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -164,8 +164,14 @@ class AlgSimp : public BasicStmtVisitor { if ((fast_math || is_integral(stmt->ret_type)) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a / a -> 1 - replace_with_one(stmt); - return true; + if (stmt->lhs->ret_type->is() && + stmt->rhs->ret_type->is()) { + replace_with_one(stmt); + return true; + } else { + // TODO: handle tensor division + return false; + } } if (fast_math && rhs && is_real(rhs->ret_type) && stmt->op_type != BinaryOpType::floordiv) { @@ -245,7 +251,13 @@ class AlgSimp : public BasicStmtVisitor { (fast_math || is_integral(stmt->ret_type)) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a -^ a -> 0 - replace_with_zero(stmt); + if (stmt->lhs->ret_type->is() && + stmt->rhs->ret_type->is()) { + replace_with_zero(stmt); + } else { + // TODO: handle tensor operations + return; + } } } else if (rhs && stmt->op_type == BinaryOpType::pow) { float64 exponent = rhs->val[0].val_cast_to_float64(); @@ -330,6 +342,11 @@ class AlgSimp : public BasicStmtVisitor { modifier.erase(stmt); } else if (alg_is_zero(lhs) || alg_is_zero(rhs)) { // 0 & a -> 0, a & 0 -> 0 + if (stmt->ret_type->is() || + stmt->rhs->ret_type->is()) { + // TODO: support tensor type + return; + } replace_with_zero(stmt); } else if (irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // a & a -> a @@ -344,6 +361,11 @@ class AlgSimp : public BasicStmtVisitor { // a << 0 -> a // 0 << a -> 0 // 0 >> a -> 0 + if (stmt->ret_type->is() || + stmt->rhs->ret_type->is()) { + // TODO: support tensor type + return; + } TI_ASSERT(stmt->lhs->ret_type == stmt->ret_type); stmt->replace_usages_with(stmt->lhs); modifier.erase(stmt); diff --git a/taichi/transforms/die.cpp b/taichi/transforms/die.cpp index 3176d8f576949..f6d696c4da498 100644 --- a/taichi/transforms/die.cpp +++ b/taichi/transforms/die.cpp @@ -108,6 +108,13 @@ class DIE : public IRVisitor { } stmt->all_blocks_accept(this, true); } + + void visit(MatrixInitStmt *stmt) override { + register_usage(stmt); + for (auto &elts : stmt->values) { + elts->accept(this); + } + } }; namespace irpass { diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index ca462e42773e8..97538ee93acce 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -794,6 +794,19 @@ class IRPrinter : public IRVisitor { print("{}{} = ref({})", stmt->type_hint(), stmt->name(), stmt->var->name()); } + void visit(MatrixInitStmt *stmt) override { + std::string result = ""; + result += fmt::format("{}{} = [", stmt->type_hint(), stmt->name()); + for (int i = 0; i < stmt->values.size(); ++i) { + result += stmt->values[i]->name(); + if (i != stmt->values.size() - 1) { + result += ", "; + } + } + result += "]"; + print(result); + } + private: std::string expr_to_string(Expr &expr) { return expr_to_string(expr.expr.get()); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index e16c62501ba3c..cd4d4f7e6cf59 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -553,6 +553,26 @@ class TypeCheck : public IRVisitor { stmt->ret_type = stmt->var->ret_type; stmt->ret_type.set_is_pointer(true); } + + void visit(MatrixInitStmt *stmt) override { + TI_ASSERT_INFO(stmt->ret_type->is(), + "Matrix should have tensor type, got {}", + stmt->ret_type->to_string()); + auto tensor_type = stmt->ret_type->as(); + auto element_dtype = tensor_type->get_element_type(); + for (auto elt : stmt->values) { + element_dtype = promoted_type(element_dtype, elt->ret_type); + } + for (int i = 0; i < stmt->values.size(); ++i) { + if (element_dtype != stmt->values[i]->ret_type) { + cast(stmt->values[i], element_dtype); + } + } + if (element_dtype != tensor_type->get_element_type()) { + stmt->ret_type = TypeFactory::create_tensor_type(tensor_type->get_shape(), + element_dtype); + } + } }; namespace irpass { From 549a359e5e69d648a047d8a6698db8decfcf2852 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Aug 2022 17:51:12 +0000 Subject: [PATCH 02/92] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/taichi/lang/ast/ast_transformer.py | 6 +++--- taichi/analysis/gen_offline_cache_key.cpp | 3 +-- taichi/ir/expression_printer.h | 1 - taichi/transforms/alg_simp.cpp | 6 +++--- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index aeca62a0c7421..f70b9e20c494e 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -14,7 +14,7 @@ from taichi.lang.ast.symbol_resolver import ASTResolver from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError from taichi.lang.field import Field -from taichi.lang.matrix import (Matrix, Vector, MatrixType, _PyScopeMatrixImpl, +from taichi.lang.matrix import (Matrix, MatrixType, Vector, _PyScopeMatrixImpl, _TiScopeMatrixImpl) from taichi.lang.snode import append from taichi.lang.util import in_taichi_scope, is_taichi_class, to_taichi_type @@ -487,9 +487,9 @@ def build_Call(ctx, node): args.insert(0, node.func.value.ptr) node.ptr = impl.ti_format(*args, **keywords) return node.ptr - + if isinstance(node.func, - ast.Attribute) and func == Matrix or func == Vector: + ast.Attribute) and func == Matrix or func == Vector: node.ptr = matrix.make_matrix(*args, **keywords) return node.ptr diff --git a/taichi/analysis/gen_offline_cache_key.cpp b/taichi/analysis/gen_offline_cache_key.cpp index 06c4a012e02b3..e0c825761c973 100644 --- a/taichi/analysis/gen_offline_cache_key.cpp +++ b/taichi/analysis/gen_offline_cache_key.cpp @@ -170,10 +170,9 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { emit(expr->dt); for (auto elt : expr->elements) { emit(elt); - } + } } - void visit(StrideExpression *expr) override { emit(ExprOpCode::StrideExpression); emit(expr->var); diff --git a/taichi/ir/expression_printer.h b/taichi/ir/expression_printer.h index 5d7bb1c012292..a32a0468d6907 100644 --- a/taichi/ir/expression_printer.h +++ b/taichi/ir/expression_printer.h @@ -117,7 +117,6 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { emit(fmt::format(" (dt={})", expr->dt->to_string())); } - void visit(IndexExpression *expr) override { expr->var->accept(this); emit('['); diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 75ed12641e31b..f4892fdc439a8 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -165,7 +165,7 @@ class AlgSimp : public BasicStmtVisitor { irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a / a -> 1 if (stmt->lhs->ret_type->is() && - stmt->rhs->ret_type->is()) { + stmt->rhs->ret_type->is()) { replace_with_one(stmt); return true; } else { @@ -252,7 +252,7 @@ class AlgSimp : public BasicStmtVisitor { irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a -^ a -> 0 if (stmt->lhs->ret_type->is() && - stmt->rhs->ret_type->is()) { + stmt->rhs->ret_type->is()) { replace_with_zero(stmt); } else { // TODO: handle tensor operations @@ -362,7 +362,7 @@ class AlgSimp : public BasicStmtVisitor { // 0 << a -> 0 // 0 >> a -> 0 if (stmt->ret_type->is() || - stmt->rhs->ret_type->is()) { + stmt->rhs->ret_type->is()) { // TODO: support tensor type return; } From 07a4dc12b34d98da7267cee950cc83c5b0a44381 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Mon, 15 Aug 2022 13:57:46 -0400 Subject: [PATCH 03/92] matrix assign --- python/taichi/lang/impl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index d9ab715b28dad..60389ce57227d 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -54,6 +54,8 @@ def expr_init(rhs): if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")): return Matrix(*rhs.to_list(), ndim=rhs.ndim) if isinstance(rhs, Matrix): + if current_cfg().real_matrix: + return rhs return Matrix(rhs.to_list(), ndim=rhs.ndim) if isinstance(rhs, SharedArray): return rhs From efca3f04a08fe1683d3801c8b83dd7575d573f37 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Mon, 15 Aug 2022 14:41:39 -0400 Subject: [PATCH 04/92] move checks to caller side --- python/taichi/lang/ast/ast_transformer.py | 4 ++-- python/taichi/lang/matrix.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index f70b9e20c494e..45e01b18a8ab2 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -488,8 +488,8 @@ def build_Call(ctx, node): node.ptr = impl.ti_format(*args, **keywords) return node.ptr - if isinstance(node.func, - ast.Attribute) and func == Matrix or func == Vector: + if (isinstance(node.func, + ast.Attribute) and (func in {Matrix, Vector})) and impl.current_cfg().real_matrix and in_taichi_scope(): node.ptr = matrix.make_matrix(*args, **keywords) return node.ptr diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index d694659de0975..8d8c59d9f2922 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -97,9 +97,7 @@ def prop_setter(instance, value): return cls -def make_matrix(arr, dt=None, suppress_warning=False, is_ref=False, **kwargs): - if not impl.current_cfg().real_matrix or in_python_scope(): - return Matrix(arr, dt, suppress_warning, is_ref, **kwargs) +def make_matrix(arr, dt=None): cast = (lambda x: ops_mod.cast(x, dt)) if dt else ( lambda x: x if isinstance(x, expr.Expr) else expr.Expr(x)) if len(arr) == 0: From 82c941383f29067ea973516bf1ee2a332038c703 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Aug 2022 18:43:03 +0000 Subject: [PATCH 05/92] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/taichi/lang/ast/ast_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 45e01b18a8ab2..4230da35f78b9 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -488,8 +488,8 @@ def build_Call(ctx, node): node.ptr = impl.ti_format(*args, **keywords) return node.ptr - if (isinstance(node.func, - ast.Attribute) and (func in {Matrix, Vector})) and impl.current_cfg().real_matrix and in_taichi_scope(): + if (isinstance(node.func, ast.Attribute) and (func in {Matrix, Vector}) + ) and impl.current_cfg().real_matrix and in_taichi_scope(): node.ptr = matrix.make_matrix(*args, **keywords) return node.ptr From 38ec7501c5b24c1e1e14f2ec480d4a556f4d84a6 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Mon, 15 Aug 2022 15:34:07 -0400 Subject: [PATCH 06/92] use == --- python/taichi/lang/ast/ast_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 45e01b18a8ab2..85cf388c8aa46 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -489,7 +489,7 @@ def build_Call(ctx, node): return node.ptr if (isinstance(node.func, - ast.Attribute) and (func in {Matrix, Vector})) and impl.current_cfg().real_matrix and in_taichi_scope(): + ast.Attribute) and (func == Matrix or func == Vector)) and impl.current_cfg().real_matrix and in_taichi_scope(): node.ptr = matrix.make_matrix(*args, **keywords) return node.ptr From 9c91103358a48b489cfee64c1a502c10552cca52 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 17 Aug 2022 13:37:56 -0400 Subject: [PATCH 07/92] refine impl --- python/taichi/lang/matrix.py | 17 +++++++++-------- taichi/ir/frontend_ir.cpp | 16 +++++++++------- taichi/ir/frontend_ir.h | 5 +++-- taichi/program/function.cpp | 2 +- taichi/program/kernel.cpp | 2 +- taichi/python/export_lang.cpp | 1 - taichi/runtime/llvm/llvm_context.cpp | 2 +- taichi/transforms/alg_simp.cpp | 24 +++++++++++++----------- 8 files changed, 37 insertions(+), 32 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 8d8c59d9f2922..06dea39d7adb3 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -98,17 +98,18 @@ def prop_setter(instance, value): def make_matrix(arr, dt=None): - cast = (lambda x: ops_mod.cast(x, dt)) if dt else ( - lambda x: x if isinstance(x, expr.Expr) else expr.Expr(x)) if len(arr) == 0: - return impl.expr_init(impl.expr_init_matrix([0], dt, [])) - if not isinstance(arr[0], Iterable): + return impl.expr_init(impl.expr_init_local_tensor([0], ti_python_core.DataType_unknown, impl.make_expr_group([]))) + is_matrix = isinstance(arr[0], Iterable) + if dt is None: + dt = _make_entries_initializer(is_matrix).infer_dt(arr) + if not is_matrix: return impl.expr_init( - impl.expr_init_matrix([len(arr)], dt, - [cast(elt).ptr for elt in arr])) + impl.expr_init_local_tensor([len(arr)], dt, + impl.make_expr_group([expr.Expr(elt) for elt in arr]))) return impl.expr_init( - impl.expr_init_matrix([len(arr), len(arr[0])], dt, - [cast(elt).ptr for row in arr for elt in row])) + impl.expr_init_local_tensor([len(arr), len(arr[0])], dt, + impl.make_expr_group([expr.Expr(elt) for row in arr for elt in row]))) class _MatrixBaseImpl: diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 6a00ed9e62121..3984a4dfa474e 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -87,9 +87,9 @@ FrontendForStmt::FrontendForStmt(const ExprGroup &loop_var, } } -FrontendContext::FrontendContext(Arch arch) { +FrontendContext::FrontendContext(Arch arch, bool real_matrix) { root_node_ = std::make_unique(); - current_builder_ = std::make_unique(root_node_.get(), arch); + current_builder_ = std::make_unique(root_node_.get(), arch, real_matrix); } FrontendForStmt::FrontendForStmt(const Expr &loop_var, @@ -979,16 +979,18 @@ Expr ASTBuilder::expr_alloca() { return var; } -Expr ASTBuilder::expr_alloca_local_matrix(const std::vector &shape, - const std::optional &dt, - const std::vector &elements) { - auto dtype = dt.value_or(PrimitiveType::unknown); - return Expr(std::make_shared(elements, shape, dtype)); +Expr make_local_matrix(const std::vector &shape, + const DataType &dt, + const std::vector &elements) { + return Expr(std::make_shared(elements, shape, dt)); } Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements) { + if (this->use_real_matrix_) { + return make_local_matrix(shape,element_type, elements.exprs); + } auto var = Expr(std::make_shared(get_next_id())); this->insert(std::make_unique( std::static_pointer_cast(var.expr)->id, shape, diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index d1dfcbe56d2d5..8e774ebfcccfd 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -861,9 +861,10 @@ class ASTBuilder { Arch arch_; ForLoopDecoratorRecorder for_loop_dec_; int id_counter_{0}; + bool use_real_matrix_{false}; public: - ASTBuilder(Block *initial, Arch arch) : arch_(arch) { + ASTBuilder(Block *initial, Arch arch, bool real_matrix) : arch_(arch), use_real_matrix_(real_matrix) { stack_.push_back(initial); loop_state_stack_.push_back(None); } @@ -964,7 +965,7 @@ class FrontendContext { std::unique_ptr root_node_; public: - FrontendContext(Arch arch); + FrontendContext(Arch arch, bool real_matrix); ASTBuilder &builder() { return *current_builder_; diff --git a/taichi/program/function.cpp b/taichi/program/function.cpp index e2655fed02432..2ff368711fddd 100644 --- a/taichi/program/function.cpp +++ b/taichi/program/function.cpp @@ -12,7 +12,7 @@ Function::Function(Program *program, const FunctionKey &func_key) } void Function::set_function_body(const std::function &func) { - context = std::make_unique(program->config.arch); + context = std::make_unique(program->config.arch, program->config.real_matrix); ir = context->get_root(); { // Note: this is not a mutex diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 9f417fb408672..89cf804a2961d 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -391,7 +391,7 @@ void Kernel::init(Program &program, is_accessor = false; is_evaluator = false; compiled_ = nullptr; - context = std::make_unique(program.config.arch); + context = std::make_unique(program.config.arch, program.config.real_matrix); ir = context->get_root(); ir_is_ast_ = true; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index f210522a92996..b9ceda1b6cbfe 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -286,7 +286,6 @@ void export_lang(py::module &m) { .def("insert_external_func_call", &ASTBuilder::insert_external_func_call) .def("expr_alloca", &ASTBuilder::expr_alloca) .def("expr_alloca_local_tensor", &ASTBuilder::expr_alloca_local_tensor) - .def("expr_alloca_matrix", &ASTBuilder::expr_alloca_local_matrix) .def("expr_alloca_shared_array", &ASTBuilder::expr_alloca_shared_array) .def("create_assert_stmt", &ASTBuilder::create_assert_stmt) .def("expr_assign", &ASTBuilder::expr_assign) diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index edde242f48f96..75cd098bf7b3e 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -138,7 +138,7 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { } else if (dt->is()) { auto vectorty = dt->as(); auto dtype = this->get_data_type(vectorty->get_element_type()); - return llvm::VectorType::get(dtype, vectorty->get_num_elements(), false); + return llvm::VectorType::get(dtype, vectorty->get_num_elements(), /*scalable=*/false); } else { TI_INFO(data_type_name(dt)); TI_NOT_IMPLEMENTED diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index f4892fdc439a8..cb8b4c1741e13 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -164,14 +164,13 @@ class AlgSimp : public BasicStmtVisitor { if ((fast_math || is_integral(stmt->ret_type)) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a / a -> 1 - if (stmt->lhs->ret_type->is() && - stmt->rhs->ret_type->is()) { - replace_with_one(stmt); - return true; - } else { - // TODO: handle tensor division + if (stmt->lhs->ret_type->is() || + stmt->rhs->ret_type->is()) { + // TODO: support tensor type return false; } + replace_with_one(stmt); + return true; } if (fast_math && rhs && is_real(rhs->ret_type) && stmt->op_type != BinaryOpType::floordiv) { @@ -251,12 +250,11 @@ class AlgSimp : public BasicStmtVisitor { (fast_math || is_integral(stmt->ret_type)) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a -^ a -> 0 - if (stmt->lhs->ret_type->is() && - stmt->rhs->ret_type->is()) { - replace_with_zero(stmt); + if (stmt->lhs->ret_type->is() && + stmt->rhs->ret_type->is()) { + // TODO: support tensor type } else { - // TODO: handle tensor operations - return; + replace_with_zero(stmt); } } } else if (rhs && stmt->op_type == BinaryOpType::pow) { @@ -267,6 +265,10 @@ class AlgSimp : public BasicStmtVisitor { modifier.erase(stmt); } else if (exponent == 0) { // a ** 0 -> 1 + if (stmt->ret_type->is()) { + // TODO: support tensor type + return; + } replace_with_one(stmt); } else if (exponent == 0.5) { // a ** 0.5 -> sqrt(a) From 28f3e0ac57c20f1d3db07b9d5000d85845af0471 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Aug 2022 17:39:30 +0000 Subject: [PATCH 08/92] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/taichi/lang/matrix.py | 12 +++++++++--- taichi/ir/frontend_ir.cpp | 5 +++-- taichi/ir/frontend_ir.h | 3 ++- taichi/program/function.cpp | 3 ++- taichi/program/kernel.cpp | 3 ++- taichi/runtime/llvm/llvm_context.cpp | 3 ++- 6 files changed, 20 insertions(+), 9 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 06dea39d7adb3..0ab648ddb5155 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -99,17 +99,23 @@ def prop_setter(instance, value): def make_matrix(arr, dt=None): if len(arr) == 0: - return impl.expr_init(impl.expr_init_local_tensor([0], ti_python_core.DataType_unknown, impl.make_expr_group([]))) + return impl.expr_init( + impl.expr_init_local_tensor([0], ti_python_core.DataType_unknown, + impl.make_expr_group([]))) is_matrix = isinstance(arr[0], Iterable) if dt is None: dt = _make_entries_initializer(is_matrix).infer_dt(arr) if not is_matrix: return impl.expr_init( impl.expr_init_local_tensor([len(arr)], dt, - impl.make_expr_group([expr.Expr(elt) for elt in arr]))) + impl.make_expr_group( + [expr.Expr(elt) for elt in arr]))) return impl.expr_init( impl.expr_init_local_tensor([len(arr), len(arr[0])], dt, - impl.make_expr_group([expr.Expr(elt) for row in arr for elt in row]))) + impl.make_expr_group([ + expr.Expr(elt) for row in arr + for elt in row + ]))) class _MatrixBaseImpl: diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 3984a4dfa474e..edb1aca141b47 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -89,7 +89,8 @@ FrontendForStmt::FrontendForStmt(const ExprGroup &loop_var, FrontendContext::FrontendContext(Arch arch, bool real_matrix) { root_node_ = std::make_unique(); - current_builder_ = std::make_unique(root_node_.get(), arch, real_matrix); + current_builder_ = + std::make_unique(root_node_.get(), arch, real_matrix); } FrontendForStmt::FrontendForStmt(const Expr &loop_var, @@ -989,7 +990,7 @@ Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements) { if (this->use_real_matrix_) { - return make_local_matrix(shape,element_type, elements.exprs); + return make_local_matrix(shape, element_type, elements.exprs); } auto var = Expr(std::make_shared(get_next_id())); this->insert(std::make_unique( diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 8e774ebfcccfd..effd8abd566e9 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -864,7 +864,8 @@ class ASTBuilder { bool use_real_matrix_{false}; public: - ASTBuilder(Block *initial, Arch arch, bool real_matrix) : arch_(arch), use_real_matrix_(real_matrix) { + ASTBuilder(Block *initial, Arch arch, bool real_matrix) + : arch_(arch), use_real_matrix_(real_matrix) { stack_.push_back(initial); loop_state_stack_.push_back(None); } diff --git a/taichi/program/function.cpp b/taichi/program/function.cpp index 2ff368711fddd..64dd970cde4ea 100644 --- a/taichi/program/function.cpp +++ b/taichi/program/function.cpp @@ -12,7 +12,8 @@ Function::Function(Program *program, const FunctionKey &func_key) } void Function::set_function_body(const std::function &func) { - context = std::make_unique(program->config.arch, program->config.real_matrix); + context = std::make_unique(program->config.arch, + program->config.real_matrix); ir = context->get_root(); { // Note: this is not a mutex diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 89cf804a2961d..5f7d5ea31db31 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -391,7 +391,8 @@ void Kernel::init(Program &program, is_accessor = false; is_evaluator = false; compiled_ = nullptr; - context = std::make_unique(program.config.arch, program.config.real_matrix); + context = std::make_unique(program.config.arch, + program.config.real_matrix); ir = context->get_root(); ir_is_ast_ = true; diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index 75cd098bf7b3e..076e7ad8788e0 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -138,7 +138,8 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { } else if (dt->is()) { auto vectorty = dt->as(); auto dtype = this->get_data_type(vectorty->get_element_type()); - return llvm::VectorType::get(dtype, vectorty->get_num_elements(), /*scalable=*/false); + return llvm::VectorType::get(dtype, vectorty->get_num_elements(), + /*scalable=*/false); } else { TI_INFO(data_type_name(dt)); TI_NOT_IMPLEMENTED From bf719a3ddd93c79839cbb6c722f0bc280f3fd4cb Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 17 Aug 2022 14:06:07 -0400 Subject: [PATCH 09/92] no long in use --- taichi/ir/type.cpp | 8 -------- taichi/ir/type.h | 2 -- 2 files changed, 10 deletions(-) diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 0188722f15e6f..6b9d1a51e7990 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -87,14 +87,6 @@ std::string TensorType::to_string() const { return s; } -int TensorType::vector_width() const { - int vw = 1; - for (auto dim : shape_) { - vw *= dim; - } - return vw; -} - int Type::vector_width() const { return 1; // TODO: CPU vectorization } diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 7588093588bcc..339e2553ffb32 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -175,8 +175,6 @@ class TensorType : public Type { return shape_; } - int vector_width() const; - Type *get_compute_type() override { return this; } From 65199eabc6e0f28315688dc666e4b6e543a8d10b Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 17 Aug 2022 14:14:29 -0400 Subject: [PATCH 10/92] add some comments --- taichi/codegen/llvm/codegen_llvm.cpp | 2 +- taichi/ir/frontend_ir.h | 4 ++++ taichi/ir/statements.h | 3 +++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 5b6d31a345f70..bc54a5a2e8f51 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -692,7 +692,7 @@ llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) { auto tensor_type = dt->cast(); auto element_type = llvm_type(tensor_type->get_element_type()); return llvm::VectorType::get(element_type, tensor_type->get_num_elements(), - false); + /*scalable=*/false); } else { TI_NOT_IMPLEMENTED; } diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index effd8abd566e9..a544d15793dc9 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -504,6 +504,10 @@ class GlobalVariableExpression : public Expression { TI_DEFINE_ACCEPT_FOR_EXPRESSION }; +/** + * Creating a local matrix; + * lowered from ti.Matrix with real_matrix=True + */ class MatrixExpression : public Expression { public: std::vector elements; diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 784c0a499c202..11c87dd98b67c 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -1807,6 +1807,9 @@ class MeshPatchIndexStmt : public Stmt { TI_DEFINE_ACCEPT_AND_CLONE }; +/** + * Initialization of a local matrix + */ class MatrixInitStmt : public Stmt { public: std::vector values; From cbf1ea84ec52f9c721e1b687e06068de1fdb2386 Mon Sep 17 00:00:00 2001 From: Mike He Date: Tue, 23 Aug 2022 19:17:51 -0400 Subject: [PATCH 11/92] get rid of always-true condition Co-authored-by: Yi Xu --- python/taichi/lang/ast/ast_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 2376f77c79092..fcc0b46cfebb8 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -490,7 +490,7 @@ def build_Call(ctx, node): if (isinstance(node.func, ast.Attribute) and (func == Matrix or func == Vector) - ) and impl.current_cfg().real_matrix and in_taichi_scope(): + ) and impl.current_cfg().real_matrix: node.ptr = matrix.make_matrix(*args, **keywords) return node.ptr From 4c8d6b7d7dc9e957d180baca78ad531c366b1e53 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Aug 2022 23:19:08 +0000 Subject: [PATCH 12/92] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/taichi/lang/ast/ast_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index fcc0b46cfebb8..135a8f14d175c 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -489,8 +489,8 @@ def build_Call(ctx, node): return node.ptr if (isinstance(node.func, ast.Attribute) and - (func == Matrix or func == Vector) - ) and impl.current_cfg().real_matrix: + (func == Matrix + or func == Vector)) and impl.current_cfg().real_matrix: node.ptr = matrix.make_matrix(*args, **keywords) return node.ptr From 1a9df8c7107c0b48dac477253ba61821b80b3edf Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 23 Aug 2022 19:40:49 -0400 Subject: [PATCH 13/92] save --- python/taichi/lang/impl.py | 16 ++++++++-------- python/taichi/lang/matrix.py | 6 +++--- taichi/ir/frontend_ir.h | 3 --- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 60389ce57227d..4f45a8fb4c8d2 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -15,7 +15,8 @@ from taichi.lang.field import Field, ScalarField from taichi.lang.kernel_arguments import SparseMatrixProxy from taichi.lang.matrix import (Matrix, MatrixField, MatrixNdarray, MatrixType, - _IntermediateMatrix, _MatrixFieldElement) + _IntermediateMatrix, _MatrixFieldElement, + make_matrix) from taichi.lang.mesh import (ConvType, MeshElementFieldProxy, MeshInstance, MeshRelationAccessProxy, MeshReorderedMatrixFieldProxy, @@ -35,12 +36,6 @@ def expr_init_local_tensor(shape, element_type, elements): shape, element_type, elements) -@taichi_scope -def expr_init_matrix(shape, element_type, elements): - return get_runtime().prog.current_ast_builder().expr_alloca_matrix( - shape, element_type, elements) - - @taichi_scope def expr_init_shared_array(shape, element_type): return get_runtime().prog.current_ast_builder().expr_alloca_shared_array( @@ -55,7 +50,12 @@ def expr_init(rhs): return Matrix(*rhs.to_list(), ndim=rhs.ndim) if isinstance(rhs, Matrix): if current_cfg().real_matrix: - return rhs + if rhs.ndim == 1: + entries = [rhs(i) for i in range(rhs.n)] + else: + entries = [[rhs(i, j) for j in range(rhs.m)] + for i in range(rhs.n)] + return make_matrix(entries) return Matrix(rhs.to_list(), ndim=rhs.ndim) if isinstance(rhs, SharedArray): return rhs diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 0ab648ddb5155..1a755b6492004 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -99,18 +99,18 @@ def prop_setter(instance, value): def make_matrix(arr, dt=None): if len(arr) == 0: - return impl.expr_init( + return impl.Expr( impl.expr_init_local_tensor([0], ti_python_core.DataType_unknown, impl.make_expr_group([]))) is_matrix = isinstance(arr[0], Iterable) if dt is None: dt = _make_entries_initializer(is_matrix).infer_dt(arr) if not is_matrix: - return impl.expr_init( + return impl.Expr( impl.expr_init_local_tensor([len(arr)], dt, impl.make_expr_group( [expr.Expr(elt) for elt in arr]))) - return impl.expr_init( + return impl.Expr( impl.expr_init_local_tensor([len(arr), len(arr[0])], dt, impl.make_expr_group([ expr.Expr(elt) for row in arr diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index a544d15793dc9..afaf6db776bc3 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -902,9 +902,6 @@ class ASTBuilder { const ExprGroup &args, const ExprGroup &outputs); Expr expr_alloca(); - Expr expr_alloca_local_matrix(const std::vector &shape, - const std::optional &dt, - const std::vector &elements); Expr expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements); From 834699e1bc47d1883f605489957cbc1e1192438f Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 23 Aug 2022 19:50:06 -0400 Subject: [PATCH 14/92] some fixes for print and matrix expr --- taichi/codegen/llvm/codegen_llvm.cpp | 8 +++++++- taichi/ir/frontend_ir.cpp | 1 + taichi/ir/frontend_ir.h | 1 - 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index bc54a5a2e8f51..4609efa1d43f8 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -807,8 +807,14 @@ void TaskCodeGenLLVM::visit(PrintStmt *stmt) { auto value = llvm_val[arg_stmt]; if (arg_stmt->ret_type->is()) { auto dtype = arg_stmt->ret_type->cast(); + auto elem_type = dtype->get_element_type(); for (int i = 0; i < dtype->get_num_elements(); ++i) { - args.push_back(builder->CreateExtractElement(value, i)); + auto elem_value = builder->CreateExtractElement(value, i); + if (elem_type->is_primitive(PrimitiveTypeID::f32) || + elem_type->is_primitive(PrimitiveTypeID::f16)) + elem_value = builder->CreateFPExt( + elem_value, tlctx->get_data_type(PrimitiveType::f64)); + args.push_back(elem_value); } formats += data_type_format(arg_stmt->ret_type); } else { diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index edb1aca141b47..d272d295c98df 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -435,6 +435,7 @@ void MatrixExpression::type_check(CompileConfig *config) { for (auto &arg : elements) { TI_ASSERT_TYPE_CHECKED(arg); } + ret_type = dt; } void MatrixExpression::flatten(FlattenContext *ctx) { diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index afaf6db776bc3..df472b954377e 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -518,7 +518,6 @@ class MatrixExpression : public Expression { DataType element_type) : elements(elements) { this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type)); - this->ret_type = this->dt; } void type_check(CompileConfig *config) override; From 23d7bf7bca927d138384cb5666613ac8c8e7b291 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 23 Aug 2022 19:53:20 -0400 Subject: [PATCH 15/92] fix codegen alloca size --- taichi/codegen/llvm/codegen_llvm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 4609efa1d43f8..397da5cf0b42b 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -124,7 +124,7 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) { void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { if (stmt->ret_type->is()) { auto tensor_type = stmt->ret_type->cast(); - auto type = tlctx->get_data_type(tensor_type); + auto type = tlctx->get_data_type(tensor_type->get_element_type()); auto array_size = tlctx->get_constant(tensor_type->get_num_elements()); // Return type is [array_size x type]*. if (stmt->is_shared) { From 6a8a8cbf4eb1b38a0e159a89fc14879f2ab03629 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 02:34:04 -0400 Subject: [PATCH 16/92] unsupport empty matrix --- python/taichi/lang/matrix.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 1a755b6492004..2d68b83d2b248 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -98,10 +98,7 @@ def prop_setter(instance, value): def make_matrix(arr, dt=None): - if len(arr) == 0: - return impl.Expr( - impl.expr_init_local_tensor([0], ti_python_core.DataType_unknown, - impl.make_expr_group([]))) + assert len(arr) > 0, "Cannot create empty matrix" is_matrix = isinstance(arr[0], Iterable) if dt is None: dt = _make_entries_initializer(is_matrix).infer_dt(arr) @@ -440,10 +437,8 @@ def __init__(self, raise TaichiTypeError( "An Matrix/Vector can only be initialized with an array-like object" ) - if len(arr) == 0: - mat = [] - self.ndim = 0 - elif isinstance(arr[0], Matrix): + assert len(arr) > 0, "Cannot create empty matrix" + if isinstance(arr[0], Matrix): raise Exception('cols/rows required when using list of vectors') else: is_matrix = isinstance(arr[0], Iterable) From 08926eff85bc8465e945ed350fd5c5f6e797d738 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 02:41:31 -0400 Subject: [PATCH 17/92] only check and cast elements --- taichi/transforms/type_check.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index cd4d4f7e6cf59..c0a24954b2bad 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -560,18 +560,11 @@ class TypeCheck : public IRVisitor { stmt->ret_type->to_string()); auto tensor_type = stmt->ret_type->as(); auto element_dtype = tensor_type->get_element_type(); - for (auto elt : stmt->values) { - element_dtype = promoted_type(element_dtype, elt->ret_type); - } for (int i = 0; i < stmt->values.size(); ++i) { if (element_dtype != stmt->values[i]->ret_type) { cast(stmt->values[i], element_dtype); } } - if (element_dtype != tensor_type->get_element_type()) { - stmt->ret_type = TypeFactory::create_tensor_type(tensor_type->get_shape(), - element_dtype); - } } }; From 1ae02aae51b5b1fa86d7656b1b007da86b28907a Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 02:44:02 -0400 Subject: [PATCH 18/92] fmt Vectors to one line --- taichi/ir/type_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/ir/type_utils.cpp b/taichi/ir/type_utils.cpp index dc5f816ecfa19..76ee3aa1f7b30 100644 --- a/taichi/ir/type_utils.cpp +++ b/taichi/ir/type_utils.cpp @@ -65,7 +65,7 @@ std::string tensor_type_format_helper(const std::vector &shape, } if (i != shape[dim] - 1) { fmt += ", "; - if (dim == 0) { + if (dim == 0 && dim != shape.size() - 1) { fmt += "\n"; } } From 17412e81060c2212a09d129eea6c3c86d5f3258b Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 02:52:50 -0400 Subject: [PATCH 19/92] lift duplicate part --- taichi/codegen/llvm/codegen_llvm.cpp | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 397da5cf0b42b..695e28e96a310 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -801,6 +801,13 @@ void TaskCodeGenLLVM::visit(PrintStmt *stmt) { TI_ASSERT(stmt->width() == 1); std::vector args; std::string formats; + auto value_for_printf = [this](llvm::Value *to_print, DataType dtype) { + if (dtype->is_primitive(PrimitiveTypeID::f32) || + dtype->is_primitive(PrimitiveTypeID::f16)) + return this->builder->CreateFPExt( + to_print, this->tlctx->get_data_type(PrimitiveType::f64)); + return to_print; + }; for (auto const &content : stmt->contents) { if (std::holds_alternative(content)) { auto arg_stmt = std::get(content); @@ -810,19 +817,11 @@ void TaskCodeGenLLVM::visit(PrintStmt *stmt) { auto elem_type = dtype->get_element_type(); for (int i = 0; i < dtype->get_num_elements(); ++i) { auto elem_value = builder->CreateExtractElement(value, i); - if (elem_type->is_primitive(PrimitiveTypeID::f32) || - elem_type->is_primitive(PrimitiveTypeID::f16)) - elem_value = builder->CreateFPExt( - elem_value, tlctx->get_data_type(PrimitiveType::f64)); - args.push_back(elem_value); + args.push_back(value_for_printf(elem_value, elem_type)); } formats += data_type_format(arg_stmt->ret_type); } else { - if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) || - arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) - value = builder->CreateFPExt( - value, tlctx->get_data_type(PrimitiveType::f64)); - args.push_back(value); + args.push_back(value_for_printf(value, arg_stmt->ret_type)); formats += data_type_format(arg_stmt->ret_type); } } else { From 5421fe1a8eae6c462f141567d603ac50bca4216c Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 03:02:49 -0400 Subject: [PATCH 20/92] clean-up --- taichi/analysis/data_source_analysis.cpp | 2 -- taichi/ir/control_flow_graph.cpp | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index 9baad69fb80fb..39bfa6750a0b9 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -37,8 +37,6 @@ std::vector get_load_pointers(Stmt *load_stmt) { return external_func->arg_stmts; } else if (auto ref = load_stmt->cast()) { return {ref->var}; - } else if (auto matrix_init = load_stmt->cast()) { - return matrix_init->values; } else if (auto ptr_offset = load_stmt->cast()) { return {ptr_offset->origin}; } else { diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 0460894fb65f6..3c4ddfedf2cac 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -501,8 +501,7 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { } } auto load_ptrs = irpass::analysis::get_load_pointers(stmt); - if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1 && - !stmt->is()) { + if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1) { // Identical load elimination auto load_ptr = load_ptrs.front(); if (!after_lower_access || From b9fd3a941629d9ded1fa53bc446f86084307478c Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 12:53:29 -0400 Subject: [PATCH 21/92] clean-up cse code --- taichi/analysis/same_statements.cpp | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index 5f908603c9497..c26957e81f4a8 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -196,24 +196,6 @@ class IRNodeComparator : public IRVisitor { basic_check(stmt); } - void visit(MatrixInitStmt *stmt) override { - basic_check(stmt); - if (!same) - return; - auto o = other_node_->as(); - if (stmt->values.size() != o->values.size()) { - same = false; - return; - } - for (int i = 0; i < stmt->values.size(); ++i) { - other_node_ = o->values[i]; - stmt->values[i]->accept(this); - other_node_ = o; - if (!same) - return; - } - } - void visit(IfStmt *stmt) override { basic_check(stmt); if (!same) From 43a456a1c97db6fa4a78d77a8ea90570afad2b78 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 13:03:56 -0400 Subject: [PATCH 22/92] breaks ci; keep as original impl --- python/taichi/lang/matrix.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 2d68b83d2b248..fe429e7238e59 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -437,8 +437,10 @@ def __init__(self, raise TaichiTypeError( "An Matrix/Vector can only be initialized with an array-like object" ) - assert len(arr) > 0, "Cannot create empty matrix" - if isinstance(arr[0], Matrix): + if len(arr) == 0: + mat = [] + self.ndim = 0 + elif isinstance(arr[0], Matrix): raise Exception('cols/rows required when using list of vectors') else: is_matrix = isinstance(arr[0], Iterable) From 78ad14ac822ba54621dc5d7f86ebb0f85eaeb7d6 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 13:10:13 -0400 Subject: [PATCH 23/92] handle alloca --- taichi/ir/frontend_ir.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index d272d295c98df..361d77945507c 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -991,7 +991,10 @@ Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements) { if (this->use_real_matrix_) { - return make_local_matrix(shape, element_type, elements.exprs); + auto matrix_expr = make_local_matrix(shape, element_type, elements.exprs); + auto v = this->expr_alloca(); + this->insert(std::make_unique(v, matrix_expr)); + return v; } auto var = Expr(std::make_shared(get_next_id())); this->insert(std::make_unique( From 40825f466a9b1a55218fcac2163b77b03e555852 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 13:18:40 -0400 Subject: [PATCH 24/92] move checks to front --- taichi/transforms/alg_simp.cpp | 41 +++++++++++++--------------------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index cb8b4c1741e13..f7830fad9d850 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -100,6 +100,11 @@ class AlgSimp : public BasicStmtVisitor { bool optimize_multiplication(BinaryOpStmt *stmt) { // return true iff the IR is modified + if (stmt->lhs->ret_type->is() || + stmt->rhs->ret_type->is()) { + // TODO: support tensor type + return false; + } auto lhs = stmt->lhs->cast(); auto rhs = stmt->rhs->cast(); TI_ASSERT(stmt->op_type == BinaryOpType::mul); @@ -151,6 +156,11 @@ class AlgSimp : public BasicStmtVisitor { bool optimize_division(BinaryOpStmt *stmt) { // return true iff the IR is modified + if (stmt->lhs->ret_type->is() || + stmt->rhs->ret_type->is()) { + // TODO: support tensor type + return false; + } auto rhs = stmt->rhs->cast(); TI_ASSERT(stmt->op_type == BinaryOpType::div || stmt->op_type == BinaryOpType::floordiv); @@ -164,11 +174,6 @@ class AlgSimp : public BasicStmtVisitor { if ((fast_math || is_integral(stmt->ret_type)) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a / a -> 1 - if (stmt->lhs->ret_type->is() || - stmt->rhs->ret_type->is()) { - // TODO: support tensor type - return false; - } replace_with_one(stmt); return true; } @@ -218,6 +223,11 @@ class AlgSimp : public BasicStmtVisitor { } void visit(BinaryOpStmt *stmt) override { + if (stmt->lhs->ret_type->is() || + stmt->rhs->ret_type->is()) { + // TODO: support tensor type + return; + } auto lhs = stmt->lhs->cast(); auto rhs = stmt->rhs->cast(); if (stmt->width() != 1) { @@ -250,12 +260,7 @@ class AlgSimp : public BasicStmtVisitor { (fast_math || is_integral(stmt->ret_type)) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a -^ a -> 0 - if (stmt->lhs->ret_type->is() && - stmt->rhs->ret_type->is()) { - // TODO: support tensor type - } else { - replace_with_zero(stmt); - } + replace_with_zero(stmt); } } else if (rhs && stmt->op_type == BinaryOpType::pow) { float64 exponent = rhs->val[0].val_cast_to_float64(); @@ -265,10 +270,6 @@ class AlgSimp : public BasicStmtVisitor { modifier.erase(stmt); } else if (exponent == 0) { // a ** 0 -> 1 - if (stmt->ret_type->is()) { - // TODO: support tensor type - return; - } replace_with_one(stmt); } else if (exponent == 0.5) { // a ** 0.5 -> sqrt(a) @@ -344,11 +345,6 @@ class AlgSimp : public BasicStmtVisitor { modifier.erase(stmt); } else if (alg_is_zero(lhs) || alg_is_zero(rhs)) { // 0 & a -> 0, a & 0 -> 0 - if (stmt->ret_type->is() || - stmt->rhs->ret_type->is()) { - // TODO: support tensor type - return; - } replace_with_zero(stmt); } else if (irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // a & a -> a @@ -363,11 +359,6 @@ class AlgSimp : public BasicStmtVisitor { // a << 0 -> a // 0 << a -> 0 // 0 >> a -> 0 - if (stmt->ret_type->is() || - stmt->rhs->ret_type->is()) { - // TODO: support tensor type - return; - } TI_ASSERT(stmt->lhs->ret_type == stmt->ret_type); stmt->replace_usages_with(stmt->lhs); modifier.erase(stmt); From f395d2a3bda9fee931412384427b0339fa90aed1 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 13:27:43 -0400 Subject: [PATCH 25/92] reuse code --- taichi/ir/frontend_ir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index e65e5009ceade..7627fceb79cd9 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -1004,7 +1004,7 @@ Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, if (this->use_real_matrix_) { auto matrix_expr = make_local_matrix(shape, element_type, elements.exprs); auto v = this->expr_alloca(); - this->insert(std::make_unique(v, matrix_expr)); + this->expr_assign(v, matrix_expr, tb); return v; } auto var = Expr(std::make_shared(get_next_id())); From 2a6a8e64828fad73ba21259cc7f1c052fe3d36b5 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 13:45:33 -0400 Subject: [PATCH 26/92] Revert "clean-up cse code" This reverts commit b9fd3a941629d9ded1fa53bc446f86084307478c. --- taichi/analysis/same_statements.cpp | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index c26957e81f4a8..5f908603c9497 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -196,6 +196,24 @@ class IRNodeComparator : public IRVisitor { basic_check(stmt); } + void visit(MatrixInitStmt *stmt) override { + basic_check(stmt); + if (!same) + return; + auto o = other_node_->as(); + if (stmt->values.size() != o->values.size()) { + same = false; + return; + } + for (int i = 0; i < stmt->values.size(); ++i) { + other_node_ = o->values[i]; + stmt->values[i]->accept(this); + other_node_ = o; + if (!same) + return; + } + } + void visit(IfStmt *stmt) override { basic_check(stmt); if (!same) From 7f8ca37da666d389ac6bf07ca6e098f3bf40c418 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 13:54:46 -0400 Subject: [PATCH 27/92] clean up together --- taichi/analysis/data_source_analysis.cpp | 2 -- taichi/analysis/same_statements.cpp | 18 ------------------ taichi/codegen/llvm/codegen_llvm.cpp | 1 + 3 files changed, 1 insertion(+), 20 deletions(-) diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index 39bfa6750a0b9..8bdeb72e9a70b 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -37,8 +37,6 @@ std::vector get_load_pointers(Stmt *load_stmt) { return external_func->arg_stmts; } else if (auto ref = load_stmt->cast()) { return {ref->var}; - } else if (auto ptr_offset = load_stmt->cast()) { - return {ptr_offset->origin}; } else { return std::vector(); } diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index 5f908603c9497..c26957e81f4a8 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -196,24 +196,6 @@ class IRNodeComparator : public IRVisitor { basic_check(stmt); } - void visit(MatrixInitStmt *stmt) override { - basic_check(stmt); - if (!same) - return; - auto o = other_node_->as(); - if (stmt->values.size() != o->values.size()) { - same = false; - return; - } - for (int i = 0; i < stmt->values.size(); ++i) { - other_node_ = o->values[i]; - stmt->values[i]->accept(this); - other_node_ = o; - if (!same) - return; - } - } - void visit(IfStmt *stmt) override { basic_check(stmt); if (!same) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 695e28e96a310..c436484d32351 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -689,6 +689,7 @@ llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) { } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*llvm_context); } else if (dt->is()) { + TI_ASSERT_INFO(kernel->program->config.real_matrix, "Real matrix not enabled but got TensorType"); auto tensor_type = dt->cast(); auto element_type = llvm_type(tensor_type->get_element_type()); return llvm::VectorType::get(element_type, tensor_type->get_num_elements(), From 988abb3c36b2854e5c0b8c5509a91f1058c0946a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Aug 2022 17:56:37 +0000 Subject: [PATCH 28/92] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/codegen/llvm/codegen_llvm.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index c436484d32351..ccaef8001cecb 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -689,7 +689,8 @@ llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) { } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*llvm_context); } else if (dt->is()) { - TI_ASSERT_INFO(kernel->program->config.real_matrix, "Real matrix not enabled but got TensorType"); + TI_ASSERT_INFO(kernel->program->config.real_matrix, + "Real matrix not enabled but got TensorType"); auto tensor_type = dt->cast(); auto element_type = llvm_type(tensor_type->get_element_type()); return llvm::VectorType::get(element_type, tensor_type->get_num_elements(), From 93b3a036c824b9906673c1558ea57502a6ab74a0 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 13:58:07 -0400 Subject: [PATCH 29/92] also checks for tlctx --- taichi/runtime/llvm/llvm_context.cpp | 3 ++- taichi/runtime/llvm/llvm_context.h | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index f0d2acbd450f7..a8d857d2c2ead 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -64,7 +64,7 @@ namespace lang { using namespace llvm; TaichiLLVMContext::TaichiLLVMContext(CompileConfig *config, Arch arch) - : arch_(arch) { + : config_(config), arch_(arch) { TI_TRACE("Creating Taichi llvm context for arch: {}", arch_name(arch)); main_thread_id_ = std::this_thread::get_id(); main_thread_data_ = get_this_thread_data(); @@ -143,6 +143,7 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*ctx); } else if (dt->is()) { + TI_ASSERT_INFO(config_->real_matrix, "Real matrix not enabled but got TensorType"); auto vectorty = dt->as(); auto dtype = this->get_data_type(vectorty->get_element_type()); return llvm::VectorType::get(dtype, vectorty->get_num_elements(), diff --git a/taichi/runtime/llvm/llvm_context.h b/taichi/runtime/llvm/llvm_context.h index afcccebbbbcfe..ae87699f48484 100644 --- a/taichi/runtime/llvm/llvm_context.h +++ b/taichi/runtime/llvm/llvm_context.h @@ -33,6 +33,7 @@ class TaichiLLVMContext { std::unique_ptr struct_module{nullptr}; ~ThreadLocalData(); }; + CompileConfig *config_; public: std::unique_ptr jit{nullptr}; From b2e101af0e0cca1e52400241caed7ca13cb48002 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 13:58:42 -0400 Subject: [PATCH 30/92] format --- taichi/runtime/llvm/llvm_context.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index a8d857d2c2ead..8e99f32503b58 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -143,7 +143,8 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*ctx); } else if (dt->is()) { - TI_ASSERT_INFO(config_->real_matrix, "Real matrix not enabled but got TensorType"); + TI_ASSERT_INFO(config_->real_matrix, + "Real matrix not enabled but got TensorType"); auto vectorty = dt->as(); auto dtype = this->get_data_type(vectorty->get_element_type()); return llvm::VectorType::get(dtype, vectorty->get_num_elements(), From 6fcf07026ebcf2535342035651bedb72581dd4b0 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 14:57:54 -0400 Subject: [PATCH 31/92] fix codegen: allocate pointer to vector --- taichi/codegen/llvm/codegen_llvm.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index ccaef8001cecb..4bb9c9f12c2fc 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -124,9 +124,8 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) { void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { if (stmt->ret_type->is()) { auto tensor_type = stmt->ret_type->cast(); - auto type = tlctx->get_data_type(tensor_type->get_element_type()); - auto array_size = tlctx->get_constant(tensor_type->get_num_elements()); - // Return type is [array_size x type]*. + auto type = tlctx->get_data_type(tensor_type); + // Return type is vector*. if (stmt->is_shared) { size_t data_element_size = tlctx->get_type_size( tlctx->get_data_type(tensor_type->get_element_type())); @@ -148,7 +147,7 @@ void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { tlctx->get_data_type(tensor_type->get_element_type()), 0); llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type); } else { - llvm_val[stmt] = create_entry_block_alloca(type, 0, array_size); + llvm_val[stmt] = create_entry_block_alloca(type, stmt->ret_type.is_pointer()); } } else { TI_ASSERT(stmt->width() == 1); From f43d2a80928529820b08058de7ab8e5529f8e574 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Aug 2022 18:59:26 +0000 Subject: [PATCH 32/92] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/codegen/llvm/codegen_llvm.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 4bb9c9f12c2fc..c91a072898230 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -147,7 +147,8 @@ void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { tlctx->get_data_type(tensor_type->get_element_type()), 0); llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type); } else { - llvm_val[stmt] = create_entry_block_alloca(type, stmt->ret_type.is_pointer()); + llvm_val[stmt] = + create_entry_block_alloca(type, stmt->ret_type.is_pointer()); } } else { TI_ASSERT(stmt->width() == 1); From c61331885b604a676ff3e35e852da30e93d70a36 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 17:19:09 -0400 Subject: [PATCH 33/92] check real matrix when allocating memory --- taichi/codegen/llvm/codegen_llvm.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 4bb9c9f12c2fc..9ac9b286ffdaa 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -124,8 +124,9 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) { void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { if (stmt->ret_type->is()) { auto tensor_type = stmt->ret_type->cast(); - auto type = tlctx->get_data_type(tensor_type); - // Return type is vector*. + auto type = kernel->program->config.real_matrix ? tlctx->get_data_type(tensor_type) : tlctx->get_data_type(tensor_type->get_element_type()); + // Return type is vector* if use real matrix. + // otherwise the return type is [type * array_size]* if (stmt->is_shared) { size_t data_element_size = tlctx->get_type_size( tlctx->get_data_type(tensor_type->get_element_type())); @@ -147,7 +148,12 @@ void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { tlctx->get_data_type(tensor_type->get_element_type()), 0); llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type); } else { - llvm_val[stmt] = create_entry_block_alloca(type, stmt->ret_type.is_pointer()); + if (kernel->program->config.real_matrix) + llvm_val[stmt] = + create_entry_block_alloca(type, stmt->ret_type.is_pointer()); + else + llvm_val[stmt] = + create_entry_block_alloca(type, 0, tlctx->get_constant(tensor_type->get_num_elements())); } } else { TI_ASSERT(stmt->width() == 1); From 9fc758d06ffb0d6f45186225162088b41ce74d0f Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 17:23:05 -0400 Subject: [PATCH 34/92] format and fix tc for variable holding matrix expression --- taichi/codegen/llvm/codegen_llvm.cpp | 8 +++++--- taichi/ir/frontend_ir.cpp | 3 +++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 9ac9b286ffdaa..a6ddb17d88555 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -124,7 +124,9 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) { void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { if (stmt->ret_type->is()) { auto tensor_type = stmt->ret_type->cast(); - auto type = kernel->program->config.real_matrix ? tlctx->get_data_type(tensor_type) : tlctx->get_data_type(tensor_type->get_element_type()); + auto type = kernel->program->config.real_matrix + ? tlctx->get_data_type(tensor_type) + : tlctx->get_data_type(tensor_type->get_element_type()); // Return type is vector* if use real matrix. // otherwise the return type is [type * array_size]* if (stmt->is_shared) { @@ -152,8 +154,8 @@ void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { llvm_val[stmt] = create_entry_block_alloca(type, stmt->ret_type.is_pointer()); else - llvm_val[stmt] = - create_entry_block_alloca(type, 0, tlctx->get_constant(tensor_type->get_num_elements())); + llvm_val[stmt] = create_entry_block_alloca( + type, 0, tlctx->get_constant(tensor_type->get_num_elements())); } } else { TI_ASSERT(stmt->width() == 1); diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 7627fceb79cd9..09f8c33e3b569 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -1005,6 +1005,9 @@ Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, auto matrix_expr = make_local_matrix(shape, element_type, elements.exprs); auto v = this->expr_alloca(); this->expr_assign(v, matrix_expr, tb); + // type check for variable `v` since + // expr_assign couldn't propagate the info + v->ret_type = matrix_expr.cast()->dt; return v; } auto var = Expr(std::make_shared(get_next_id())); From f635b7c20d46f2f48bcc21f22c567255dcd1aa24 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 25 Aug 2022 00:23:20 -0400 Subject: [PATCH 35/92] refactor: change to `make_local_matrix` which returns only an Expr; postpone var alloca to ast transformer --- python/taichi/lang/impl.py | 6 ++++++ python/taichi/lang/matrix.py | 15 +++++---------- taichi/analysis/data_source_analysis.cpp | 2 +- taichi/ir/frontend_ir.cpp | 20 +++++--------------- taichi/ir/frontend_ir.h | 9 +++++---- taichi/program/function.cpp | 3 +-- taichi/program/kernel.cpp | 3 +-- taichi/python/export_lang.cpp | 1 + 8 files changed, 25 insertions(+), 34 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 3fd16ec2ade16..1727d4d5e44cf 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -37,6 +37,12 @@ def expr_init_local_tensor(shape, element_type, elements): get_runtime().get_current_src_info()) +@taichi_scope +def make_local_matrix(shape, element_type, elements): + return get_runtime().prog.current_ast_builder().make_local_matrix( + shape, element_type, elements) + + @taichi_scope def expr_init_shared_array(shape, element_type): return get_runtime().prog.current_ast_builder().expr_alloca_shared_array( diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 336fb534978d9..026a966de024b 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -104,16 +104,11 @@ def make_matrix(arr, dt=None): if dt is None: dt = _make_entries_initializer(is_matrix).infer_dt(arr) if not is_matrix: - return impl.Expr( - impl.expr_init_local_tensor([len(arr)], dt, - impl.make_expr_group( - [expr.Expr(elt) for elt in arr]))) - return impl.Expr( - impl.expr_init_local_tensor([len(arr), len(arr[0])], dt, - impl.make_expr_group([ - expr.Expr(elt) for row in arr - for elt in row - ]))) + return impl.make_local_matrix([len(arr)], dt, + [expr.Expr(elt).ptr for elt in arr]) + return impl.make_local_matrix( + [len(arr), len(arr[0])], dt, + [expr.Expr(elt).ptr for row in arr for elt in row]) class _MatrixBaseImpl: diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index 8bdeb72e9a70b..4a018afa6bf47 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -59,7 +59,7 @@ Stmt *get_store_data(Stmt *store_stmt) { std::vector get_store_destination(Stmt *store_stmt) { // If store_stmt provides some data sources, return the pointers of the data. - if (store_stmt->is()) { + if (store_stmt->is() && !store_stmt->ret_type->is()) { // The statement itself provides a data source (const [0]). return std::vector(1, store_stmt); } else if (auto local_store = store_stmt->cast()) { diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 09f8c33e3b569..1cc8404c73c77 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -87,10 +87,9 @@ FrontendForStmt::FrontendForStmt(const ExprGroup &loop_var, } } -FrontendContext::FrontendContext(Arch arch, bool real_matrix) { +FrontendContext::FrontendContext(Arch arch) { root_node_ = std::make_unique(); - current_builder_ = - std::make_unique(root_node_.get(), arch, real_matrix); + current_builder_ = std::make_unique(root_node_.get(), arch); } FrontendForStmt::FrontendForStmt(const Expr &loop_var, @@ -991,9 +990,9 @@ Expr ASTBuilder::expr_alloca() { return var; } -Expr make_local_matrix(const std::vector &shape, - const DataType &dt, - const std::vector &elements) { +Expr ASTBuilder::make_local_matrix(const std::vector &shape, + const DataType &dt, + const std::vector &elements) { return Expr(std::make_shared(elements, shape, dt)); } @@ -1001,15 +1000,6 @@ Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements, std::string tb) { - if (this->use_real_matrix_) { - auto matrix_expr = make_local_matrix(shape, element_type, elements.exprs); - auto v = this->expr_alloca(); - this->expr_assign(v, matrix_expr, tb); - // type check for variable `v` since - // expr_assign couldn't propagate the info - v->ret_type = matrix_expr.cast()->dt; - return v; - } auto var = Expr(std::make_shared(get_next_id())); this->insert(std::make_unique( std::static_pointer_cast(var.expr)->id, shape, diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 66b99e9772996..29d07e3e24cdd 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -867,11 +867,9 @@ class ASTBuilder { Arch arch_; ForLoopDecoratorRecorder for_loop_dec_; int id_counter_{0}; - bool use_real_matrix_{false}; public: - ASTBuilder(Block *initial, Arch arch, bool real_matrix) - : arch_(arch), use_real_matrix_(real_matrix) { + ASTBuilder(Block *initial, Arch arch) : arch_(arch) { stack_.push_back(initial); loop_state_stack_.push_back(None); } @@ -890,6 +888,9 @@ class ASTBuilder { const std::function &func); Expr make_id_expr(const std::string &name); + Expr make_local_matrix(const std::vector &shape, + const DataType &dt, + const std::vector &elements); Expr insert_thread_idx_expr(); Expr insert_patch_idx_expr(); void create_kernel_exprgroup_return(const ExprGroup &group); @@ -972,7 +973,7 @@ class FrontendContext { std::unique_ptr root_node_; public: - FrontendContext(Arch arch, bool real_matrix); + FrontendContext(Arch arch); ASTBuilder &builder() { return *current_builder_; diff --git a/taichi/program/function.cpp b/taichi/program/function.cpp index 64dd970cde4ea..e2655fed02432 100644 --- a/taichi/program/function.cpp +++ b/taichi/program/function.cpp @@ -12,8 +12,7 @@ Function::Function(Program *program, const FunctionKey &func_key) } void Function::set_function_body(const std::function &func) { - context = std::make_unique(program->config.arch, - program->config.real_matrix); + context = std::make_unique(program->config.arch); ir = context->get_root(); { // Note: this is not a mutex diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 0276a71269cf1..f6f0115955634 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -391,8 +391,7 @@ void Kernel::init(Program &program, is_accessor = false; is_evaluator = false; compiled_ = nullptr; - context = std::make_unique(program.config.arch, - program.config.real_matrix); + context = std::make_unique(program.config.arch); ir = context->get_root(); ir_is_ast_ = true; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 569e631611d85..29af61a2f9166 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -291,6 +291,7 @@ void export_lang(py::module &m) { .def("insert_deactivate", &ASTBuilder::insert_snode_deactivate) .def("insert_activate", &ASTBuilder::insert_snode_activate) .def("insert_external_func_call", &ASTBuilder::insert_external_func_call) + .def("make_local_matrix", &ASTBuilder::make_local_matrix) .def("expr_alloca", &ASTBuilder::expr_alloca) .def("expr_alloca_local_tensor", &ASTBuilder::expr_alloca_local_tensor) .def("expr_alloca_shared_array", &ASTBuilder::expr_alloca_shared_array) From f82dc259c6b7a34ff53c5544dc30a2593a478ec9 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 25 Aug 2022 00:24:27 -0400 Subject: [PATCH 36/92] get rid of duplicated check --- taichi/transforms/alg_simp.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index f7830fad9d850..5fc782c094842 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -100,11 +100,6 @@ class AlgSimp : public BasicStmtVisitor { bool optimize_multiplication(BinaryOpStmt *stmt) { // return true iff the IR is modified - if (stmt->lhs->ret_type->is() || - stmt->rhs->ret_type->is()) { - // TODO: support tensor type - return false; - } auto lhs = stmt->lhs->cast(); auto rhs = stmt->rhs->cast(); TI_ASSERT(stmt->op_type == BinaryOpType::mul); From bd68c2369f26b04edfe58aecf7604686b307bbcc Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 25 Aug 2022 01:42:48 -0400 Subject: [PATCH 37/92] save changes --- python/taichi/lang/impl.py | 2 +- python/taichi/lang/matrix.py | 8 ++++---- taichi/transforms/alg_simp.cpp | 5 ----- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 1727d4d5e44cf..632cbc0e95d42 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -38,7 +38,7 @@ def expr_init_local_tensor(shape, element_type, elements): @taichi_scope -def make_local_matrix(shape, element_type, elements): +def make_matrix_expr(shape, element_type, elements): return get_runtime().prog.current_ast_builder().make_local_matrix( shape, element_type, elements) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 026a966de024b..887637d5bc306 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -104,11 +104,11 @@ def make_matrix(arr, dt=None): if dt is None: dt = _make_entries_initializer(is_matrix).infer_dt(arr) if not is_matrix: - return impl.make_local_matrix([len(arr)], dt, - [expr.Expr(elt).ptr for elt in arr]) - return impl.make_local_matrix( + return impl.Expr(impl.make_matrix_expr([len(arr)], dt, + [expr.Expr(elt).ptr for elt in arr])) + return impl.Expr(impl.make_matrix_expr( [len(arr), len(arr[0])], dt, - [expr.Expr(elt).ptr for row in arr for elt in row]) + [expr.Expr(elt).ptr for row in arr for elt in row])) class _MatrixBaseImpl: diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 5fc782c094842..b8bdf2a32f698 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -151,11 +151,6 @@ class AlgSimp : public BasicStmtVisitor { bool optimize_division(BinaryOpStmt *stmt) { // return true iff the IR is modified - if (stmt->lhs->ret_type->is() || - stmt->rhs->ret_type->is()) { - // TODO: support tensor type - return false; - } auto rhs = stmt->rhs->cast(); TI_ASSERT(stmt->op_type == BinaryOpType::div || stmt->op_type == BinaryOpType::floordiv); From 5d00c98a6548815623150382f8ea4931d052c3da Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 25 Aug 2022 01:43:21 -0400 Subject: [PATCH 38/92] format --- python/taichi/lang/matrix.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 887637d5bc306..2bbb23c1acf93 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -104,11 +104,13 @@ def make_matrix(arr, dt=None): if dt is None: dt = _make_entries_initializer(is_matrix).infer_dt(arr) if not is_matrix: - return impl.Expr(impl.make_matrix_expr([len(arr)], dt, - [expr.Expr(elt).ptr for elt in arr])) - return impl.Expr(impl.make_matrix_expr( - [len(arr), len(arr[0])], dt, - [expr.Expr(elt).ptr for row in arr for elt in row])) + return impl.Expr( + impl.make_matrix_expr([len(arr)], dt, + [expr.Expr(elt).ptr for elt in arr])) + return impl.Expr( + impl.make_matrix_expr( + [len(arr), len(arr[0])], dt, + [expr.Expr(elt).ptr for row in arr for elt in row])) class _MatrixBaseImpl: From 3451397a0c94a69ddf981b0b2c5758236d1b7a36 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 25 Aug 2022 02:02:27 -0400 Subject: [PATCH 39/92] also rename cxx part --- taichi/ir/frontend_ir.cpp | 6 +++--- taichi/ir/frontend_ir.h | 6 +++--- taichi/python/export_lang.cpp | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 1cc8404c73c77..9ef38c9701896 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -990,9 +990,9 @@ Expr ASTBuilder::expr_alloca() { return var; } -Expr ASTBuilder::make_local_matrix(const std::vector &shape, - const DataType &dt, - const std::vector &elements) { +Expr ASTBuilder::make_matrix_expr(const std::vector &shape, + const DataType &dt, + const std::vector &elements) { return Expr(std::make_shared(elements, shape, dt)); } diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 29d07e3e24cdd..d4da9b635820f 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -888,9 +888,9 @@ class ASTBuilder { const std::function &func); Expr make_id_expr(const std::string &name); - Expr make_local_matrix(const std::vector &shape, - const DataType &dt, - const std::vector &elements); + Expr make_matrix_expr(const std::vector &shape, + const DataType &dt, + const std::vector &elements); Expr insert_thread_idx_expr(); Expr insert_patch_idx_expr(); void create_kernel_exprgroup_return(const ExprGroup &group); diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 29af61a2f9166..89ad1eaeb2a4f 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -291,7 +291,7 @@ void export_lang(py::module &m) { .def("insert_deactivate", &ASTBuilder::insert_snode_deactivate) .def("insert_activate", &ASTBuilder::insert_snode_activate) .def("insert_external_func_call", &ASTBuilder::insert_external_func_call) - .def("make_local_matrix", &ASTBuilder::make_local_matrix) + .def("make_local_matrix", &ASTBuilder::make_matrix_expr) .def("expr_alloca", &ASTBuilder::expr_alloca) .def("expr_alloca_local_tensor", &ASTBuilder::expr_alloca_local_tensor) .def("expr_alloca_shared_array", &ASTBuilder::expr_alloca_shared_array) From 5ed0e9300ef7e5f332d5cef60bafac199da34d23 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Thu, 25 Aug 2022 15:15:16 +0800 Subject: [PATCH 40/92] Apply suggestions from code review --- python/taichi/lang/impl.py | 2 +- taichi/python/export_lang.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 632cbc0e95d42..027696989dc9d 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -39,7 +39,7 @@ def expr_init_local_tensor(shape, element_type, elements): @taichi_scope def make_matrix_expr(shape, element_type, elements): - return get_runtime().prog.current_ast_builder().make_local_matrix( + return get_runtime().prog.current_ast_builder().make_matrix_expr( shape, element_type, elements) diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 89ad1eaeb2a4f..8b44098656b3d 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -291,7 +291,7 @@ void export_lang(py::module &m) { .def("insert_deactivate", &ASTBuilder::insert_snode_deactivate) .def("insert_activate", &ASTBuilder::insert_snode_activate) .def("insert_external_func_call", &ASTBuilder::insert_external_func_call) - .def("make_local_matrix", &ASTBuilder::make_matrix_expr) + .def("make_matrix_expr", &ASTBuilder::make_matrix_expr) .def("expr_alloca", &ASTBuilder::expr_alloca) .def("expr_alloca_local_tensor", &ASTBuilder::expr_alloca_local_tensor) .def("expr_alloca_shared_array", &ASTBuilder::expr_alloca_shared_array) From 0fee70767c68a36a0325820a52e9efd4993e4f82 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Mon, 15 Aug 2022 16:51:37 -0400 Subject: [PATCH 41/92] add indexing p1 --- python/taichi/lang/expr.py | 9 +++++++++ taichi/codegen/llvm/codegen_llvm.cpp | 26 ++++++++++++++++++++++++-- taichi/ir/frontend_ir.cpp | 25 +++++++++++++++++++++++++ taichi/ir/frontend_ir.h | 1 + taichi/ir/statements.h | 7 ++++++- taichi/python/export_lang.cpp | 1 + 6 files changed, 66 insertions(+), 3 deletions(-) diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 7bc2b5d8a66b2..af2df90fc1b79 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -38,6 +38,15 @@ def __init__(self, *args, tb=None, dtype=None): if self.tb: self.ptr.set_tb(self.tb) self.ptr.type_check(impl.get_runtime().prog.config) + + def __getitem__(self, *indices): + if not isinstance(indices, (list, tuple)): + indices = (indices, ) + + indices = make_expr_group(*indices) + return Expr( + impl.get_runtime().prog.current_ast_builder().expr_indexed_matrix( + self.ptr, indices)) def __hash__(self): return self.ptr.get_raw_address() diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index a6ddb17d88555..79e858967d280 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1761,8 +1761,30 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { llvm_val[stmt] = builder->CreateGEP(ptr_ty, llvm_val[stmt->origin], llvm_val[stmt->offset]); #else - llvm_val[stmt] = - builder->CreateGEP(llvm_val[stmt->origin], llvm_val[stmt->offset]); + if (stmt->origin->ret_type->is() || + (stmt->origin->ret_type->is() && + stmt->origin->ret_type->cast() + ->get_pointee_type() + ->is())) { + TensorType *stmt_dtype; + if (stmt->origin->ret_type->is()) { + stmt_dtype = stmt->origin->ret_type->cast() + ->get_pointee_type() + ->cast(); + } else { + stmt_dtype = stmt->origin->ret_type->cast(); + } + auto element_dtype = stmt_dtype->get_element_type(); + auto llvm_type = tlctx->get_data_type(element_dtype); + auto casted_ptr = builder->CreateBitCast( + llvm_val[stmt->origin], llvm::PointerType::get(llvm_type, 0)); + llvm_val[stmt] = builder->CreateBitCast( + builder->CreateGEP(casted_ptr, llvm_val[stmt->offset]), + llvm::PointerType::get(llvm_type, 0)); + } else { + llvm_val[stmt] = + builder->CreateGEP(llvm_val[stmt->origin], llvm_val[stmt->offset]); + } #endif } else { auto origin_address = builder->CreatePtrToInt( diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 9ef38c9701896..fdf081b071c2a 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -415,6 +415,9 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, std::vector shape, int stride) { flatten_lvalue(var, ctx); + if (var->stmt->ret_type->is_primitive(PrimitiveTypeID::unknown)) { + var->stmt->ret_type = var->ret_type; + } Stmt *offset_stmt = ctx->push_back(TypedConstant(0)); for (int i = 0; i < (int)indices.size(); ++i) { flatten_rvalue(indices[i], ctx); @@ -996,6 +999,28 @@ Expr ASTBuilder::make_matrix_expr(const std::vector &shape, return Expr(std::make_shared(elements, shape, dt)); } +Expr ASTBuilder::expr_indexed_matrix(const Expr &matrix, + const ExprGroup &indices) { + TI_ASSERT(matrix.get_ret_type()->is()); + auto shape = matrix.get_ret_type()->as()->get_shape(); + if (indices.size() != shape.size()) { + std::string shape_str = "["; + if (shape.size() > 0) { + shape_str += std::to_string(shape[0]); + for (int i = 1; i < shape.size(); i++) { + shape_str += ", " + std::to_string(shape[i]); + } + } + shape_str += "]"; + TI_ERROR( + "Indexed matrix of shape {} has wrong number of indices. Expected {} " + "but got " + "{}.", + shape_str, shape.size(), indices.size()); + } + return Expr(std::make_shared(matrix, indices)); +} + Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements, diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index d4da9b635820f..9a897ccc5f907 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -913,6 +913,7 @@ class ASTBuilder { std::string tb); Expr expr_alloca_shared_array(const std::vector &shape, const DataType &element_type); + Expr expr_indexed_matrix(const Expr &matrix, const ExprGroup &indices); void expr_assign(const Expr &lhs, const Expr &rhs, std::string tb); void create_assert_stmt(const Expr &cond, const std::string &msg, diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 11c87dd98b67c..61e8c0719e72c 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -367,7 +367,12 @@ class PtrOffsetStmt : public Stmt { bool is_local_ptr() const { if (origin->is() || origin->is()) { - TI_ASSERT_INFO(origin->ret_type->is(), + auto is_tensor_type = origin->ret_type->is() + ? origin->ret_type->cast() + ->get_pointee_type() + ->is() + : origin->ret_type->is(); + TI_ASSERT_INFO(is_tensor_type, "PtrOffsetStmt can only be used for Alloca (TensorType)."); } return origin->is() || origin->is(); diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 8b44098656b3d..ad1f9efea31d8 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -294,6 +294,7 @@ void export_lang(py::module &m) { .def("make_matrix_expr", &ASTBuilder::make_matrix_expr) .def("expr_alloca", &ASTBuilder::expr_alloca) .def("expr_alloca_local_tensor", &ASTBuilder::expr_alloca_local_tensor) + .def("expr_indexed_matrix", &ASTBuilder::expr_indexed_matrix) .def("expr_alloca_shared_array", &ASTBuilder::expr_alloca_shared_array) .def("create_assert_stmt", &ASTBuilder::create_assert_stmt) .def("expr_assign", &ASTBuilder::expr_assign) From 25e886c30ce30e91d2de141be5bc5b40412db85b Mon Sep 17 00:00:00 2001 From: AD1024 Date: Mon, 15 Aug 2022 17:18:18 -0400 Subject: [PATCH 42/92] pick fixes --- python/taichi/lang/expr.py | 2 +- taichi/codegen/llvm/codegen_llvm.cpp | 12 ++++++------ taichi/ir/statements.h | 8 ++++---- taichi/transforms/type_check.cpp | 3 +-- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index af2df90fc1b79..3aa0f11797b00 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -38,7 +38,7 @@ def __init__(self, *args, tb=None, dtype=None): if self.tb: self.ptr.set_tb(self.tb) self.ptr.type_check(impl.get_runtime().prog.config) - + def __getitem__(self, *indices): if not isinstance(indices, (list, tuple)): indices = (indices, ) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 79e858967d280..5daccc57cbf09 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1762,15 +1762,15 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { llvm_val[stmt->offset]); #else if (stmt->origin->ret_type->is() || - (stmt->origin->ret_type->is() && - stmt->origin->ret_type->cast() - ->get_pointee_type() - ->is())) { + (stmt->origin->ret_type->is() && + stmt->origin->ret_type->cast() + ->get_pointee_type() + ->is())) { TensorType *stmt_dtype; if (stmt->origin->ret_type->is()) { stmt_dtype = stmt->origin->ret_type->cast() - ->get_pointee_type() - ->cast(); + ->get_pointee_type() + ->cast(); } else { stmt_dtype = stmt->origin->ret_type->cast(); } diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 61e8c0719e72c..bbfe9a6d42682 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -368,10 +368,10 @@ class PtrOffsetStmt : public Stmt { bool is_local_ptr() const { if (origin->is() || origin->is()) { auto is_tensor_type = origin->ret_type->is() - ? origin->ret_type->cast() - ->get_pointee_type() - ->is() - : origin->ret_type->is(); + ? origin->ret_type->cast() + ->get_pointee_type() + ->is() + : origin->ret_type->is(); TI_ASSERT_INFO(is_tensor_type, "PtrOffsetStmt can only be used for Alloca (TensorType)."); } diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index b6cb7112e1cf8..192835ad89bb5 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -537,8 +537,7 @@ class TypeCheck : public IRVisitor { } void visit(GlobalTemporaryStmt *stmt) override { - if (!stmt->ret_type->is()) - stmt->ret_type.set_is_pointer(true); + stmt->ret_type.set_is_pointer(true); } void visit(InternalFuncStmt *stmt) override { From e5c8c9a4912714a5eddb1a0ecdb9f36d54f1c26d Mon Sep 17 00:00:00 2001 From: AD1024 Date: Mon, 15 Aug 2022 17:44:50 -0400 Subject: [PATCH 43/92] fin indexing --- taichi/ir/statements.h | 10 +++++++++- taichi/ir/type.h | 6 ++++++ taichi/ir/type_utils.h | 4 ++++ taichi/transforms/alg_simp.cpp | 5 +++++ taichi/transforms/type_check.cpp | 14 +++++++++++++- 5 files changed, 37 insertions(+), 2 deletions(-) diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index bbfe9a6d42682..17f6756de1991 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -606,8 +606,16 @@ class GlobalStoreStmt : public Stmt { class LocalLoadStmt : public Stmt { public: LaneAttribute src; + std::vector shape; - explicit LocalLoadStmt(const LaneAttribute &src) : src(src) { + explicit LocalLoadStmt(const LaneAttribute &src) + : src(src), shape({static_cast(src.data.size())}) { + TI_STMT_REG_FIELDS; + } + + LocalLoadStmt(const LaneAttribute &src, + const std::vector &shape) + : src(src), shape(shape) { TI_STMT_REG_FIELDS; } diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 339e2553ffb32..a58701cbf2e1b 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -380,6 +380,12 @@ class TypedConstant { } TypedConstant(DataType dt) : dt(dt) { + if (!dt->is()) { + assert(false); + } + TI_ASSERT_INFO(dt->is(), + "TypedConstant can only be PrimitiveType, got {}", + dt->to_string()); value_bits = 0; } diff --git a/taichi/ir/type_utils.h b/taichi/ir/type_utils.h index 6d3b97154c94f..74423eaaebdfc 100644 --- a/taichi/ir/type_utils.h +++ b/taichi/ir/type_utils.h @@ -87,6 +87,8 @@ inline bool is_real(DataType dt) { } inline bool is_integral(DataType dt) { + if (dt->is()) + return is_integral(dt->as()->get_element_type()); return dt->is_primitive(PrimitiveTypeID::i8) || dt->is_primitive(PrimitiveTypeID::i16) || dt->is_primitive(PrimitiveTypeID::i32) || @@ -100,6 +102,8 @@ inline bool is_integral(DataType dt) { inline bool is_signed(DataType dt) { // Shall we return false if is_integral returns false? TI_ASSERT(is_integral(dt)); + if (auto t = dt->cast()) + return is_signed(t->get_element_type()); if (auto t = dt->cast()) return t->get_is_signed(); return dt->is_primitive(PrimitiveTypeID::i8) || diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index b8bdf2a32f698..4d56719294e85 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -112,6 +112,11 @@ class AlgSimp : public BasicStmtVisitor { if ((fast_math || is_integral(stmt->ret_type)) && (alg_is_zero(lhs) || alg_is_zero(rhs))) { // fast_math or integral operands: 0 * a -> 0, a * 0 -> 0 + if (stmt->ret_type->is() || + stmt->rhs->ret_type->is()) { + // TODO: handle 0-tensor + return false; + } replace_with_zero(stmt); return true; } diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 192835ad89bb5..28dc915b774b8 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -107,9 +107,21 @@ class TypeCheck : public IRVisitor { .ptr_removed(); stmt->ret_type = lookup; } - } else { + } else if (stmt->src.size() == 1) { auto lookup = stmt->src[0].var->ret_type; stmt->ret_type = lookup; + } else { + TI_ASSERT(stmt->src.size() > 1); + auto acc = stmt->src[0].var->ret_type; + for (int i = 1; i < stmt->src.size(); i++) { + acc = promoted_type(acc, stmt->src[i].var->ret_type); + } + if (stmt->ret_type != PrimitiveType::unknown) { + TI_ASSERT(stmt->ret_type->is()); + acc = promoted_type( + acc, stmt->ret_type->as()->get_element_type()); + } + stmt->ret_type = TypeFactory::create_tensor_type(stmt->shape, acc); } } From 6b139dc206af028e869f6bc25a2b4b64dcc7d593 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Aug 2022 21:49:17 +0000 Subject: [PATCH 44/92] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/transforms/alg_simp.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 4d56719294e85..39cb43a4a2609 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -113,7 +113,7 @@ class AlgSimp : public BasicStmtVisitor { (alg_is_zero(lhs) || alg_is_zero(rhs))) { // fast_math or integral operands: 0 * a -> 0, a * 0 -> 0 if (stmt->ret_type->is() || - stmt->rhs->ret_type->is()) { + stmt->rhs->ret_type->is()) { // TODO: handle 0-tensor return false; } From afa66c9ff8ea3b1e904069338df66d3b9fa51887 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 14:51:13 -0400 Subject: [PATCH 45/92] some fixes --- taichi/analysis/data_source_analysis.cpp | 2 ++ taichi/ir/frontend_ir.cpp | 33 ++++++++++++------------ 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index 4a018afa6bf47..93c740646e911 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -37,6 +37,8 @@ std::vector get_load_pointers(Stmt *load_stmt) { return external_func->arg_stmts; } else if (auto ref = load_stmt->cast()) { return {ref->var}; + } else if (auto ptr_offset = load_stmt->cast()) { + return {ptr_offset->origin}; } else { return std::vector(); } diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index fdf081b071c2a..b47f693baeae6 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -483,6 +483,22 @@ void IndexExpression::type_check(CompileConfig *) { } else if (is_ndarray()) { // ndarray ret_type = var.cast()->dt; } else if (is_tensor()) { // local tensor + auto shape = var.get_ret_type()->as()->get_shape(); + if (indices.size() != shape.size()) { + std::string shape_str = "["; + if (shape.size() > 0) { + shape_str += std::to_string(shape[0]); + for (int i = 1; i < shape.size(); i++) { + shape_str += ", " + std::to_string(shape[i]); + } + } + shape_str += "]"; + TI_ERROR( + "Indexed matrix of shape {} has wrong number of indices. Expected {} " + "but got " + "{}.", + shape_str, shape.size(), indices.size()); + } ret_type = var->ret_type->cast()->get_element_type(); } else { throw TaichiTypeError( @@ -1001,23 +1017,6 @@ Expr ASTBuilder::make_matrix_expr(const std::vector &shape, Expr ASTBuilder::expr_indexed_matrix(const Expr &matrix, const ExprGroup &indices) { - TI_ASSERT(matrix.get_ret_type()->is()); - auto shape = matrix.get_ret_type()->as()->get_shape(); - if (indices.size() != shape.size()) { - std::string shape_str = "["; - if (shape.size() > 0) { - shape_str += std::to_string(shape[0]); - for (int i = 1; i < shape.size(); i++) { - shape_str += ", " + std::to_string(shape[i]); - } - } - shape_str += "]"; - TI_ERROR( - "Indexed matrix of shape {} has wrong number of indices. Expected {} " - "but got " - "{}.", - shape_str, shape.size(), indices.size()); - } return Expr(std::make_shared(matrix, indices)); } From dc144cecb3f1f67454b592ef36521163fed5317a Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 25 Aug 2022 15:16:03 -0400 Subject: [PATCH 46/92] fix assignment --- taichi/analysis/data_source_analysis.cpp | 2 -- taichi/ir/control_flow_graph.cpp | 11 ++++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index 93c740646e911..4a018afa6bf47 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -37,8 +37,6 @@ std::vector get_load_pointers(Stmt *load_stmt) { return external_func->arg_stmts; } else if (auto ref = load_stmt->cast()) { return {ref->var}; - } else if (auto ptr_offset = load_stmt->cast()) { - return {ptr_offset->origin}; } else { return std::vector(); } diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 3c4ddfedf2cac..0181c51331f81 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -147,7 +147,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // The UD-chain is inside this node. Stmt *result = irpass::analysis::get_store_data( block->statements[last_def_position].get()); - if (!var->is()) { + if (!var->is() || var->ret_type->is()) { for (int i = last_def_position + 1; i < position; i++) { if (!irpass::analysis::same_value( result, @@ -241,6 +241,15 @@ void CFGNode::reaching_definition_analysis(bool after_lower_access) { // loop in reversed order auto stmt = block->statements[i].get(); auto data_source_ptrs = irpass::analysis::get_store_destination(stmt); + if (auto local_store = stmt->cast()) { + if (auto dest = local_store->dest->cast()) { + if (auto data = get_store_forwarding_data(dest->origin, i)) { + data_source_ptrs = std::vector(1, data); + } else { + data_source_ptrs = std::vector(); + } + } + } for (auto data_source_ptr : data_source_ptrs) { // stmt provides a data source if (after_lower_access && !(data_source_ptr->is())) { From b1c3cc9d36a114426241104baf2dcaf44af6165c Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 25 Aug 2022 15:26:55 -0400 Subject: [PATCH 47/92] fix indexing --- taichi/analysis/data_source_analysis.cpp | 2 ++ taichi/ir/control_flow_graph.cpp | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index 4a018afa6bf47..93c740646e911 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -37,6 +37,8 @@ std::vector get_load_pointers(Stmt *load_stmt) { return external_func->arg_stmts; } else if (auto ref = load_stmt->cast()) { return {ref->var}; + } else if (auto ptr_offset = load_stmt->cast()) { + return {ptr_offset->origin}; } else { return std::vector(); } diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 0181c51331f81..f549b444c60de 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -510,7 +510,7 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { } } auto load_ptrs = irpass::analysis::get_load_pointers(stmt); - if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1) { + if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1 && !stmt->is()) { // Identical load elimination auto load_ptr = load_ptrs.front(); if (!after_lower_access || From ca5005046a4d8eb82e491e9d14e6bc53adf10ede Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Aug 2022 21:57:10 +0000 Subject: [PATCH 48/92] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/ir/control_flow_graph.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index f549b444c60de..0f580d878f752 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -510,7 +510,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { } } auto load_ptrs = irpass::analysis::get_load_pointers(stmt); - if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1 && !stmt->is()) { + if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1 && + !stmt->is()) { // Identical load elimination auto load_ptr = load_ptrs.front(); if (!after_lower_access || From 02327d7bfee93769ad6e839ac95cc244f4c8ba25 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Fri, 26 Aug 2022 00:53:39 -0400 Subject: [PATCH 49/92] reuse code --- python/taichi/lang/expr.py | 5 +---- taichi/ir/frontend_ir.cpp | 5 ----- taichi/ir/frontend_ir.h | 1 - taichi/python/export_lang.cpp | 1 - 4 files changed, 1 insertion(+), 11 deletions(-) diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 3aa0f11797b00..70787185c70e0 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -43,10 +43,7 @@ def __getitem__(self, *indices): if not isinstance(indices, (list, tuple)): indices = (indices, ) - indices = make_expr_group(*indices) - return Expr( - impl.get_runtime().prog.current_ast_builder().expr_indexed_matrix( - self.ptr, indices)) + return impl.make_index_expr(self.ptr, indices) def __hash__(self): return self.ptr.get_raw_address() diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index b47f693baeae6..77e985aa5e121 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -1015,11 +1015,6 @@ Expr ASTBuilder::make_matrix_expr(const std::vector &shape, return Expr(std::make_shared(elements, shape, dt)); } -Expr ASTBuilder::expr_indexed_matrix(const Expr &matrix, - const ExprGroup &indices) { - return Expr(std::make_shared(matrix, indices)); -} - Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements, diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 9a897ccc5f907..d4da9b635820f 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -913,7 +913,6 @@ class ASTBuilder { std::string tb); Expr expr_alloca_shared_array(const std::vector &shape, const DataType &element_type); - Expr expr_indexed_matrix(const Expr &matrix, const ExprGroup &indices); void expr_assign(const Expr &lhs, const Expr &rhs, std::string tb); void create_assert_stmt(const Expr &cond, const std::string &msg, diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index a871af19da3a1..80accd4b17f74 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -294,7 +294,6 @@ void export_lang(py::module &m) { .def("make_matrix_expr", &ASTBuilder::make_matrix_expr) .def("expr_alloca", &ASTBuilder::expr_alloca) .def("expr_alloca_local_tensor", &ASTBuilder::expr_alloca_local_tensor) - .def("expr_indexed_matrix", &ASTBuilder::expr_indexed_matrix) .def("expr_alloca_shared_array", &ASTBuilder::expr_alloca_shared_array) .def("create_assert_stmt", &ASTBuilder::create_assert_stmt) .def("expr_assign", &ASTBuilder::expr_assign) From 6e8960773805aa89722b405c85aec508872baa6f Mon Sep 17 00:00:00 2001 From: AD1024 Date: Fri, 26 Aug 2022 14:09:18 -0400 Subject: [PATCH 50/92] fix compilation --- taichi/ir/control_flow_graph.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 757fae8922075..d39f45a41e088 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -498,8 +498,7 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { } } auto load_ptrs = irpass::analysis::get_load_pointers(stmt); - if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1 && - !stmt->is()) { + if (load_ptrs.size() == 1 && store_ptrs.empty() && !stmt->is()) { // Identical load elimination auto load_ptr = load_ptrs.front(); if (!after_lower_access || From 991c64cde9724e5aeff61538c4598d576030503f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 26 Aug 2022 18:10:47 +0000 Subject: [PATCH 51/92] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/ir/control_flow_graph.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index d39f45a41e088..542aba6f2292c 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -498,7 +498,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { } } auto load_ptrs = irpass::analysis::get_load_pointers(stmt); - if (load_ptrs.size() == 1 && store_ptrs.empty() && !stmt->is()) { + if (load_ptrs.size() == 1 && store_ptrs.empty() && + !stmt->is()) { // Identical load elimination auto load_ptr = load_ptrs.front(); if (!after_lower_access || From 4f722dec38ebe01968abb34c05804a6b8fb61d0e Mon Sep 17 00:00:00 2001 From: AD1024 Date: Fri, 26 Aug 2022 14:20:22 -0400 Subject: [PATCH 52/92] remove unused code --- taichi/ir/type.h | 3 --- taichi/ir/type_utils.h | 4 ---- taichi/transforms/alg_simp.cpp | 5 ----- taichi/transforms/type_check.cpp | 12 ------------ 4 files changed, 24 deletions(-) diff --git a/taichi/ir/type.h b/taichi/ir/type.h index b9bc9e5273e41..064a7fb0e8462 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -384,9 +384,6 @@ class TypedConstant { } TypedConstant(DataType dt) : dt(dt) { - if (!dt->is()) { - assert(false); - } TI_ASSERT_INFO(dt->is(), "TypedConstant can only be PrimitiveType, got {}", dt->to_string()); diff --git a/taichi/ir/type_utils.h b/taichi/ir/type_utils.h index 30d76b4a4aa3f..4b0a280342919 100644 --- a/taichi/ir/type_utils.h +++ b/taichi/ir/type_utils.h @@ -89,8 +89,6 @@ inline bool is_real(DataType dt) { } inline bool is_integral(DataType dt) { - if (dt->is()) - return is_integral(dt->as()->get_element_type()); return dt->is_primitive(PrimitiveTypeID::i8) || dt->is_primitive(PrimitiveTypeID::i16) || dt->is_primitive(PrimitiveTypeID::i32) || @@ -104,8 +102,6 @@ inline bool is_integral(DataType dt) { inline bool is_signed(DataType dt) { // Shall we return false if is_integral returns false? TI_ASSERT(is_integral(dt)); - if (auto t = dt->cast()) - return is_signed(t->get_element_type()); if (auto t = dt->cast()) return t->get_is_signed(); return dt->is_primitive(PrimitiveTypeID::i8) || diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 62e90a8b73fa8..8ec875f25b1c3 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -112,11 +112,6 @@ class AlgSimp : public BasicStmtVisitor { if ((fast_math || is_integral(stmt->ret_type)) && (alg_is_zero(lhs) || alg_is_zero(rhs))) { // fast_math or integral operands: 0 * a -> 0, a * 0 -> 0 - if (stmt->ret_type->is() || - stmt->rhs->ret_type->is()) { - // TODO: handle 0-tensor - return false; - } replace_with_zero(stmt); return true; } diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index cd6c0d76ec654..f20cc477d6fc2 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -107,18 +107,6 @@ class TypeCheck : public IRVisitor { } else if (stmt->src.size() == 1) { auto lookup = stmt->src[0].var->ret_type; stmt->ret_type = lookup; - } else { - TI_ASSERT(stmt->src.size() > 1); - auto acc = stmt->src[0].var->ret_type; - for (int i = 1; i < stmt->src.size(); i++) { - acc = promoted_type(acc, stmt->src[i].var->ret_type); - } - if (stmt->ret_type != PrimitiveType::unknown) { - TI_ASSERT(stmt->ret_type->is()); - acc = promoted_type( - acc, stmt->ret_type->as()->get_element_type()); - } - stmt->ret_type = TypeFactory::create_tensor_type(stmt->shape, acc); } } From 8a320247d9b29153619ed3c46bb5b5069a5c4b19 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 30 Aug 2022 14:43:02 -0400 Subject: [PATCH 53/92] move to `impl.subscript` --- python/taichi/lang/expr.py | 6 ------ python/taichi/lang/impl.py | 3 +++ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 70787185c70e0..7bc2b5d8a66b2 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -39,12 +39,6 @@ def __init__(self, *args, tb=None, dtype=None): self.ptr.set_tb(self.tb) self.ptr.type_check(impl.get_runtime().prog.config) - def __getitem__(self, *indices): - if not isinstance(indices, (list, tuple)): - indices = (indices, ) - - return impl.make_index_expr(self.ptr, indices) - def __hash__(self): return self.ptr.get_raw_address() diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 4b7e98e623b1c..03540a367bfbc 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -130,6 +130,9 @@ def subscript(value, *_indices, skip_reordered=False, get_ref=False): if isinstance(value, np.ndarray): return value.__getitem__(_indices) + if isinstance(value, Expr): + return make_index_expr(value.ptr, _indices) + if isinstance(value, (tuple, list, dict)): assert len(_indices) == 1 return value[_indices[0]] From 39d1243c01ceb9d1f7da6693dfe12098bb7c6baf Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 30 Aug 2022 15:03:26 -0400 Subject: [PATCH 54/92] add some tests --- tests/python/test_matrix.py | 42 +++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index dda05d9c35855..6f6b0ec43356e 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -699,3 +699,45 @@ def bar(): with pytest.raises(TaichiCompilationError, match=r'Expected 2 indices, got 1'): bar() + + +@test_utils.test(arch=[ti.cuda, ti.cpu, ti.gpu], real_matrix=True) +def test_local_matrix_read(): + + s = ti.field(ti.i32, shape=()) + + @ti.kernel + def get_index(i: ti.i32, j: ti.i32): + mat = ti.Matrix([[x * 3 + y for y in range(3)] for x in range(3)]) + s[None] = mat[i, j] + + for i in range(3): + for j in range(3): + get_index(i, j) + assert s[None] == i * 3 + j + + +@test_utils.test(arch=[ti.cuda, ti.cpu, ti.gpu], real_matrix=True) +def test_local_matrix_indexing_in_loop(): + @ti.kernel + def test(): + mat = ti.Matrix([[x * 3 + y for y in range(3)] for x in range(3)]) + for i in range(3): + for j in range(3): + assert mat[i, j] == i * 3 + j + + test() + + +@test_utils.test(arch=[ti.cuda, ti.cpu, ti.gpu], real_matrix=True) +def test_local_matrix_indexing_ops(): + @ti.kernel + def basic_ops(): + mat = ti.Matrix([[x * 3 + y for y in range(3)] for x in range(3)]) + s = 0 + for i in range(3): + for j in range(3): + s += mat[i, j] + assert s == 72 + + basic_ops() From 9a6461dbdbd9f3776063de5282c1c3a909f03f43 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 30 Aug 2022 15:09:18 -0400 Subject: [PATCH 55/92] enable fetching store destination for tensor types --- taichi/analysis/data_source_analysis.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index a7e1ed5a47730..bf4100554f704 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -56,7 +56,7 @@ Stmt *get_store_data(Stmt *store_stmt) { std::vector get_store_destination(Stmt *store_stmt) { // If store_stmt provides some data sources, return the pointers of the data. - if (store_stmt->is() && !store_stmt->ret_type->is()) { + if (store_stmt->is()) { // The statement itself provides a data source (const [0]). return std::vector(1, store_stmt); } else if (auto local_store = store_stmt->cast()) { From 8bcc6dcd9e706f14297dc0e4b267e0164329f60d Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 31 Aug 2022 16:41:45 -0400 Subject: [PATCH 56/92] disable indexing check for now --- taichi/analysis/data_source_analysis.cpp | 2 +- taichi/ir/frontend_ir.cpp | 16 ---------------- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index bf4100554f704..18287868378b1 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -41,7 +41,7 @@ std::vector get_load_pointers(Stmt *load_stmt) { Stmt *get_store_data(Stmt *store_stmt) { // If store_stmt provides one data source, return the data. - if (store_stmt->is() && !store_stmt->ret_type->is()) { + if (store_stmt->is()) { // For convenience, return store_stmt instead of the const [0] it actually // stores. return store_stmt; diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 624db47e6c2df..17728db1204c9 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -482,22 +482,6 @@ void IndexExpression::type_check(CompileConfig *) { } else if (is_ndarray()) { // ndarray ret_type = var.cast()->dt; } else if (is_tensor()) { // local tensor - auto shape = var.get_ret_type()->as()->get_shape(); - if (indices.size() != shape.size()) { - std::string shape_str = "["; - if (shape.size() > 0) { - shape_str += std::to_string(shape[0]); - for (int i = 1; i < shape.size(); i++) { - shape_str += ", " + std::to_string(shape[i]); - } - } - shape_str += "]"; - TI_ERROR( - "Indexed matrix of shape {} has wrong number of indices. Expected {} " - "but got " - "{}.", - shape_str, shape.size(), indices.size()); - } ret_type = var->ret_type->cast()->get_element_type(); } else { throw TaichiTypeError( From 77ed7ad0a57694bca57520c60725c40195ada8b5 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 31 Aug 2022 17:59:09 -0400 Subject: [PATCH 57/92] fix dynamic index check --- taichi/codegen/llvm/codegen_llvm.cpp | 3 ++- taichi/runtime/llvm/llvm_context.cpp | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index dd1756fd83140..acff01ab16b4b 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -703,7 +703,8 @@ llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) { } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*llvm_context); } else if (dt->is()) { - TI_ASSERT_INFO(kernel->program->config.real_matrix, + TI_ASSERT_INFO(kernel->program->config.real_matrix || + kernel->program->config.dynamic_index, "Real matrix not enabled but got TensorType"); auto tensor_type = dt->cast(); auto element_type = llvm_type(tensor_type->get_element_type()); diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index 8e99f32503b58..e3b29401188b3 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -143,7 +143,7 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*ctx); } else if (dt->is()) { - TI_ASSERT_INFO(config_->real_matrix, + TI_ASSERT_INFO(config_->real_matrix || config_->dynamic_index, "Real matrix not enabled but got TensorType"); auto vectorty = dt->as(); auto dtype = this->get_data_type(vectorty->get_element_type()); From 10fec4f486071e09364cd77ab10a245233e1759d Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 1 Sep 2022 15:25:01 -0400 Subject: [PATCH 58/92] fix matrix solve test --- taichi/analysis/data_source_analysis.cpp | 2 -- taichi/ir/control_flow_graph.cpp | 3 +++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index 18287868378b1..9a38bd454031b 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -32,8 +32,6 @@ std::vector get_load_pointers(Stmt *load_stmt) { return external_func->arg_stmts; } else if (auto ref = load_stmt->cast()) { return {ref->var}; - } else if (auto ptr_offset = load_stmt->cast()) { - return {ptr_offset->origin}; } else { return std::vector(); } diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index d1b23f8b13c87..809a3a3731eaa 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -363,6 +363,9 @@ void CFGNode::live_variable_analysis(bool after_lower_access) { for (int i = begin_location; i < end_location; i++) { auto stmt = block->statements[i].get(); auto load_ptrs = irpass::analysis::get_load_pointers(stmt); + // if (auto ptr_offset = stmt->cast()) { + // load_ptrs = std::vector(1, ptr_offset->origin); + // } for (auto &load_ptr : load_ptrs) { if (!after_lower_access || (load_ptr->is() || load_ptr->is())) { From c7fa2edb38c563b5c4d678b03bd8bcf1a4012795 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 23 Aug 2022 14:13:57 -0400 Subject: [PATCH 59/92] fix print on cuda --- taichi/codegen/cuda/codegen_cuda.cpp | 53 +++++++++++++++++++--------- 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index eb98dfa502444..5cbe5c4d24d48 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -66,6 +66,24 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { llvm::Type::getInt8PtrTy(*llvm_context))); } + std::tuple create_value_and_type(llvm::Value *value, DataType dt) { + auto value_type = tlctx->get_data_type(dt); + if (dt->is_primitive(PrimitiveTypeID::f32) || + dt->is_primitive(PrimitiveTypeID::f16)) { + value_type = tlctx->get_data_type(PrimitiveType::f64); + value = builder->CreateFPExt(value, value_type); + } + if (dt->is_primitive(PrimitiveTypeID::i8)) { + value_type = tlctx->get_data_type(PrimitiveType::i16); + value = builder->CreateSExt(value, value_type); + } + if (dt->is_primitive(PrimitiveTypeID::u8)) { + value_type = tlctx->get_data_type(PrimitiveType::u16); + value = builder->CreateZExt(value, value_type); + } + return std::make_tuple(value, value_type); + } + void visit(PrintStmt *stmt) override { TI_ASSERT_INFO(stmt->contents.size() < 32, "CUDA `print()` doesn't support more than 32 entries"); @@ -74,31 +92,32 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { std::vector values; std::string formats; + size_t num_contents = 0; for (auto const &content : stmt->contents) { if (std::holds_alternative(content)) { auto arg_stmt = std::get(content); formats += data_type_format(arg_stmt->ret_type); - auto value_type = tlctx->get_data_type(arg_stmt->ret_type); auto value = llvm_val[arg_stmt]; - if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) || - arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) { - value_type = tlctx->get_data_type(PrimitiveType::f64); - value = builder->CreateFPExt(value, value_type); - } - if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::i8)) { - value_type = tlctx->get_data_type(PrimitiveType::i16); - value = builder->CreateSExt(value, value_type); - } - if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::u8)) { - value_type = tlctx->get_data_type(PrimitiveType::u16); - value = builder->CreateZExt(value, value_type); + if (arg_stmt->ret_type->is()) { + auto dtype = arg_stmt->ret_type->cast(); + num_contents += dtype->get_num_elements(); + auto elem_type = dtype->get_element_type(); + for (int i = 0; i < dtype->get_num_elements(); ++i) { + auto elem_value = builder->CreateExtractElement(value, i); + auto [casted_value, elem_value_type] = create_value_and_type(elem_value, elem_type); + types.push_back(elem_value_type); + values.push_back(casted_value); + } + } else { + num_contents++; + auto [val, dtype] = create_value_and_type(value, arg_stmt->ret_type); + types.push_back(dtype); + values.push_back(val); } - - types.push_back(value_type); - values.push_back(value); } else { + num_contents += 1; auto arg_str = std::get(content); auto value = builder->CreateGlobalStringPtr(arg_str, "content_string"); @@ -110,6 +129,8 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { values.push_back(value); formats += "%s"; } + TI_ASSERT_INFO(num_contents < 32, + "CUDA `print()` doesn't support more than 32 entries"); } llvm_val[stmt] = create_print(formats, types, values); From 4284ee58b327f34b4024f6958be8b3ef1ac8a688 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 1 Sep 2022 16:03:01 -0400 Subject: [PATCH 60/92] check real matrix --- taichi/ir/control_flow_graph.cpp | 29 +++++++++++++---------- taichi/ir/control_flow_graph.h | 6 ++--- taichi/ir/transforms.h | 1 + taichi/transforms/cfg_optimization.cpp | 3 ++- taichi/transforms/compile_to_offloads.cpp | 2 +- taichi/transforms/simplify.cpp | 2 +- 6 files changed, 25 insertions(+), 18 deletions(-) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 809a3a3731eaa..ff4934bb13e5d 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -232,7 +232,8 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { return result; } -void CFGNode::reaching_definition_analysis(bool after_lower_access) { +void CFGNode::reaching_definition_analysis(bool after_lower_access, + bool real_matrix_enabled) { // Calculate |reach_gen| and |reach_kill|. reach_gen.clear(); reach_kill.clear(); @@ -240,12 +241,14 @@ void CFGNode::reaching_definition_analysis(bool after_lower_access) { // loop in reversed order auto stmt = block->statements[i].get(); auto data_source_ptrs = irpass::analysis::get_store_destination(stmt); - if (auto local_store = stmt->cast()) { - if (auto dest = local_store->dest->cast()) { - if (auto data = get_store_forwarding_data(dest->origin, i)) { - data_source_ptrs = std::vector(1, data); - } else { - data_source_ptrs = std::vector(); + if (real_matrix_enabled) { + if (auto local_store = stmt->cast()) { + if (auto dest = local_store->dest->cast()) { + if (auto data = get_store_forwarding_data(dest->origin, i)) { + data_source_ptrs = std::vector(1, data); + } else { + data_source_ptrs = std::vector(); + } } } } @@ -602,7 +605,8 @@ void ControlFlowGraph::print_graph_structure() const { } } -void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { +void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access, + bool real_matrix_enabled) { TI_AUTO_PROF; const int num_nodes = size(); std::queue to_visit; @@ -627,7 +631,7 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { } for (int i = 0; i < num_nodes; i++) { if (i != start_node) { - nodes[i]->reaching_definition_analysis(after_lower_access); + nodes[i]->reaching_definition_analysis(after_lower_access, real_matrix_enabled); } nodes[i]->reach_in.clear(); nodes[i]->reach_out = nodes[i]->reach_gen; @@ -824,9 +828,10 @@ bool ControlFlowGraph::unreachable_code_elimination() { } bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access, - bool autodiff_enabled) { + bool autodiff_enabled, + bool real_matrix_enabled) { TI_AUTO_PROF; - reaching_definition_analysis(after_lower_access); + reaching_definition_analysis(after_lower_access, real_matrix_enabled); const int num_nodes = size(); bool modified = false; for (int i = 0; i < num_nodes; i++) { @@ -853,7 +858,7 @@ bool ControlFlowGraph::dead_store_elimination( std::unordered_set ControlFlowGraph::gather_loaded_snodes() { TI_AUTO_PROF; - reaching_definition_analysis(/*after_lower_access=*/false); + reaching_definition_analysis(/*after_lower_access=*/false, /*real_matrix=*/false); const int num_nodes = size(); std::unordered_set snodes; diff --git a/taichi/ir/control_flow_graph.h b/taichi/ir/control_flow_graph.h index 1776a57ea4eea..ff6cacd616915 100644 --- a/taichi/ir/control_flow_graph.h +++ b/taichi/ir/control_flow_graph.h @@ -75,7 +75,7 @@ class CFGNode { Stmt *get_store_forwarding_data(Stmt *var, int position) const; // Analyses and optimizations inside a CFGNode. - void reaching_definition_analysis(bool after_lower_access); + void reaching_definition_analysis(bool after_lower_access, bool real_matrix_enabled); bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled); void gather_loaded_snodes(std::unordered_set &snodes) const; void live_variable_analysis(bool after_lower_access); @@ -117,7 +117,7 @@ class ControlFlowGraph { * @param after_lower_access * When after_lower_access is true, only consider local variables (allocas). */ - void reaching_definition_analysis(bool after_lower_access); + void reaching_definition_analysis(bool after_lower_access, bool real_matrix_enabled); /** * Perform live variable analysis using the worklist algorithm, @@ -145,7 +145,7 @@ class ControlFlowGraph { /** * Perform store-to-load forwarding and identical store elimination. */ - bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled); + bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled, bool store_to_load_forwarding); /** * Perform dead store elimination and identical load elimination. diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index ac6006646bc3a..b96d75834a66f 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -35,6 +35,7 @@ bool cfg_optimization( IRNode *root, bool after_lower_access, bool autodiff_enabled, + bool real_matrix_enabled, const std::optional &lva_config_opt = std::nullopt); bool alg_simp(IRNode *root, const CompileConfig &config); diff --git a/taichi/transforms/cfg_optimization.cpp b/taichi/transforms/cfg_optimization.cpp index 93efc651a411d..ea6f756ba805b 100644 --- a/taichi/transforms/cfg_optimization.cpp +++ b/taichi/transforms/cfg_optimization.cpp @@ -11,6 +11,7 @@ bool cfg_optimization( IRNode *root, bool after_lower_access, bool autodiff_enabled, + bool real_matrix_enabled, const std::optional &lva_config_opt) { TI_AUTO_PROF; @@ -19,7 +20,7 @@ bool cfg_optimization( while (true) { bool modified = false; cfg->simplify_graph(); - if (cfg->store_to_load_forwarding(after_lower_access, autodiff_enabled)) + if (cfg->store_to_load_forwarding(after_lower_access, autodiff_enabled, real_matrix_enabled)) modified = true; if (cfg->dead_store_elimination(after_lower_access, lva_config_opt)) modified = true; diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 505ea107ff1f5..d1e12f645df54 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -133,7 +133,7 @@ void compile_to_offloads(IRNode *ir, // TODO: This pass may be redundant as cfg_optimization() is already called // in full_simplify(). if (config.opt_level > 0 && config.cfg_optimization) { - irpass::cfg_optimization(ir, false, /*autodiff_enabled*/ false); + irpass::cfg_optimization(ir, false, /*autodiff_enabled*/ false, config.real_matrix); print("Optimized by CFG"); irpass::analysis::verify(ir); } diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 7ce1b04784c0a..eeb8b091dd4e3 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -637,7 +637,7 @@ void full_simplify(IRNode *root, if (config.opt_level > 0 && (first_iteration || modified) && config.cfg_optimization && cfg_optimization(root, args.after_lower_access, - args.autodiff_enabled)) + args.autodiff_enabled, config.real_matrix)) modified = true; first_iteration = false; if (!modified) From e7002f09636c7d673cf628bded7c704d33a52e80 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 1 Sep 2022 20:05:08 +0000 Subject: [PATCH 61/92] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/codegen/cuda/codegen_cuda.cpp | 9 ++++++--- taichi/ir/control_flow_graph.cpp | 6 ++++-- taichi/ir/control_flow_graph.h | 10 +++++++--- taichi/transforms/cfg_optimization.cpp | 3 ++- taichi/transforms/compile_to_offloads.cpp | 3 ++- taichi/transforms/simplify.cpp | 4 ++-- 6 files changed, 23 insertions(+), 12 deletions(-) diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index 5cbe5c4d24d48..bff2a25541560 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -66,7 +66,9 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { llvm::Type::getInt8PtrTy(*llvm_context))); } - std::tuple create_value_and_type(llvm::Value *value, DataType dt) { + std::tuple create_value_and_type( + llvm::Value *value, + DataType dt) { auto value_type = tlctx->get_data_type(dt); if (dt->is_primitive(PrimitiveTypeID::f32) || dt->is_primitive(PrimitiveTypeID::f16)) { @@ -106,7 +108,8 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { auto elem_type = dtype->get_element_type(); for (int i = 0; i < dtype->get_num_elements(); ++i) { auto elem_value = builder->CreateExtractElement(value, i); - auto [casted_value, elem_value_type] = create_value_and_type(elem_value, elem_type); + auto [casted_value, elem_value_type] = + create_value_and_type(elem_value, elem_type); types.push_back(elem_value_type); values.push_back(casted_value); } @@ -130,7 +133,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { formats += "%s"; } TI_ASSERT_INFO(num_contents < 32, - "CUDA `print()` doesn't support more than 32 entries"); + "CUDA `print()` doesn't support more than 32 entries"); } llvm_val[stmt] = create_print(formats, types, values); diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index ff4934bb13e5d..27af3b7ebdba3 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -631,7 +631,8 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access, } for (int i = 0; i < num_nodes; i++) { if (i != start_node) { - nodes[i]->reaching_definition_analysis(after_lower_access, real_matrix_enabled); + nodes[i]->reaching_definition_analysis(after_lower_access, + real_matrix_enabled); } nodes[i]->reach_in.clear(); nodes[i]->reach_out = nodes[i]->reach_gen; @@ -858,7 +859,8 @@ bool ControlFlowGraph::dead_store_elimination( std::unordered_set ControlFlowGraph::gather_loaded_snodes() { TI_AUTO_PROF; - reaching_definition_analysis(/*after_lower_access=*/false, /*real_matrix=*/false); + reaching_definition_analysis(/*after_lower_access=*/false, + /*real_matrix=*/false); const int num_nodes = size(); std::unordered_set snodes; diff --git a/taichi/ir/control_flow_graph.h b/taichi/ir/control_flow_graph.h index ff6cacd616915..9a8c4284afb5e 100644 --- a/taichi/ir/control_flow_graph.h +++ b/taichi/ir/control_flow_graph.h @@ -75,7 +75,8 @@ class CFGNode { Stmt *get_store_forwarding_data(Stmt *var, int position) const; // Analyses and optimizations inside a CFGNode. - void reaching_definition_analysis(bool after_lower_access, bool real_matrix_enabled); + void reaching_definition_analysis(bool after_lower_access, + bool real_matrix_enabled); bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled); void gather_loaded_snodes(std::unordered_set &snodes) const; void live_variable_analysis(bool after_lower_access); @@ -117,7 +118,8 @@ class ControlFlowGraph { * @param after_lower_access * When after_lower_access is true, only consider local variables (allocas). */ - void reaching_definition_analysis(bool after_lower_access, bool real_matrix_enabled); + void reaching_definition_analysis(bool after_lower_access, + bool real_matrix_enabled); /** * Perform live variable analysis using the worklist algorithm, @@ -145,7 +147,9 @@ class ControlFlowGraph { /** * Perform store-to-load forwarding and identical store elimination. */ - bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled, bool store_to_load_forwarding); + bool store_to_load_forwarding(bool after_lower_access, + bool autodiff_enabled, + bool store_to_load_forwarding); /** * Perform dead store elimination and identical load elimination. diff --git a/taichi/transforms/cfg_optimization.cpp b/taichi/transforms/cfg_optimization.cpp index ea6f756ba805b..441011a194997 100644 --- a/taichi/transforms/cfg_optimization.cpp +++ b/taichi/transforms/cfg_optimization.cpp @@ -20,7 +20,8 @@ bool cfg_optimization( while (true) { bool modified = false; cfg->simplify_graph(); - if (cfg->store_to_load_forwarding(after_lower_access, autodiff_enabled, real_matrix_enabled)) + if (cfg->store_to_load_forwarding(after_lower_access, autodiff_enabled, + real_matrix_enabled)) modified = true; if (cfg->dead_store_elimination(after_lower_access, lva_config_opt)) modified = true; diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index d1e12f645df54..8c8c3d3d6af82 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -133,7 +133,8 @@ void compile_to_offloads(IRNode *ir, // TODO: This pass may be redundant as cfg_optimization() is already called // in full_simplify(). if (config.opt_level > 0 && config.cfg_optimization) { - irpass::cfg_optimization(ir, false, /*autodiff_enabled*/ false, config.real_matrix); + irpass::cfg_optimization(ir, false, /*autodiff_enabled*/ false, + config.real_matrix); print("Optimized by CFG"); irpass::analysis::verify(ir); } diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index eeb8b091dd4e3..6c18c88a9194d 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -636,8 +636,8 @@ void full_simplify(IRNode *root, // not modified. if (config.opt_level > 0 && (first_iteration || modified) && config.cfg_optimization && - cfg_optimization(root, args.after_lower_access, - args.autodiff_enabled, config.real_matrix)) + cfg_optimization(root, args.after_lower_access, args.autodiff_enabled, + config.real_matrix)) modified = true; first_iteration = false; if (!modified) From 9375327cef86414b56a8c21497001c1fc1c27610 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 1 Sep 2022 16:05:50 -0400 Subject: [PATCH 62/92] uncomment --- taichi/ir/control_flow_graph.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index ff4934bb13e5d..7e8a582890c15 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -366,9 +366,9 @@ void CFGNode::live_variable_analysis(bool after_lower_access) { for (int i = begin_location; i < end_location; i++) { auto stmt = block->statements[i].get(); auto load_ptrs = irpass::analysis::get_load_pointers(stmt); - // if (auto ptr_offset = stmt->cast()) { - // load_ptrs = std::vector(1, ptr_offset->origin); - // } + if (auto ptr_offset = stmt->cast()) { + load_ptrs = std::vector(1, ptr_offset->origin); + } for (auto &load_ptr : load_ptrs) { if (!after_lower_access || (load_ptr->is() || load_ptr->is())) { From b468eb45724533bf9cc4f88534c4d1dbeec0c938 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 1 Sep 2022 20:34:33 -0400 Subject: [PATCH 63/92] fix index read --- taichi/ir/control_flow_graph.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index e62473b1ed073..47fe8aa687914 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -486,6 +486,9 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { } } auto load_ptrs = irpass::analysis::get_load_pointers(stmt); + if (auto ptr_offset = stmt->cast()) { + load_ptrs = std::vector(1, ptr_offset->origin); + } if (load_ptrs.size() == 1 && store_ptrs.empty() && !stmt->is()) { // Identical load elimination From ff295ae71335931242c9ecea65cb267c837373fb Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 1 Sep 2022 21:48:26 -0400 Subject: [PATCH 64/92] check flag for dead store elem --- taichi/ir/control_flow_graph.cpp | 12 ++++++++---- taichi/ir/control_flow_graph.h | 3 ++- taichi/transforms/cfg_optimization.cpp | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 47fe8aa687914..08ce6ea6afff6 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -399,7 +399,8 @@ void CFGNode::live_variable_analysis(bool after_lower_access) { } } -bool CFGNode::dead_store_elimination(bool after_lower_access) { +bool CFGNode::dead_store_elimination(bool after_lower_access, + bool real_matrix_enabled) { bool modified = false; std::unordered_set live_in_this_node; std::unordered_set killed_in_this_node; @@ -486,8 +487,10 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { } } auto load_ptrs = irpass::analysis::get_load_pointers(stmt); - if (auto ptr_offset = stmt->cast()) { - load_ptrs = std::vector(1, ptr_offset->origin); + if (real_matrix_enabled) { + if (auto ptr_offset = stmt->cast()) { + load_ptrs = std::vector(1, ptr_offset->origin); + } } if (load_ptrs.size() == 1 && store_ptrs.empty() && !stmt->is()) { @@ -848,13 +851,14 @@ bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access, bool ControlFlowGraph::dead_store_elimination( bool after_lower_access, + bool real_matrix_enabled, const std::optional &lva_config_opt) { TI_AUTO_PROF; live_variable_analysis(after_lower_access, lva_config_opt); const int num_nodes = size(); bool modified = false; for (int i = 0; i < num_nodes; i++) { - if (nodes[i]->dead_store_elimination(after_lower_access)) + if (nodes[i]->dead_store_elimination(after_lower_access, real_matrix_enabled)) modified = true; } return modified; diff --git a/taichi/ir/control_flow_graph.h b/taichi/ir/control_flow_graph.h index 9a8c4284afb5e..e389084a0fc10 100644 --- a/taichi/ir/control_flow_graph.h +++ b/taichi/ir/control_flow_graph.h @@ -80,7 +80,7 @@ class CFGNode { bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled); void gather_loaded_snodes(std::unordered_set &snodes) const; void live_variable_analysis(bool after_lower_access); - bool dead_store_elimination(bool after_lower_access); + bool dead_store_elimination(bool after_lower_access, bool real_matrix_enabled); }; class ControlFlowGraph { @@ -156,6 +156,7 @@ class ControlFlowGraph { */ bool dead_store_elimination( bool after_lower_access, + bool real_matrix_enabled, const std::optional &lva_config_opt); /** diff --git a/taichi/transforms/cfg_optimization.cpp b/taichi/transforms/cfg_optimization.cpp index 441011a194997..e1c0a0f80230a 100644 --- a/taichi/transforms/cfg_optimization.cpp +++ b/taichi/transforms/cfg_optimization.cpp @@ -23,7 +23,7 @@ bool cfg_optimization( if (cfg->store_to_load_forwarding(after_lower_access, autodiff_enabled, real_matrix_enabled)) modified = true; - if (cfg->dead_store_elimination(after_lower_access, lva_config_opt)) + if (cfg->dead_store_elimination(after_lower_access, real_matrix_enabled, lva_config_opt)) modified = true; if (modified) result_modified = true; From 87029e5c093a9c4b3129ff603965bb21258ee174 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 2 Sep 2022 01:50:00 +0000 Subject: [PATCH 65/92] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/ir/control_flow_graph.cpp | 3 ++- taichi/ir/control_flow_graph.h | 3 ++- taichi/transforms/cfg_optimization.cpp | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 08ce6ea6afff6..018f6aad2e19f 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -858,7 +858,8 @@ bool ControlFlowGraph::dead_store_elimination( const int num_nodes = size(); bool modified = false; for (int i = 0; i < num_nodes; i++) { - if (nodes[i]->dead_store_elimination(after_lower_access, real_matrix_enabled)) + if (nodes[i]->dead_store_elimination(after_lower_access, + real_matrix_enabled)) modified = true; } return modified; diff --git a/taichi/ir/control_flow_graph.h b/taichi/ir/control_flow_graph.h index e389084a0fc10..5ee111b86a730 100644 --- a/taichi/ir/control_flow_graph.h +++ b/taichi/ir/control_flow_graph.h @@ -80,7 +80,8 @@ class CFGNode { bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled); void gather_loaded_snodes(std::unordered_set &snodes) const; void live_variable_analysis(bool after_lower_access); - bool dead_store_elimination(bool after_lower_access, bool real_matrix_enabled); + bool dead_store_elimination(bool after_lower_access, + bool real_matrix_enabled); }; class ControlFlowGraph { diff --git a/taichi/transforms/cfg_optimization.cpp b/taichi/transforms/cfg_optimization.cpp index e1c0a0f80230a..abc498dda1278 100644 --- a/taichi/transforms/cfg_optimization.cpp +++ b/taichi/transforms/cfg_optimization.cpp @@ -23,7 +23,8 @@ bool cfg_optimization( if (cfg->store_to_load_forwarding(after_lower_access, autodiff_enabled, real_matrix_enabled)) modified = true; - if (cfg->dead_store_elimination(after_lower_access, real_matrix_enabled, lva_config_opt)) + if (cfg->dead_store_elimination(after_lower_access, real_matrix_enabled, + lva_config_opt)) modified = true; if (modified) result_modified = true; From 2302a247ff7e338a459f3ac0f203d5fd856c06f6 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 1 Sep 2022 22:02:03 -0400 Subject: [PATCH 66/92] turn off cfg for new matrix impl --- taichi/ir/control_flow_graph.cpp | 13 +++++++++---- taichi/ir/control_flow_graph.h | 7 +++++-- taichi/transforms/cfg_optimization.cpp | 3 ++- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 08ce6ea6afff6..752408898dacd 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -267,7 +267,11 @@ void CFGNode::reaching_definition_analysis(bool after_lower_access, } bool CFGNode::store_to_load_forwarding(bool after_lower_access, - bool autodiff_enabled) { + bool autodiff_enabled, + bool real_matrix_enabled) { + if (real_matrix_enabled) + // Disable this for new matrices for now + return false; bool modified = false; for (int i = begin_location; i < end_location; i++) { // Store-to-load forwarding @@ -842,8 +846,8 @@ bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access, const int num_nodes = size(); bool modified = false; for (int i = 0; i < num_nodes; i++) { - if (nodes[i]->store_to_load_forwarding(after_lower_access, - autodiff_enabled)) + if (nodes[i]->store_to_load_forwarding(after_lower_access, autodiff_enabled, + real_matrix_enabled)) modified = true; } return modified; @@ -858,7 +862,8 @@ bool ControlFlowGraph::dead_store_elimination( const int num_nodes = size(); bool modified = false; for (int i = 0; i < num_nodes; i++) { - if (nodes[i]->dead_store_elimination(after_lower_access, real_matrix_enabled)) + if (nodes[i]->dead_store_elimination(after_lower_access, + real_matrix_enabled)) modified = true; } return modified; diff --git a/taichi/ir/control_flow_graph.h b/taichi/ir/control_flow_graph.h index e389084a0fc10..011482a445c06 100644 --- a/taichi/ir/control_flow_graph.h +++ b/taichi/ir/control_flow_graph.h @@ -77,10 +77,13 @@ class CFGNode { // Analyses and optimizations inside a CFGNode. void reaching_definition_analysis(bool after_lower_access, bool real_matrix_enabled); - bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled); + bool store_to_load_forwarding(bool after_lower_access, + bool autodiff_enabled, + bool real_matrix_enabled); void gather_loaded_snodes(std::unordered_set &snodes) const; void live_variable_analysis(bool after_lower_access); - bool dead_store_elimination(bool after_lower_access, bool real_matrix_enabled); + bool dead_store_elimination(bool after_lower_access, + bool real_matrix_enabled); }; class ControlFlowGraph { diff --git a/taichi/transforms/cfg_optimization.cpp b/taichi/transforms/cfg_optimization.cpp index e1c0a0f80230a..abc498dda1278 100644 --- a/taichi/transforms/cfg_optimization.cpp +++ b/taichi/transforms/cfg_optimization.cpp @@ -23,7 +23,8 @@ bool cfg_optimization( if (cfg->store_to_load_forwarding(after_lower_access, autodiff_enabled, real_matrix_enabled)) modified = true; - if (cfg->dead_store_elimination(after_lower_access, real_matrix_enabled, lva_config_opt)) + if (cfg->dead_store_elimination(after_lower_access, real_matrix_enabled, + lva_config_opt)) modified = true; if (modified) result_modified = true; From c92b578f122455bdd1ebc63473ee361ed773c876 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 1 Sep 2022 22:25:18 -0400 Subject: [PATCH 67/92] disable cfg pass for new impl --- taichi/analysis/data_source_analysis.cpp | 4 +- taichi/ir/control_flow_graph.cpp | 58 +++++++----------------- taichi/ir/control_flow_graph.h | 13 ++---- taichi/ir/type.h | 1 + 4 files changed, 23 insertions(+), 53 deletions(-) diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index 9a38bd454031b..7c6721800f695 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -39,7 +39,7 @@ std::vector get_load_pointers(Stmt *load_stmt) { Stmt *get_store_data(Stmt *store_stmt) { // If store_stmt provides one data source, return the data. - if (store_stmt->is()) { + if (store_stmt->is() && !store_stmt->ret_type->is()) { // For convenience, return store_stmt instead of the const [0] it actually // stores. return store_stmt; @@ -54,7 +54,7 @@ Stmt *get_store_data(Stmt *store_stmt) { std::vector get_store_destination(Stmt *store_stmt) { // If store_stmt provides some data sources, return the pointers of the data. - if (store_stmt->is()) { + if (store_stmt->is() && !store_stmt->ret_type->is()) { // The statement itself provides a data source (const [0]). return std::vector(1, store_stmt); } else if (auto local_store = store_stmt->cast()) { diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 752408898dacd..66f039722b371 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -147,7 +147,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // The UD-chain is inside this node. Stmt *result = irpass::analysis::get_store_data( block->statements[last_def_position].get()); - if (!var->is() || var->ret_type->is()) { + if (!var->is()) { for (int i = last_def_position + 1; i < position; i++) { if (!irpass::analysis::same_value( result, @@ -232,8 +232,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { return result; } -void CFGNode::reaching_definition_analysis(bool after_lower_access, - bool real_matrix_enabled) { +void CFGNode::reaching_definition_analysis(bool after_lower_access) { // Calculate |reach_gen| and |reach_kill|. reach_gen.clear(); reach_kill.clear(); @@ -241,17 +240,6 @@ void CFGNode::reaching_definition_analysis(bool after_lower_access, // loop in reversed order auto stmt = block->statements[i].get(); auto data_source_ptrs = irpass::analysis::get_store_destination(stmt); - if (real_matrix_enabled) { - if (auto local_store = stmt->cast()) { - if (auto dest = local_store->dest->cast()) { - if (auto data = get_store_forwarding_data(dest->origin, i)) { - data_source_ptrs = std::vector(1, data); - } else { - data_source_ptrs = std::vector(); - } - } - } - } for (auto data_source_ptr : data_source_ptrs) { // stmt provides a data source if (after_lower_access && !(data_source_ptr->is())) { @@ -267,11 +255,7 @@ void CFGNode::reaching_definition_analysis(bool after_lower_access, } bool CFGNode::store_to_load_forwarding(bool after_lower_access, - bool autodiff_enabled, - bool real_matrix_enabled) { - if (real_matrix_enabled) - // Disable this for new matrices for now - return false; + bool autodiff_enabled) { bool modified = false; for (int i = begin_location; i < end_location; i++) { // Store-to-load forwarding @@ -370,9 +354,6 @@ void CFGNode::live_variable_analysis(bool after_lower_access) { for (int i = begin_location; i < end_location; i++) { auto stmt = block->statements[i].get(); auto load_ptrs = irpass::analysis::get_load_pointers(stmt); - if (auto ptr_offset = stmt->cast()) { - load_ptrs = std::vector(1, ptr_offset->origin); - } for (auto &load_ptr : load_ptrs) { if (!after_lower_access || (load_ptr->is() || load_ptr->is())) { @@ -403,8 +384,7 @@ void CFGNode::live_variable_analysis(bool after_lower_access) { } } -bool CFGNode::dead_store_elimination(bool after_lower_access, - bool real_matrix_enabled) { +bool CFGNode::dead_store_elimination(bool after_lower_access) { bool modified = false; std::unordered_set live_in_this_node; std::unordered_set killed_in_this_node; @@ -491,13 +471,7 @@ bool CFGNode::dead_store_elimination(bool after_lower_access, } } auto load_ptrs = irpass::analysis::get_load_pointers(stmt); - if (real_matrix_enabled) { - if (auto ptr_offset = stmt->cast()) { - load_ptrs = std::vector(1, ptr_offset->origin); - } - } - if (load_ptrs.size() == 1 && store_ptrs.empty() && - !stmt->is()) { + if (load_ptrs.size() == 1 && store_ptrs.empty()) { // Identical load elimination auto load_ptr = load_ptrs.front(); if (!after_lower_access || @@ -615,8 +589,7 @@ void ControlFlowGraph::print_graph_structure() const { } } -void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access, - bool real_matrix_enabled) { +void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { TI_AUTO_PROF; const int num_nodes = size(); std::queue to_visit; @@ -641,8 +614,7 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access, } for (int i = 0; i < num_nodes; i++) { if (i != start_node) { - nodes[i]->reaching_definition_analysis(after_lower_access, - real_matrix_enabled); + nodes[i]->reaching_definition_analysis(after_lower_access); } nodes[i]->reach_in.clear(); nodes[i]->reach_out = nodes[i]->reach_gen; @@ -842,12 +814,14 @@ bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled, bool real_matrix_enabled) { TI_AUTO_PROF; - reaching_definition_analysis(after_lower_access, real_matrix_enabled); + if (real_matrix_enabled) + return false; + reaching_definition_analysis(after_lower_access); const int num_nodes = size(); bool modified = false; for (int i = 0; i < num_nodes; i++) { - if (nodes[i]->store_to_load_forwarding(after_lower_access, autodiff_enabled, - real_matrix_enabled)) + if (nodes[i]->store_to_load_forwarding(after_lower_access, + autodiff_enabled)) modified = true; } return modified; @@ -858,12 +832,13 @@ bool ControlFlowGraph::dead_store_elimination( bool real_matrix_enabled, const std::optional &lva_config_opt) { TI_AUTO_PROF; + if (real_matrix_enabled) + return false; live_variable_analysis(after_lower_access, lva_config_opt); const int num_nodes = size(); bool modified = false; for (int i = 0; i < num_nodes; i++) { - if (nodes[i]->dead_store_elimination(after_lower_access, - real_matrix_enabled)) + if (nodes[i]->dead_store_elimination(after_lower_access)) modified = true; } return modified; @@ -871,8 +846,7 @@ bool ControlFlowGraph::dead_store_elimination( std::unordered_set ControlFlowGraph::gather_loaded_snodes() { TI_AUTO_PROF; - reaching_definition_analysis(/*after_lower_access=*/false, - /*real_matrix=*/false); + reaching_definition_analysis(/*after_lower_access=*/false); const int num_nodes = size(); std::unordered_set snodes; diff --git a/taichi/ir/control_flow_graph.h b/taichi/ir/control_flow_graph.h index 011482a445c06..196591c3bcc97 100644 --- a/taichi/ir/control_flow_graph.h +++ b/taichi/ir/control_flow_graph.h @@ -75,15 +75,11 @@ class CFGNode { Stmt *get_store_forwarding_data(Stmt *var, int position) const; // Analyses and optimizations inside a CFGNode. - void reaching_definition_analysis(bool after_lower_access, - bool real_matrix_enabled); - bool store_to_load_forwarding(bool after_lower_access, - bool autodiff_enabled, - bool real_matrix_enabled); + void reaching_definition_analysis(bool after_lower_access); + bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled); void gather_loaded_snodes(std::unordered_set &snodes) const; void live_variable_analysis(bool after_lower_access); - bool dead_store_elimination(bool after_lower_access, - bool real_matrix_enabled); + bool dead_store_elimination(bool after_lower_access); }; class ControlFlowGraph { @@ -121,8 +117,7 @@ class ControlFlowGraph { * @param after_lower_access * When after_lower_access is true, only consider local variables (allocas). */ - void reaching_definition_analysis(bool after_lower_access, - bool real_matrix_enabled); + void reaching_definition_analysis(bool after_lower_access); /** * Perform live variable analysis using the worklist algorithm, diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 064a7fb0e8462..e3c23a1f814fb 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -445,6 +445,7 @@ class TypedConstant { } else if (dt->is_primitive(PrimitiveTypeID::u64)) { val_u64 = value; } else { + assert(false); TI_NOT_IMPLEMENTED } } From 768269926fb3f34f4fa3ef64c8a7b7b35cef58cc Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 1 Sep 2022 22:25:44 -0400 Subject: [PATCH 68/92] rm debug code --- taichi/ir/type.h | 1 - 1 file changed, 1 deletion(-) diff --git a/taichi/ir/type.h b/taichi/ir/type.h index e3c23a1f814fb..064a7fb0e8462 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -445,7 +445,6 @@ class TypedConstant { } else if (dt->is_primitive(PrimitiveTypeID::u64)) { val_u64 = value; } else { - assert(false); TI_NOT_IMPLEMENTED } } From bc04628ea45bd5aac32330a4f4e059489a3851eb Mon Sep 17 00:00:00 2001 From: AD1024 Date: Fri, 2 Sep 2022 11:29:28 -0400 Subject: [PATCH 69/92] disable in transformation pass --- taichi/ir/control_flow_graph.cpp | 8 +------- taichi/ir/control_flow_graph.h | 5 +---- taichi/transforms/cfg_optimization.cpp | 8 +++----- 3 files changed, 5 insertions(+), 16 deletions(-) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 66f039722b371..74f6e4311641f 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -811,11 +811,8 @@ bool ControlFlowGraph::unreachable_code_elimination() { } bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access, - bool autodiff_enabled, - bool real_matrix_enabled) { + bool autodiff_enabled) { TI_AUTO_PROF; - if (real_matrix_enabled) - return false; reaching_definition_analysis(after_lower_access); const int num_nodes = size(); bool modified = false; @@ -829,11 +826,8 @@ bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access, bool ControlFlowGraph::dead_store_elimination( bool after_lower_access, - bool real_matrix_enabled, const std::optional &lva_config_opt) { TI_AUTO_PROF; - if (real_matrix_enabled) - return false; live_variable_analysis(after_lower_access, lva_config_opt); const int num_nodes = size(); bool modified = false; diff --git a/taichi/ir/control_flow_graph.h b/taichi/ir/control_flow_graph.h index 196591c3bcc97..1776a57ea4eea 100644 --- a/taichi/ir/control_flow_graph.h +++ b/taichi/ir/control_flow_graph.h @@ -145,16 +145,13 @@ class ControlFlowGraph { /** * Perform store-to-load forwarding and identical store elimination. */ - bool store_to_load_forwarding(bool after_lower_access, - bool autodiff_enabled, - bool store_to_load_forwarding); + bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled); /** * Perform dead store elimination and identical load elimination. */ bool dead_store_elimination( bool after_lower_access, - bool real_matrix_enabled, const std::optional &lva_config_opt); /** diff --git a/taichi/transforms/cfg_optimization.cpp b/taichi/transforms/cfg_optimization.cpp index abc498dda1278..9366300f06ec1 100644 --- a/taichi/transforms/cfg_optimization.cpp +++ b/taichi/transforms/cfg_optimization.cpp @@ -17,14 +17,12 @@ bool cfg_optimization( TI_AUTO_PROF; auto cfg = analysis::build_cfg(root); bool result_modified = false; - while (true) { + while (true && !real_matrix_enabled) { bool modified = false; cfg->simplify_graph(); - if (cfg->store_to_load_forwarding(after_lower_access, autodiff_enabled, - real_matrix_enabled)) + if (cfg->store_to_load_forwarding(after_lower_access, autodiff_enabled)) modified = true; - if (cfg->dead_store_elimination(after_lower_access, real_matrix_enabled, - lva_config_opt)) + if (cfg->dead_store_elimination(after_lower_access, lva_config_opt)) modified = true; if (modified) result_modified = true; From e13dcd59510341c5aee104106c97fa3cac0dd6d4 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Sun, 4 Sep 2022 15:05:11 -0400 Subject: [PATCH 70/92] turn on indexing check --- python/taichi/lang/_ndrange.py | 2 +- python/taichi/lang/impl.py | 6 ++++-- python/taichi/lang/matrix.py | 24 +++++++++++++++++++----- taichi/ir/frontend_ir.cpp | 16 ++++++++++++++++ 4 files changed, 40 insertions(+), 8 deletions(-) diff --git a/python/taichi/lang/_ndrange.py b/python/taichi/lang/_ndrange.py index 51ce24288def1..f414bf5e75cab 100644 --- a/python/taichi/lang/_ndrange.py +++ b/python/taichi/lang/_ndrange.py @@ -144,7 +144,7 @@ def __init__(self, r): def __iter__(self): for ind in self.r: - yield _IntermediateMatrix(len(ind), 1, list(ind)) + yield _IntermediateMatrix(len(ind), 1, list(ind), ndim=1) __all__ = ['ndrange'] diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 9f96dccddf2bc..8eb9002e75b8e 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -15,8 +15,8 @@ from taichi.lang.field import Field, ScalarField from taichi.lang.kernel_arguments import SparseMatrixProxy from taichi.lang.matrix import (Matrix, MatrixField, MatrixNdarray, MatrixType, - _IntermediateMatrix, _MatrixFieldElement, - make_matrix) + Vector, _IntermediateMatrix, + _MatrixFieldElement, make_matrix) from taichi.lang.mesh import (ConvType, MeshElementFieldProxy, MeshInstance, MeshRelationAccessProxy, MeshReorderedMatrixFieldProxy, @@ -63,6 +63,8 @@ def expr_init(rhs): entries = [[rhs(i, j) for j in range(rhs.m)] for i in range(rhs.n)] return make_matrix(entries) + if isinstance(rhs, Vector): + return Vector(rhs.to_list(), ndim=rhs.ndim) return Matrix(rhs.to_list(), ndim=rhs.ndim) if isinstance(rhs, SharedArray): return rhs diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 2d47d504448a6..0b2ef9cb054ba 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -113,6 +113,10 @@ def make_matrix(arr, dt=None): [expr.Expr(elt).ptr for row in arr for elt in row])) +def is_vector(x): + return isinstance(x, Vector) or getattr(x, "ndim", None) == 1 + + class _MatrixBaseImpl: def __init__(self, m, n, entries): self.m = m @@ -441,7 +445,7 @@ def __init__(self, elif isinstance(arr[0], Matrix): raise Exception('cols/rows required when using list of vectors') else: - is_matrix = isinstance(arr[0], Iterable) + is_matrix = isinstance(arr[0], Iterable) and not is_vector(self) initializer = _make_entries_initializer(is_matrix) self.ndim = 2 if is_matrix else 1 @@ -490,17 +494,20 @@ def __init__(self, def _element_wise_binary(self, foo, other): other = self._broadcast_copy(other) + if is_vector(self): + return Vector([foo(self(i), other(i)) for i in range(self.n)], + ndim=self.ndim) return Matrix([[foo(self(i, j), other(i, j)) for j in range(self.m)] for i in range(self.n)], ndim=self.ndim) def _broadcast_copy(self, other): if isinstance(other, (list, tuple)): - other = Matrix(other) + other = type(self)(other) if not isinstance(other, Matrix): - other = Matrix([[other for _ in range(self.m)] - for _ in range(self.n)], - ndim=self.ndim) + other = type(self)([[other for _ in range(self.m)] + for _ in range(self.n)], + ndim=self.ndim) assert self.m == other.m and self.n == other.n, f"Dimension mismatch between shapes ({self.n}, {self.m}), ({other.n}, {other.m})" return other @@ -665,6 +672,10 @@ def cast(self, dtype): >>> B [0.0, 1.0, 2.0] """ + if is_vector(self): + # when using _IntermediateMatrix, we can only check `self.ndim` + return Vector( + [ops_mod.cast(self(i), dtype) for i in range(self.n)]) return Matrix( [[ops_mod.cast(self(i, j), dtype) for j in range(self.m)] for i in range(self.n)], @@ -1476,6 +1487,9 @@ def ndarray(cls, n, dtype, shape, layout=Layout.AOS): shape = (shape, ) return VectorNdarray(n, dtype, shape, layout) + def to_list(self): + return [self(i) for i in range(self.n)] + class _IntermediateMatrix(Matrix): """Intermediate matrix class for compiler internal use only. diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index c013ae163fe55..7dc9d057240bf 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -503,6 +503,22 @@ void IndexExpression::type_check(CompileConfig *) { ret_type = var.cast()->dt; } } else if (is_tensor()) { // local tensor + auto shape = var->ret_type->as()->get_shape(); + if (indices.size() != shape.size()) { + std::string shape_str = "["; + if (shape.size() > 0) { + shape_str += std::to_string(shape[0]); + for (int i = 1; i < shape.size(); i++) { + shape_str += ", " + std::to_string(shape[i]); + } + } + shape_str += "]"; + TI_ERROR( + "Indexed matrix of shape {} has wrong number of indices. Expected {} " + "but got " + "{}.", + shape_str, shape.size(), indices.size()); + } ret_type = var->ret_type->cast()->get_element_type(); } else { throw TaichiTypeError( From a12de9e55f085e7c40ab845f83f5780855af91be Mon Sep 17 00:00:00 2001 From: AD1024 Date: Sun, 4 Sep 2022 15:43:48 -0400 Subject: [PATCH 71/92] fix element wise copy --- python/taichi/lang/matrix.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 0b2ef9cb054ba..658b20ab9da2a 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -505,7 +505,10 @@ def _broadcast_copy(self, other): if isinstance(other, (list, tuple)): other = type(self)(other) if not isinstance(other, Matrix): - other = type(self)([[other for _ in range(self.m)] + if isinstance(self, Vector): + other = Vector([other for _ in range(self.n)]) + else: + other = Matrix([[other for _ in range(self.m)] for _ in range(self.n)], ndim=self.ndim) assert self.m == other.m and self.n == other.n, f"Dimension mismatch between shapes ({self.n}, {self.m}), ({other.n}, {other.m})" From 9721042e5bc392f6029e09c782f7a4fabeef2779 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Sun, 4 Sep 2022 15:56:10 -0400 Subject: [PATCH 72/92] check type --- python/taichi/lang/matrix.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 658b20ab9da2a..d831d22f63507 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -503,7 +503,10 @@ def _element_wise_binary(self, foo, other): def _broadcast_copy(self, other): if isinstance(other, (list, tuple)): - other = type(self)(other) + if is_vector(self): + other = Vector(other, ndim=self.ndim) + else: + other = Matrix(other, ndim=self.ndim) if not isinstance(other, Matrix): if isinstance(self, Vector): other = Vector([other for _ in range(self.n)]) From fece2e7159831f9e0aff577901ae172ec01379de Mon Sep 17 00:00:00 2001 From: AD1024 Date: Sun, 4 Sep 2022 22:39:19 -0400 Subject: [PATCH 73/92] fix intermediate matrix --- python/taichi/_funcs.py | 4 ++-- python/taichi/lang/ast/ast_transformer.py | 7 +++++-- python/taichi/lang/impl.py | 3 ++- python/taichi/lang/matrix.py | 7 ++++--- tests/python/test_scalar_op.py | 4 ++-- 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/python/taichi/_funcs.py b/python/taichi/_funcs.py index 50eeac4098334..1fa9fc3d3ddc9 100644 --- a/python/taichi/_funcs.py +++ b/python/taichi/_funcs.py @@ -90,10 +90,10 @@ def _matrix_outer_product(self, other): """ impl.static( - impl.static_assert(self.m == 1, + impl.static_assert(self.m == 1 and isinstance(self, Vector), "lhs for outer_product is not a vector")) impl.static( - impl.static_assert(other.m == 1, + impl.static_assert(other.m == 1 and isinstance(other, Vector), "rhs for outer_product is not a vector")) return matrix.Matrix([[self[i] * other[j] for j in range(other.n)] for i in range(self.n)]) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 8e6d43905b270..add42c73cc35a 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -662,10 +662,13 @@ def build_Return(ctx, node): ti_ops.cast(expr.Expr(node.value.ptr), ctx.func.return_type).ptr)) elif isinstance(ctx.func.return_type, MatrixType): + item_iter = iter(node.value.ptr.to_list())\ + if isinstance(node.value.ptr, Vector) or node.value.ptr.ndim == 1\ + else itertools.chain.from_iterable(node.value.ptr.to_list()) ctx.ast_builder.create_kernel_exprgroup_return( expr.make_expr_group([ - ti_ops.cast(exp, ctx.func.return_type.dtype) for exp in - itertools.chain.from_iterable(node.value.ptr.to_list()) + ti_ops.cast(exp, ctx.func.return_type.dtype) + for exp in item_iter ])) else: raise TaichiSyntaxError( diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 8eb9002e75b8e..c53f0a57cd1e2 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -63,7 +63,8 @@ def expr_init(rhs): entries = [[rhs(i, j) for j in range(rhs.m)] for i in range(rhs.n)] return make_matrix(entries) - if isinstance(rhs, Vector): + if isinstance(rhs, Vector) or getattr(rhs, "ndim", None) == 1: + # _IntermediateMatrix may reach here return Vector(rhs.to_list(), ndim=rhs.ndim) return Matrix(rhs.to_list(), ndim=rhs.ndim) if isinstance(rhs, SharedArray): diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index d831d22f63507..e65948b0e01a5 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -658,6 +658,8 @@ def to_list(self): This is similar to `numpy.ndarray`'s `flatten` and `ravel` methods, the difference is that this function always returns a new list. """ + if is_vector(self): + return [self(i) for i in range(self.n)] return [[self(i, j) for j in range(self.m)] for i in range(self.n)] @taichi_scope @@ -1493,9 +1495,6 @@ def ndarray(cls, n, dtype, shape, layout=Layout.AOS): shape = (shape, ) return VectorNdarray(n, dtype, shape, layout) - def to_list(self): - return [self(i) for i in range(self.n)] - class _IntermediateMatrix(Matrix): """Intermediate matrix class for compiler internal use only. @@ -1745,6 +1744,8 @@ def __getitem__(self, key): self._initialize_host_accessors() key = self._pad_key(key) _host_access = self._host_access(key) + if self.ndim == 1: + return Vector([_host_access[i] for i in range(self.n)]) return Matrix([[_host_access[i * self.m + j] for j in range(self.m)] for i in range(self.n)], ndim=self.ndim) diff --git a/tests/python/test_scalar_op.py b/tests/python/test_scalar_op.py index 1be7606302f8d..49fd7b16f8f54 100644 --- a/tests/python/test_scalar_op.py +++ b/tests/python/test_scalar_op.py @@ -77,8 +77,8 @@ def test_python_scope_matmul(): ti.init() a = np.array([[1, 2], [3, 4]]) b = np.array([[5, 6], [7, 8]]) - x = ti.Vector(a) - y = ti.Vector(b) + x = ti.Matrix(a) + y = ti.Matrix(b) result = (x @ y).to_numpy() expected = a @ b From b2262891aa1653cdc44ea27aa4557bbd16b260e7 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Sun, 4 Sep 2022 22:58:26 -0400 Subject: [PATCH 74/92] fix field fill --- python/taichi/_kernels.py | 12 ++++++++++-- python/taichi/lang/matrix.py | 7 +++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/python/taichi/_kernels.py b/python/taichi/_kernels.py index 16ce89f0e75d6..f1d81f3856de9 100644 --- a/python/taichi/_kernels.py +++ b/python/taichi/_kernels.py @@ -1,3 +1,5 @@ +from typing import Iterable + from taichi._lib.utils import get_os_name from taichi.lang import ops from taichi.lang._ndrange import ndrange @@ -241,9 +243,15 @@ def fill_matrix(mat: template(), vals: template()): for p in static(range(mat.n)): for q in static(range(mat.m)): if static(mat[I].ndim == 2): - mat[I][p, q] = vals[p][q] + if static(isinstance(vals[p], Iterable)): + mat[I][p, q] = vals[p][q] + else: + mat[I][p, q] = vals[p] else: - mat[I][p] = vals[p][q] + if static(isinstance(vals[p], Iterable)): + mat[I][p] = vals[p][q] + else: + mat[I][p] = vals[p] @kernel diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index e65948b0e01a5..6f348bffbd240 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1620,7 +1620,9 @@ def fill(self, val): elif isinstance(val, (list, tuple)) and isinstance(val[0], numbers.Number): assert self.m == 1 - val = tuple([(v, ) for v in val]) + val = tuple(val) + elif is_vector(val) or self.ndim == 1: + val = tuple([(val(i), ) for i in range(self.n)]) elif isinstance(val, Matrix): val_tuple = [] for i in range(val.n): @@ -1631,7 +1633,8 @@ def fill(self, val): val_tuple.append(row) val = tuple(val_tuple) assert len(val) == self.n - assert len(val[0]) == self.m + if self.ndim != 1: + assert len(val[0]) == self.m if in_python_scope(): from taichi._kernels import fill_matrix # pylint: disable=C0415 From 05e7497188aee1e609236bfe2d39346c2b8eafe6 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Sun, 4 Sep 2022 23:00:51 -0400 Subject: [PATCH 75/92] fix vector_arg --- python/taichi/lang/kernel_arguments.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/taichi/lang/kernel_arguments.py b/python/taichi/lang/kernel_arguments.py index 7d788fdac823e..a286eaabf0178 100644 --- a/python/taichi/lang/kernel_arguments.py +++ b/python/taichi/lang/kernel_arguments.py @@ -7,7 +7,7 @@ from taichi.lang.any_array import AnyArray from taichi.lang.enums import Layout from taichi.lang.expr import Expr -from taichi.lang.matrix import Matrix, MatrixType +from taichi.lang.matrix import Matrix, MatrixType, Vector, VectorType from taichi.lang.util import cook_dtype from taichi.types.primitive_types import RefType, f32, u64 @@ -58,6 +58,9 @@ def decl_scalar_arg(dtype): def decl_matrix_arg(matrixtype): + if isinstance(matrixtype, VectorType): + return Vector( + [decl_scalar_arg(matrixtype.dtype) for _ in range(matrixtype.n)]) return Matrix( [[decl_scalar_arg(matrixtype.dtype) for _ in range(matrixtype.m)] for _ in range(matrixtype.n)], From 2d7471f3ab680efab693b9b07401a4a1920134c8 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Sun, 4 Sep 2022 23:12:14 -0400 Subject: [PATCH 76/92] fix slice --- python/taichi/lang/matrix.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 6f348bffbd240..fc12c441f704a 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -261,8 +261,7 @@ def _subscript(self, is_global_mat, *indices, get_ref=False): is_ref=get_ref) return Matrix([[self._subscript(is_global_mat, a, b) for b in j] for a in i], - is_ref=get_ref, - ndim=1) + is_ref=get_ref) if self.any_array_access: return self.any_array_access.subscript(i, j) From 553f98fa31f0c49bdbc5c29a689ec70971a5bf10 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Mon, 5 Sep 2022 13:31:14 -0400 Subject: [PATCH 77/92] fix transpose for vector --- python/taichi/_funcs.py | 5 ++++- tests/python/test_matrix.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/taichi/_funcs.py b/python/taichi/_funcs.py index 1fa9fc3d3ddc9..5ba3613447fa7 100644 --- a/python/taichi/_funcs.py +++ b/python/taichi/_funcs.py @@ -3,7 +3,7 @@ from taichi.lang import impl, matrix, ops from taichi.lang.impl import expr_init, get_runtime, grouped, static from taichi.lang.kernel_impl import func, pyfunc -from taichi.lang.matrix import Matrix, Vector +from taichi.lang.matrix import Matrix, Vector, is_vector from taichi.types import f32, f64 from taichi.types.annotations import template @@ -59,6 +59,9 @@ def _matrix_transpose(mat): Returns: Transpose of the input matrix. """ + if static(is_vector(mat)): + # Convert to row vector + return matrix.Matrix([[mat(i) for i in range(mat.n)]]) return matrix.Matrix([[mat(i, j) for i in range(mat.n)] for j in range(mat.m)], ndim=mat.ndim) diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 6f6b0ec43356e..88b853aa80e7d 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -544,10 +544,10 @@ def foo(): def test_matrix_dtype(): @ti.kernel def foo(): - a = ti.Vector([[1, 2], [3, 4]], ti.f32) + a = ti.Matrix([[1, 2], [3, 4]], ti.f32) a /= 2 assert all(abs(a - ((0.5, 1.), (1.5, 2.))) < 1e-6) - b = ti.Vector([[1.5, 2.5], [3.5, 4.5]], ti.i32) + b = ti.Matrix([[1.5, 2.5], [3.5, 4.5]], ti.i32) assert all(b == ((1, 2), (3, 4))) foo() From 1a0ab9c824ba1d8b8486f23a8512a196d6ad7ec7 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 6 Sep 2022 00:13:27 -0400 Subject: [PATCH 78/92] save fixes --- cmake/TaichiCore.cmake | 2 +- taichi/codegen/llvm/codegen_llvm.cpp | 71 +++++++++++++++------------- 2 files changed, 39 insertions(+), 34 deletions(-) diff --git a/cmake/TaichiCore.cmake b/cmake/TaichiCore.cmake index 4c57b56796c92..7f20de188c813 100644 --- a/cmake/TaichiCore.cmake +++ b/cmake/TaichiCore.cmake @@ -1,6 +1,6 @@ option(USE_STDCPP "Use -stdlib=libc++" OFF) option(TI_WITH_LLVM "Build with LLVM backends" ON) -option(TI_LLVM_15 "Switch to LLVM 15" OFF) +option(TI_LLVM_15 "Switch to LLVM 15" ON) option(TI_WITH_METAL "Build with the Metal backend" ON) option(TI_WITH_CUDA "Build with the CUDA backend" ON) option(TI_WITH_CUDA_TOOLKIT "Build with the CUDA toolkit" OFF) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index b6f83d22538c2..6871326bc7c7b 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1719,46 +1719,51 @@ void TaskCodeGenLLVM::visit(GetChStmt *stmt) { void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { if (stmt->is_local_ptr()) { -#ifdef TI_LLVM_15 - // FIXME: get ptr_ty from taichi instead of llvm. - llvm::Type *ptr_ty = nullptr; - auto *val = llvm_val[stmt->origin]; - // For SharedArray which is in address space 3. - if (auto *addr_cast = llvm::dyn_cast(val)) - val = addr_cast->getOperand(0); - if (auto *alloc = llvm::dyn_cast(val)) - ptr_ty = alloc->getAllocatedType(); - else if (auto *gv = llvm::dyn_cast(val)) - ptr_ty = gv->getValueType(); - else if (auto *gep = llvm::dyn_cast(val)) - ptr_ty = gep->getResultElementType(); - else if (stmt->origin->is()) { - auto *tmpo_stmt = stmt->origin->cast(); - if (tmpo_stmt->ret_type->is()) { - ptr_ty = tlctx->get_data_type( - tmpo_stmt->ret_type->cast()->get_element_type()); - } else { - ptr_ty = tlctx->get_data_type(tmpo_stmt->ret_type.ptr_removed()); - } - } - TI_ASSERT(ptr_ty); - - llvm_val[stmt] = builder->CreateGEP(ptr_ty, llvm_val[stmt->origin], - llvm_val[stmt->offset]); -#else - if (stmt->origin->ret_type->is() || + TensorType *stmt_dtype = nullptr; + if (stmt->origin->ret_type->is() || (stmt->origin->ret_type->is() && stmt->origin->ret_type->cast() ->get_pointee_type() ->is())) { - TensorType *stmt_dtype; - if (stmt->origin->ret_type->is()) { - stmt_dtype = stmt->origin->ret_type->cast() + if (stmt->origin->ret_type->is()) { + stmt_dtype = stmt->origin->ret_type->cast() ->get_pointee_type() ->cast(); - } else { - stmt_dtype = stmt->origin->ret_type->cast(); + } else { + stmt_dtype = stmt->origin->ret_type->cast(); + } + } +#ifdef TI_LLVM_15 + // FIXME: get ptr_ty from taichi instead of llvm. + llvm::Type *ptr_ty = nullptr; + auto *val = llvm_val[stmt->origin]; + // For SharedArray which is in address space 3. + if (auto *addr_cast = llvm::dyn_cast(val)) + val = addr_cast->getOperand(0); + if (auto *alloc = llvm::dyn_cast(val)) + ptr_ty = alloc->getAllocatedType(); + else if (auto *gv = llvm::dyn_cast(val)) + ptr_ty = gv->getValueType(); + else if (auto *gep = llvm::dyn_cast(val)) + ptr_ty = gep->getResultElementType(); + else if (stmt->origin->is()) { + if (stmt->origin->ret_type->is()) { + ptr_ty = tlctx->get_data_type( + stmt->origin->ret_type->cast()->get_element_type()); + } else { + ptr_ty = tlctx->get_data_type(stmt->origin->ret_type.ptr_removed()); + } } + // else if (stmt->origin->is() && + // stmt->origin->ret_type->is()) { + // ptr_ty = tlctx->get_data_type(stmt_dtype); + // } + TI_ASSERT(ptr_ty); + + llvm_val[stmt] = builder->CreateGEP(ptr_ty, llvm_val[stmt->origin], + llvm_val[stmt->offset]); +#else + if (stmt_dtype) { auto element_dtype = stmt_dtype->get_element_type(); auto llvm_type = tlctx->get_data_type(element_dtype); auto casted_ptr = builder->CreateBitCast( From 9e2116360aa700c7232268b175fd5009b53bbd51 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 6 Sep 2022 02:30:23 -0400 Subject: [PATCH 79/92] fix for global tmp var --- taichi/codegen/llvm/codegen_llvm.cpp | 85 +++++++++++++++------------- taichi/transforms/type_check.cpp | 3 +- 2 files changed, 49 insertions(+), 39 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 6871326bc7c7b..c02dd8cbcf769 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1718,52 +1718,53 @@ void TaskCodeGenLLVM::visit(GetChStmt *stmt) { } void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { + auto is_tensor_or_ptr = [](DataType dt) { + return dt->is() || + (dt->is() && + dt->cast()->get_pointee_type()->is()); + }; + auto get_tensor_type = [](DataType dt) -> TensorType * { + return dt.ptr_removed()->cast(); + }; if (stmt->is_local_ptr()) { - TensorType *stmt_dtype = nullptr; - if (stmt->origin->ret_type->is() || - (stmt->origin->ret_type->is() && - stmt->origin->ret_type->cast() - ->get_pointee_type() - ->is())) { - if (stmt->origin->ret_type->is()) { - stmt_dtype = stmt->origin->ret_type->cast() - ->get_pointee_type() - ->cast(); - } else { - stmt_dtype = stmt->origin->ret_type->cast(); - } - } #ifdef TI_LLVM_15 - // FIXME: get ptr_ty from taichi instead of llvm. - llvm::Type *ptr_ty = nullptr; - auto *val = llvm_val[stmt->origin]; - // For SharedArray which is in address space 3. - if (auto *addr_cast = llvm::dyn_cast(val)) - val = addr_cast->getOperand(0); - if (auto *alloc = llvm::dyn_cast(val)) + // FIXME: get ptr_ty from taichi instead of llvm. + llvm::Type *ptr_ty = nullptr; + auto *val = llvm_val[stmt->origin]; + // For SharedArray which is in address space 3. + if (auto *addr_cast = llvm::dyn_cast(val)) + val = addr_cast->getOperand(0); + if (auto *alloc = llvm::dyn_cast(val)) + if (!stmt->origin->ret_type->is()) ptr_ty = alloc->getAllocatedType(); - else if (auto *gv = llvm::dyn_cast(val)) - ptr_ty = gv->getValueType(); - else if (auto *gep = llvm::dyn_cast(val)) - ptr_ty = gep->getResultElementType(); - else if (stmt->origin->is()) { - if (stmt->origin->ret_type->is()) { + if (!ptr_ty) { + if (auto *gv = llvm::dyn_cast(val)) + if (!stmt->origin->ret_type->is()) + ptr_ty = gv->getValueType(); + if (auto *gep = llvm::dyn_cast(val)) + if (!stmt->origin->ret_type->is()) + ptr_ty = gep->getResultElementType(); + if (stmt->origin->is()) { + if (is_tensor_or_ptr(stmt->origin->ret_type)) { ptr_ty = tlctx->get_data_type( - stmt->origin->ret_type->cast()->get_element_type()); + get_tensor_type(stmt->origin->ret_type)->get_element_type()); + val = builder->CreateBitCast(val, llvm::PointerType::get(ptr_ty, 0)); } else { ptr_ty = tlctx->get_data_type(stmt->origin->ret_type.ptr_removed()); } + } else if (is_tensor_or_ptr(stmt->origin->ret_type)) { + auto dtype = + get_tensor_type(stmt->origin->ret_type)->get_element_type(); + ptr_ty = tlctx->get_data_type(dtype); + val = builder->CreateBitCast(val, llvm::PointerType::get(ptr_ty, 0)); } - // else if (stmt->origin->is() && - // stmt->origin->ret_type->is()) { - // ptr_ty = tlctx->get_data_type(stmt_dtype); - // } - TI_ASSERT(ptr_ty); - - llvm_val[stmt] = builder->CreateGEP(ptr_ty, llvm_val[stmt->origin], - llvm_val[stmt->offset]); + } + TI_ASSERT(ptr_ty); + + llvm_val[stmt] = builder->CreateGEP(ptr_ty, val, llvm_val[stmt->offset]); #else - if (stmt_dtype) { + if (is_tensor_or_ptr(stmt->origin->ret_type)) { + auto stmt_dtype = get_tensor_type(stmt->origin->ret_type); auto element_dtype = stmt_dtype->get_element_type(); auto llvm_type = tlctx->get_data_type(element_dtype); auto casted_ptr = builder->CreateBitCast( @@ -2291,16 +2292,24 @@ void TaskCodeGenLLVM::visit(GlobalTemporaryStmt *stmt) { auto buffer = call("get_temporary_pointer", runtime, tlctx->get_constant((int64)stmt->offset)); - if (stmt->ret_type->is()) { + if (stmt->ret_type->is() && !prog->config.real_matrix) { auto ptr_type = llvm::PointerType::get( tlctx->get_data_type( stmt->ret_type->cast()->get_element_type()), 0); llvm_val[stmt] = builder->CreatePointerCast(buffer, ptr_type); } else { + // if (prog->config.real_matrix && + // stmt->ret_type.ptr_removed()->is()) { + // auto tensor_type = stmt->ret_type.ptr_removed()->cast(); + // auto ptr_type = llvm::PointerType::get( + // tlctx->get_data_type(tensor_type->get_element_type()), 0); + // llvm_val[stmt] = builder->CreatePointerCast(buffer, ptr_type); + // } else { auto ptr_type = llvm::PointerType::get( tlctx->get_data_type(stmt->ret_type.ptr_removed()), 0); llvm_val[stmt] = builder->CreatePointerCast(buffer, ptr_type); + // } } } diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 7d43e47cfb7bb..477f8d803f14f 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -506,7 +506,8 @@ class TypeCheck : public IRVisitor { } void visit(GlobalTemporaryStmt *stmt) override { - stmt->ret_type.set_is_pointer(true); + if (!stmt->ret_type->is() || config_.real_matrix) + stmt->ret_type.set_is_pointer(true); } void visit(InternalFuncStmt *stmt) override { From 2bb10b75183319f13a4a38400f2389d70131596d Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 6 Sep 2022 02:48:48 -0400 Subject: [PATCH 80/92] revert local change --- cmake/TaichiCore.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/TaichiCore.cmake b/cmake/TaichiCore.cmake index 7f20de188c813..4c57b56796c92 100644 --- a/cmake/TaichiCore.cmake +++ b/cmake/TaichiCore.cmake @@ -1,6 +1,6 @@ option(USE_STDCPP "Use -stdlib=libc++" OFF) option(TI_WITH_LLVM "Build with LLVM backends" ON) -option(TI_LLVM_15 "Switch to LLVM 15" ON) +option(TI_LLVM_15 "Switch to LLVM 15" OFF) option(TI_WITH_METAL "Build with the Metal backend" ON) option(TI_WITH_CUDA "Build with the CUDA backend" ON) option(TI_WITH_CUDA_TOOLKIT "Build with the CUDA toolkit" OFF) From 504808a03b99c93322d0ce6f78a399a34a9f0411 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 6 Sep 2022 16:28:29 -0400 Subject: [PATCH 81/92] fix test_scan --- taichi/codegen/llvm/codegen_llvm.cpp | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index c02dd8cbcf769..82ac0a412afe2 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1731,24 +1731,27 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { // FIXME: get ptr_ty from taichi instead of llvm. llvm::Type *ptr_ty = nullptr; auto *val = llvm_val[stmt->origin]; + auto *lhs = val; + auto fall_through = + stmt->origin->ret_type->is() && prog->config.real_matrix; // For SharedArray which is in address space 3. if (auto *addr_cast = llvm::dyn_cast(val)) val = addr_cast->getOperand(0); if (auto *alloc = llvm::dyn_cast(val)) - if (!stmt->origin->ret_type->is()) + if (!fall_through) ptr_ty = alloc->getAllocatedType(); if (!ptr_ty) { if (auto *gv = llvm::dyn_cast(val)) - if (!stmt->origin->ret_type->is()) + if (!fall_through) ptr_ty = gv->getValueType(); if (auto *gep = llvm::dyn_cast(val)) - if (!stmt->origin->ret_type->is()) + if (!fall_through) ptr_ty = gep->getResultElementType(); if (stmt->origin->is()) { if (is_tensor_or_ptr(stmt->origin->ret_type)) { ptr_ty = tlctx->get_data_type( get_tensor_type(stmt->origin->ret_type)->get_element_type()); - val = builder->CreateBitCast(val, llvm::PointerType::get(ptr_ty, 0)); + lhs = builder->CreateBitCast(lhs, llvm::PointerType::get(ptr_ty, 0)); } else { ptr_ty = tlctx->get_data_type(stmt->origin->ret_type.ptr_removed()); } @@ -1756,12 +1759,12 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { auto dtype = get_tensor_type(stmt->origin->ret_type)->get_element_type(); ptr_ty = tlctx->get_data_type(dtype); - val = builder->CreateBitCast(val, llvm::PointerType::get(ptr_ty, 0)); + lhs = builder->CreateBitCast(lhs, llvm::PointerType::get(ptr_ty, 0)); } } TI_ASSERT(ptr_ty); - llvm_val[stmt] = builder->CreateGEP(ptr_ty, val, llvm_val[stmt->offset]); + llvm_val[stmt] = builder->CreateGEP(ptr_ty, lhs, llvm_val[stmt->offset]); #else if (is_tensor_or_ptr(stmt->origin->ret_type)) { auto stmt_dtype = get_tensor_type(stmt->origin->ret_type); @@ -2299,17 +2302,9 @@ void TaskCodeGenLLVM::visit(GlobalTemporaryStmt *stmt) { 0); llvm_val[stmt] = builder->CreatePointerCast(buffer, ptr_type); } else { - // if (prog->config.real_matrix && - // stmt->ret_type.ptr_removed()->is()) { - // auto tensor_type = stmt->ret_type.ptr_removed()->cast(); - // auto ptr_type = llvm::PointerType::get( - // tlctx->get_data_type(tensor_type->get_element_type()), 0); - // llvm_val[stmt] = builder->CreatePointerCast(buffer, ptr_type); - // } else { auto ptr_type = llvm::PointerType::get( tlctx->get_data_type(stmt->ret_type.ptr_removed()), 0); llvm_val[stmt] = builder->CreatePointerCast(buffer, ptr_type); - // } } } From 70c49c36cb3ccbf25cda22a3ed19324c17e2bf85 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 6 Sep 2022 16:29:47 -0400 Subject: [PATCH 82/92] remove dup arch --- tests/python/test_matrix.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 88b853aa80e7d..95f99b237ac21 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -701,7 +701,7 @@ def bar(): bar() -@test_utils.test(arch=[ti.cuda, ti.cpu, ti.gpu], real_matrix=True) +@test_utils.test(arch=[ti.cuda, ti.cpu], real_matrix=True) def test_local_matrix_read(): s = ti.field(ti.i32, shape=()) @@ -717,7 +717,7 @@ def get_index(i: ti.i32, j: ti.i32): assert s[None] == i * 3 + j -@test_utils.test(arch=[ti.cuda, ti.cpu, ti.gpu], real_matrix=True) +@test_utils.test(arch=[ti.cuda, ti.cpu], real_matrix=True) def test_local_matrix_indexing_in_loop(): @ti.kernel def test(): @@ -729,7 +729,7 @@ def test(): test() -@test_utils.test(arch=[ti.cuda, ti.cpu, ti.gpu], real_matrix=True) +@test_utils.test(arch=[ti.cuda, ti.cpu], real_matrix=True) def test_local_matrix_indexing_ops(): @ti.kernel def basic_ops(): From 22bd0933a3a5480e53784f465c60b3eb6ba7d7ac Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 6 Sep 2022 16:50:17 -0400 Subject: [PATCH 83/92] refine test case --- tests/python/test_matrix.py | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 95f99b237ac21..2167f532a0c0c 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -719,25 +719,51 @@ def get_index(i: ti.i32, j: ti.i32): @test_utils.test(arch=[ti.cuda, ti.cpu], real_matrix=True) def test_local_matrix_indexing_in_loop(): + s = ti.field(ti.i32, shape=(3, 3)) + @ti.kernel def test(): mat = ti.Matrix([[x * 3 + y for y in range(3)] for x in range(3)]) for i in range(3): for j in range(3): - assert mat[i, j] == i * 3 + j + s[i, j] = mat[i, j] + 1 test() + for i in range(3): + for j in range(3): + assert s[i, j] == i * 3 + j + 1 @test_utils.test(arch=[ti.cuda, ti.cpu], real_matrix=True) def test_local_matrix_indexing_ops(): @ti.kernel - def basic_ops(): + def element_write() -> ti.i32: mat = ti.Matrix([[x * 3 + y for y in range(3)] for x in range(3)]) s = 0 for i in range(3): for j in range(3): + mat[i, j] = 10 s += mat[i, j] - assert s == 72 + return s + + f = ti.field(ti.i32, shape=(3, 3)) - basic_ops() + @ti.kernel + def assign_from_index(): + mat = ti.Matrix([[x * 3 + y for y in range(3)] for x in range(3)]) + result = ti.Matrix([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) + # TODO: fix parallelization + ti.loop_config(serialize=True) + for i in range(3): + for j in range(3): + result[i, j] = mat[j, i] + for i in range(3): + for j in range(3): + f[i, j] = result[i, j] + + assert element_write() == 90 + assign_from_index() + xs = [[x * 3 + y for y in range(3)] for x in range(3)] + for i in range(3): + for j in range(3): + assert f[i, j] == xs[j][i] From f623d5791d48bdeaf91876225fd1410ffda8b04f Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 6 Sep 2022 17:23:08 -0400 Subject: [PATCH 84/92] fill in type if idexpr is already type-checked --- taichi/ir/frontend_ir.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 7dc9d057240bf..9963f5e5d16c5 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -424,9 +424,6 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, std::vector shape, int stride) { flatten_lvalue(var, ctx); - if (var->stmt->ret_type->is_primitive(PrimitiveTypeID::unknown)) { - var->stmt->ret_type = var->ret_type; - } Stmt *offset_stmt = ctx->push_back(TypedConstant(0)); for (int i = 0; i < (int)indices.size(); ++i) { flatten_rvalue(indices[i], ctx); @@ -453,7 +450,6 @@ void MatrixExpression::type_check(CompileConfig *config) { } void MatrixExpression::flatten(FlattenContext *ctx) { - // TODO: implement flatten TI_ASSERT(this->dt->is()); std::vector values; for (auto &elt : elements) { @@ -605,6 +601,9 @@ void LoopUniqueExpression::flatten(FlattenContext *ctx) { void IdExpression::flatten(FlattenContext *ctx) { stmt = ctx->current_block->lookup_var(id); + if (!ret_type->is_primitive(PrimitiveTypeID::unknown)) { + stmt->ret_type = ret_type; + } } void AtomicOpExpression::type_check(CompileConfig *) { @@ -1037,7 +1036,9 @@ Expr ASTBuilder::expr_alloca() { Expr ASTBuilder::make_matrix_expr(const std::vector &shape, const DataType &dt, const std::vector &elements) { - return Expr(std::make_shared(elements, shape, dt)); + auto mat = Expr(std::make_shared(elements, shape, dt)); + mat->ret_type = dt; + return mat; } Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, From 18f121a518a9eab811f6f17b985f3bcb12d57764 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 6 Sep 2022 20:16:21 -0400 Subject: [PATCH 85/92] rename func && modify docstring --- python/taichi/_funcs.py | 6 +++--- python/taichi/lang/matrix.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/taichi/_funcs.py b/python/taichi/_funcs.py index 5ba3613447fa7..40a9f3b097cc6 100644 --- a/python/taichi/_funcs.py +++ b/python/taichi/_funcs.py @@ -82,11 +82,11 @@ def _matrix_cross2d(self, other): @pyfunc -def _matrix_outer_product(self, other): - """Perform the outer product with the input Vector (1-D Matrix). +def _vector_outer_product(self, other): + """Perform the outer product with the input Vector. Args: - other (:class:`~taichi.lang.matrix.Matrix`): The input Vector (1-D Matrix) to perform the outer product. + other (:class:`~taichi.lang.matrix.Vector`): The input Vector to perform the outer product. Returns: :class:`~taichi.lang.matrix.Matrix`: The outer product result (Matrix) of the two Vectors. diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index fc12c441f704a..b9b6a7b4f1c1a 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1440,8 +1440,8 @@ def outer_product(self, other): :class:`~taichi.Matrix`: The outer product of the two Vectors. """ from taichi._funcs import \ - _matrix_outer_product # pylint: disable=C0415 - return _matrix_outer_product(self, other) + _vector_outer_product # pylint: disable=C0415 + return _vector_outer_product(self, other) class Vector(Matrix): From 3ace1a17d9460d5aae80d5af26ea1e8311899bbf Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 6 Sep 2022 21:06:21 -0400 Subject: [PATCH 86/92] merge changes (except llvm 15 part) --- taichi/codegen/llvm/codegen_llvm.cpp | 22 +++++++++++++++++----- taichi/transforms/type_check.cpp | 5 +++++ 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 66fb1299d72d9..582e3af1fcfd9 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1729,19 +1729,25 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { auto get_tensor_type = [](DataType dt) -> TensorType * { return dt.ptr_removed()->cast(); }; - if (stmt->offset_used_as_index()) {} + if (stmt->offset_used_as_index()) { #ifdef TI_LLVM_15 // FIXME: get ptr_ty from taichi instead of llvm. llvm::Type *ptr_ty = nullptr; auto *val = llvm_val[stmt->origin]; auto *lhs = val; - auto fall_through = - stmt->origin->ret_type->is() && prog->config.real_matrix; // For SharedArray which is in address space 3. if (auto *addr_cast = llvm::dyn_cast(val)) val = addr_cast->getOperand(0); if (auto *alloc = llvm::dyn_cast(val)) + if (stmt->origin->ret_type.ptr_removed()->is()) { + ptr_ty = stmt->origin->ret_type.ptr_removed() + ->cast() + ->get_element_type(); + lhs = builder->CreatePointerCast( + lhs, llvm::PointerType::get(tlctx->get_data_type(ptr_ty), 0)); + } else { ptr_ty = alloc->getAllocatedType(); + } else if (auto *gv = llvm::dyn_cast(val)) ptr_ty = gv->getValueType(); else if (stmt->origin->is()) { @@ -1779,8 +1785,14 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { https://llvm.org/doxygen/classllvm_1_1AllocaInst.html#ac68a7586b8be7de3c39531d9eca902e6 */ if (stmt->tensor_type_represented_as_primitive_type_ptr()) { - llvm_val[stmt] = - builder->CreateGEP(llvm_val[stmt->origin], llvm_val[stmt->offset]); + auto element_type = stmt->origin->ret_type.ptr_removed() + ->as() + ->get_element_type(); + auto element_ptr = + llvm::PointerType::get(tlctx->get_data_type(element_type), 0); + auto val = + builder->CreatePointerCast(llvm_val[stmt->origin], element_ptr); + llvm_val[stmt] = builder->CreateGEP(val, llvm_val[stmt->offset]); } else { llvm_val[stmt] = builder->CreateGEP(llvm_val[stmt->origin], diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 477f8d803f14f..9cd42bba1b38e 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -506,6 +506,11 @@ class TypeCheck : public IRVisitor { } void visit(GlobalTemporaryStmt *stmt) override { + /** + * We need to convert TensorType to pointer when + * real_matrix is enabled because one can store value + * in a loop to a tensor defined outside the loop + */ if (!stmt->ret_type->is() || config_.real_matrix) stmt->ret_type.set_is_pointer(true); } From 1cd3d5eca36e293922d49f33d57aa55f6a1cd7d0 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 6 Sep 2022 21:18:20 -0400 Subject: [PATCH 87/92] fix compilation --- taichi/codegen/llvm/codegen_llvm.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 582e3af1fcfd9..a9f7ab7bff1e7 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1740,11 +1740,11 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { val = addr_cast->getOperand(0); if (auto *alloc = llvm::dyn_cast(val)) if (stmt->origin->ret_type.ptr_removed()->is()) { - ptr_ty = stmt->origin->ret_type.ptr_removed() + ptr_ty = tlctx->get_data_type(stmt->origin->ret_type.ptr_removed() ->cast() - ->get_element_type(); + ->get_element_type()); lhs = builder->CreatePointerCast( - lhs, llvm::PointerType::get(tlctx->get_data_type(ptr_ty), 0)); + lhs, llvm::PointerType::get(ptr_ty, 0)); } else { ptr_ty = alloc->getAllocatedType(); } From ad27eb0e61ab56f041d3d20aaa59fb645e7b2cc4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Sep 2022 01:19:55 +0000 Subject: [PATCH 88/92] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/codegen/llvm/codegen_llvm.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index a9f7ab7bff1e7..32528f9c758b6 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1741,10 +1741,10 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { if (auto *alloc = llvm::dyn_cast(val)) if (stmt->origin->ret_type.ptr_removed()->is()) { ptr_ty = tlctx->get_data_type(stmt->origin->ret_type.ptr_removed() - ->cast() - ->get_element_type()); - lhs = builder->CreatePointerCast( - lhs, llvm::PointerType::get(ptr_ty, 0)); + ->cast() + ->get_element_type()); + lhs = + builder->CreatePointerCast(lhs, llvm::PointerType::get(ptr_ty, 0)); } else { ptr_ty = alloc->getAllocatedType(); } From 9ba4a03787866a7b48b9500838a3474b63aff890 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 6 Sep 2022 21:36:20 -0400 Subject: [PATCH 89/92] unused --- taichi/codegen/llvm/codegen_llvm.cpp | 8 -------- 1 file changed, 8 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index a9f7ab7bff1e7..3fa21a3998781 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1721,14 +1721,6 @@ void TaskCodeGenLLVM::visit(GetChStmt *stmt) { } void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { - auto is_tensor_or_ptr = [](DataType dt) { - return dt->is() || - (dt->is() && - dt->cast()->get_pointee_type()->is()); - }; - auto get_tensor_type = [](DataType dt) -> TensorType * { - return dt.ptr_removed()->cast(); - }; if (stmt->offset_used_as_index()) { #ifdef TI_LLVM_15 // FIXME: get ptr_ty from taichi instead of llvm. From 595d5af068b03690a893c623e23e5baea5d4d43a Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 6 Sep 2022 22:26:24 -0400 Subject: [PATCH 90/92] fix GlobalTmpVar in loop --- taichi/codegen/llvm/codegen_llvm.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index ad5ff6d76ea57..53f859cdab6dd 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1747,9 +1747,12 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { } else if (auto *gep = llvm::dyn_cast(val)) ptr_ty = gep->getResultElementType(); else if (stmt->origin->is()) { - if (stmt->origin->ret_type->is()) { - ptr_ty = tlctx->get_data_type( - stmt->origin->ret_type->cast()->get_element_type()); + if (stmt->origin->ret_type.ptr_removed()->is()) { + ptr_ty = tlctx->get_data_type(stmt->origin->ret_type.ptr_removed() + ->cast() + ->get_element_type()); + lhs = + builder->CreatePointerCast(lhs, llvm::PointerType::get(ptr_ty, 0)); } else { ptr_ty = tlctx->get_data_type(stmt->origin->ret_type.ptr_removed()); } @@ -1757,8 +1760,7 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { TI_ASSERT(ptr_ty); if (stmt->tensor_type_represented_as_primitive_type_ptr()) { - llvm_val[stmt] = builder->CreateGEP(ptr_ty, llvm_val[stmt->origin], - llvm_val[stmt->offset]); + llvm_val[stmt] = builder->CreateGEP(ptr_ty, lhs, llvm_val[stmt->offset]); } else { llvm_val[stmt] = builder->CreateGEP(ptr_ty, llvm_val[stmt->origin], From 4b1a619e53ba03cec7b1d2ca4a67139d524cdf30 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 6 Sep 2022 23:09:46 -0400 Subject: [PATCH 91/92] add index mismatch tests --- taichi/ir/frontend_ir.cpp | 15 ++------------- tests/python/test_matrix.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 41f004c08fb9b..bbd290f8c6749 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -501,19 +501,8 @@ void IndexExpression::type_check(CompileConfig *) { } else if (is_tensor()) { // local tensor auto shape = var->ret_type->as()->get_shape(); if (indices.size() != shape.size()) { - std::string shape_str = "["; - if (shape.size() > 0) { - shape_str += std::to_string(shape[0]); - for (int i = 1; i < shape.size(); i++) { - shape_str += ", " + std::to_string(shape[i]); - } - } - shape_str += "]"; - TI_ERROR( - "Indexed matrix of shape {} has wrong number of indices. Expected {} " - "but got " - "{}.", - shape_str, shape.size(), indices.size()); + TI_ERROR("Expected {} indices, but got {}.", shape.size(), + indices.size()); } ret_type = var->ret_type->cast()->get_element_type(); } else { diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 2167f532a0c0c..2bc18f564eb86 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -767,3 +767,24 @@ def assign_from_index(): for i in range(3): for j in range(3): assert f[i, j] == xs[j][i] + + +@test_utils.test(arch=[ti.cuda, ti.cpu], real_matrix=True) +def test_local_matrix_index_check(): + @ti.kernel + def foo(): + mat = ti.Matrix([[1, 2, 3], [4, 5, 6]]) + print(mat[0]) + + with pytest.raises(TaichiCompilationError, + match=r'Expected 2 indices, but got 1'): + foo() + + @ti.kernel + def bar(): + vec = ti.Vector([1, 2, 3, 4]) + print(vec[0, 0]) + + with pytest.raises(TaichiCompilationError, + match=r'Expected 1 indices, but got 2'): + bar() From 19ec61f17464ab2104c31da67d1ceb1af4f00866 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 13 Sep 2022 16:16:22 -0400 Subject: [PATCH 92/92] remove unused code --- python/taichi/lang/impl.py | 3 --- taichi/ir/frontend_ir.cpp | 1 - 2 files changed, 4 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 1c66db07280da..ad76fc019c5d3 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -139,9 +139,6 @@ def subscript(value, *_indices, skip_reordered=False, get_ref=False): if isinstance(value, np.ndarray): return value.__getitem__(_indices) - if isinstance(value, Expr): - return make_index_expr(value.ptr, _indices) - if isinstance(value, (tuple, list, dict)): assert len(_indices) == 1 return value[_indices[0]] diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index bbd290f8c6749..bfe32d5d4cd18 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -1026,7 +1026,6 @@ Expr ASTBuilder::make_matrix_expr(const std::vector &shape, const DataType &dt, const std::vector &elements) { auto mat = Expr(std::make_shared(elements, shape, dt)); - mat->ret_type = dt; return mat; }