From 4ce08c87305de093850ba22518fe14dd1d6159ca Mon Sep 17 00:00:00 2001 From: AD1024 Date: Mon, 15 Aug 2022 13:39:32 -0400 Subject: [PATCH 01/40] cherrypick Matrix repr support --- python/taichi/lang/ast/ast_transformer.py | 7 ++++- python/taichi/lang/impl.py | 6 ++++ python/taichi/lang/matrix.py | 16 ++++++++++ taichi/analysis/data_source_analysis.cpp | 6 +++- taichi/analysis/gen_offline_cache_key.cpp | 9 ++++++ taichi/analysis/offline_cache_util.cpp | 1 + taichi/analysis/same_statements.cpp | 18 +++++++++++ taichi/codegen/llvm/codegen_llvm.cpp | 37 ++++++++++++++++++----- taichi/codegen/llvm/codegen_llvm.h | 2 ++ taichi/inc/expressions.inc.h | 1 + taichi/inc/statements.inc.h | 1 + taichi/ir/control_flow_graph.cpp | 3 +- taichi/ir/expression_printer.h | 8 +++++ taichi/ir/frontend_ir.cpp | 26 ++++++++++++++++ taichi/ir/frontend_ir.h | 23 ++++++++++++++ taichi/ir/statements.h | 12 ++++++++ taichi/ir/type.cpp | 8 +++++ taichi/ir/type.h | 2 ++ taichi/ir/type_utils.cpp | 32 ++++++++++++++++++++ taichi/program/compile_config.cpp | 1 + taichi/program/compile_config.h | 1 + taichi/python/export_lang.cpp | 2 ++ taichi/runtime/llvm/llvm_context.cpp | 4 +++ taichi/transforms/alg_simp.cpp | 28 +++++++++++++++-- taichi/transforms/die.cpp | 7 +++++ taichi/transforms/ir_printer.cpp | 13 ++++++++ taichi/transforms/type_check.cpp | 20 ++++++++++++ 27 files changed, 281 insertions(+), 13 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 1d30ffceae77d..aeca62a0c7421 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -14,7 +14,7 @@ from taichi.lang.ast.symbol_resolver import ASTResolver from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError from taichi.lang.field import Field -from taichi.lang.matrix import (Matrix, MatrixType, _PyScopeMatrixImpl, +from taichi.lang.matrix import (Matrix, Vector, MatrixType, _PyScopeMatrixImpl, _TiScopeMatrixImpl) from taichi.lang.snode import append from taichi.lang.util import in_taichi_scope, is_taichi_class, to_taichi_type @@ -487,6 +487,11 @@ def build_Call(ctx, node): args.insert(0, node.func.value.ptr) node.ptr = impl.ti_format(*args, **keywords) return node.ptr + + if isinstance(node.func, + ast.Attribute) and func == Matrix or func == Vector: + node.ptr = matrix.make_matrix(*args, **keywords) + return node.ptr if ASTTransformer.build_call_if_is_builtin(ctx, node, args, keywords): return node.ptr diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 1cf13596ac70f..d9ab715b28dad 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -35,6 +35,12 @@ def expr_init_local_tensor(shape, element_type, elements): shape, element_type, elements) +@taichi_scope +def expr_init_matrix(shape, element_type, elements): + return get_runtime().prog.current_ast_builder().expr_alloca_matrix( + shape, element_type, elements) + + @taichi_scope def expr_init_shared_array(shape, element_type): return get_runtime().prog.current_ast_builder().expr_alloca_shared_array( diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index e72831224738b..d694659de0975 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -97,6 +97,22 @@ def prop_setter(instance, value): return cls +def make_matrix(arr, dt=None, suppress_warning=False, is_ref=False, **kwargs): + if not impl.current_cfg().real_matrix or in_python_scope(): + return Matrix(arr, dt, suppress_warning, is_ref, **kwargs) + cast = (lambda x: ops_mod.cast(x, dt)) if dt else ( + lambda x: x if isinstance(x, expr.Expr) else expr.Expr(x)) + if len(arr) == 0: + return impl.expr_init(impl.expr_init_matrix([0], dt, [])) + if not isinstance(arr[0], Iterable): + return impl.expr_init( + impl.expr_init_matrix([len(arr)], dt, + [cast(elt).ptr for elt in arr])) + return impl.expr_init( + impl.expr_init_matrix([len(arr), len(arr[0])], dt, + [cast(elt).ptr for row in arr for elt in row])) + + class _MatrixBaseImpl: def __init__(self, m, n, entries): self.m = m diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index 4a018afa6bf47..9baad69fb80fb 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -37,6 +37,10 @@ std::vector get_load_pointers(Stmt *load_stmt) { return external_func->arg_stmts; } else if (auto ref = load_stmt->cast()) { return {ref->var}; + } else if (auto matrix_init = load_stmt->cast()) { + return matrix_init->values; + } else if (auto ptr_offset = load_stmt->cast()) { + return {ptr_offset->origin}; } else { return std::vector(); } @@ -59,7 +63,7 @@ Stmt *get_store_data(Stmt *store_stmt) { std::vector get_store_destination(Stmt *store_stmt) { // If store_stmt provides some data sources, return the pointers of the data. - if (store_stmt->is() && !store_stmt->ret_type->is()) { + if (store_stmt->is()) { // The statement itself provides a data source (const [0]). return std::vector(1, store_stmt); } else if (auto local_store = store_stmt->cast()) { diff --git a/taichi/analysis/gen_offline_cache_key.cpp b/taichi/analysis/gen_offline_cache_key.cpp index d1dfd5166c8ba..06c4a012e02b3 100644 --- a/taichi/analysis/gen_offline_cache_key.cpp +++ b/taichi/analysis/gen_offline_cache_key.cpp @@ -165,6 +165,15 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { emit(expr->indices.exprs); } + void visit(MatrixExpression *expr) override { + emit(ExprOpCode::MatrixExpression); + emit(expr->dt); + for (auto elt : expr->elements) { + emit(elt); + } + } + + void visit(StrideExpression *expr) override { emit(ExprOpCode::StrideExpression); emit(expr->var); diff --git a/taichi/analysis/offline_cache_util.cpp b/taichi/analysis/offline_cache_util.cpp index 19f8a32c01dbf..28263b0f0c3d0 100644 --- a/taichi/analysis/offline_cache_util.cpp +++ b/taichi/analysis/offline_cache_util.cpp @@ -65,6 +65,7 @@ static std::vector get_offline_cache_key_of_compile_config( serializer(config->demote_no_access_mesh_fors); serializer(config->experimental_auto_mesh_local); serializer(config->auto_mesh_local_default_occupacy); + serializer(config->real_matrix); serializer.finalize(); return serializer.data; diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index c26957e81f4a8..5f908603c9497 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -196,6 +196,24 @@ class IRNodeComparator : public IRVisitor { basic_check(stmt); } + void visit(MatrixInitStmt *stmt) override { + basic_check(stmt); + if (!same) + return; + auto o = other_node_->as(); + if (stmt->values.size() != o->values.size()) { + same = false; + return; + } + for (int i = 0; i < stmt->values.size(); ++i) { + other_node_ = o->values[i]; + stmt->values[i]->accept(this); + other_node_ = o; + if (!same) + return; + } + } + void visit(IfStmt *stmt) override { basic_check(stmt); if (!same) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 771f9f055dbf0..5b6d31a345f70 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -124,7 +124,7 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) { void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { if (stmt->ret_type->is()) { auto tensor_type = stmt->ret_type->cast(); - auto type = tlctx->get_data_type(tensor_type->get_element_type()); + auto type = tlctx->get_data_type(tensor_type); auto array_size = tlctx->get_constant(tensor_type->get_num_elements()); // Return type is [array_size x type]*. if (stmt->is_shared) { @@ -688,6 +688,11 @@ llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) { return llvm::Type::getDoubleTy(*llvm_context); } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*llvm_context); + } else if (dt->is()) { + auto tensor_type = dt->cast(); + auto element_type = llvm_type(tensor_type->get_element_type()); + return llvm::VectorType::get(element_type, tensor_type->get_num_elements(), + false); } else { TI_NOT_IMPLEMENTED; } @@ -800,12 +805,20 @@ void TaskCodeGenLLVM::visit(PrintStmt *stmt) { if (std::holds_alternative(content)) { auto arg_stmt = std::get(content); auto value = llvm_val[arg_stmt]; - if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) || - arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) - value = builder->CreateFPExt(value, - tlctx->get_data_type(PrimitiveType::f64)); - args.push_back(value); - formats += data_type_format(arg_stmt->ret_type); + if (arg_stmt->ret_type->is()) { + auto dtype = arg_stmt->ret_type->cast(); + for (int i = 0; i < dtype->get_num_elements(); ++i) { + args.push_back(builder->CreateExtractElement(value, i)); + } + formats += data_type_format(arg_stmt->ret_type); + } else { + if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) || + arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) + value = builder->CreateFPExt( + value, tlctx->get_data_type(PrimitiveType::f64)); + args.push_back(value); + formats += data_type_format(arg_stmt->ret_type); + } } else { auto arg_str = std::get(content); auto value = builder->CreateGlobalStringPtr(arg_str, "content_string"); @@ -2515,6 +2528,16 @@ void TaskCodeGenLLVM::visit(MeshPatchIndexStmt *stmt) { llvm_val[stmt] = get_arg(2); } +void TaskCodeGenLLVM::visit(MatrixInitStmt *stmt) { + auto type = tlctx->get_data_type(stmt->ret_type->as()); + llvm::Value *vec = llvm::UndefValue::get(type); + for (int i = 0; i < stmt->values.size(); ++i) { + auto *elem = llvm_val[stmt->values[i]]; + vec = builder->CreateInsertElement(vec, elem, i); + } + llvm_val[stmt] = vec; +} + void TaskCodeGenLLVM::eliminate_unused_functions() { TaichiLLVMContext::eliminate_unused_functions( module.get(), [&](std::string func_name) { diff --git a/taichi/codegen/llvm/codegen_llvm.h b/taichi/codegen/llvm/codegen_llvm.h index 6f97ed7dff0f4..356866b12b8c4 100644 --- a/taichi/codegen/llvm/codegen_llvm.h +++ b/taichi/codegen/llvm/codegen_llvm.h @@ -369,6 +369,8 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(ReferenceStmt *stmt) override; + void visit(MatrixInitStmt *stmt) override; + llvm::Value *create_xlogue(std::unique_ptr &block); llvm::Value *create_mesh_xlogue(std::unique_ptr &block); diff --git a/taichi/inc/expressions.inc.h b/taichi/inc/expressions.inc.h index 9b20ba86bd80a..ac3d3b7bc9b1b 100644 --- a/taichi/inc/expressions.inc.h +++ b/taichi/inc/expressions.inc.h @@ -7,6 +7,7 @@ PER_EXPRESSION(InternalFuncCallExpression) PER_EXPRESSION(ExternalTensorExpression) PER_EXPRESSION(GlobalVariableExpression) PER_EXPRESSION(IndexExpression) +PER_EXPRESSION(MatrixExpression) PER_EXPRESSION(StrideExpression) PER_EXPRESSION(RangeAssumptionExpression) PER_EXPRESSION(LoopUniqueExpression) diff --git a/taichi/inc/statements.inc.h b/taichi/inc/statements.inc.h index fe12a8941f7f5..05056ce46b9ae 100644 --- a/taichi/inc/statements.inc.h +++ b/taichi/inc/statements.inc.h @@ -38,6 +38,7 @@ PER_STATEMENT(LoopUniqueStmt) PER_STATEMENT(AssertStmt) PER_STATEMENT(ExternalFuncCallStmt) PER_STATEMENT(ExternalTensorShapeAlongAxisStmt) +PER_STATEMENT(MatrixInitStmt) // Locals with reverse-mode autodiff PER_STATEMENT(AdStackAllocaStmt) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 3c4ddfedf2cac..0460894fb65f6 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -501,7 +501,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { } } auto load_ptrs = irpass::analysis::get_load_pointers(stmt); - if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1) { + if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1 && + !stmt->is()) { // Identical load elimination auto load_ptr = load_ptrs.front(); if (!after_lower_access || diff --git a/taichi/ir/expression_printer.h b/taichi/ir/expression_printer.h index 74887c890e099..5d7bb1c012292 100644 --- a/taichi/ir/expression_printer.h +++ b/taichi/ir/expression_printer.h @@ -110,6 +110,14 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { } } + void visit(MatrixExpression *expr) override { + emit('['); + emit_vector(expr->elements); + emit(']'); + emit(fmt::format(" (dt={})", expr->dt->to_string())); + } + + void visit(IndexExpression *expr) override { expr->var->accept(this); emit('['); diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 56d5de025a9e8..6a00ed9e62121 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -429,6 +429,25 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, return ctx->push_back(var->stmt, offset_stmt); } +void MatrixExpression::type_check(CompileConfig *config) { + // TODO: typecheck matrix + for (auto &arg : elements) { + TI_ASSERT_TYPE_CHECKED(arg); + } +} + +void MatrixExpression::flatten(FlattenContext *ctx) { + // TODO: implement flatten + TI_ASSERT(this->dt->is()); + std::vector values; + for (auto &elt : elements) { + flatten_rvalue(elt, ctx); + values.push_back(elt->stmt); + } + stmt = ctx->push_back(values); + stmt->ret_type = this->dt; +} + bool IndexExpression::is_field() const { return var.is(); } @@ -960,6 +979,13 @@ Expr ASTBuilder::expr_alloca() { return var; } +Expr ASTBuilder::expr_alloca_local_matrix(const std::vector &shape, + const std::optional &dt, + const std::vector &elements) { + auto dtype = dt.value_or(PrimitiveType::unknown); + return Expr(std::make_shared(elements, shape, dtype)); +} + Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements) { diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index b56c3ec48f416..d1dfcbe56d2d5 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -504,6 +504,26 @@ class GlobalVariableExpression : public Expression { TI_DEFINE_ACCEPT_FOR_EXPRESSION }; +class MatrixExpression : public Expression { + public: + std::vector elements; + DataType dt; + + MatrixExpression(const std::vector &elements, + std::vector shape, + DataType element_type) + : elements(elements) { + this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type)); + this->ret_type = this->dt; + } + + void type_check(CompileConfig *config) override; + + void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION +}; + class IndexExpression : public Expression { public: // `var` is one of GlobalVariableExpression, ExternalTensorExpression, @@ -876,6 +896,9 @@ class ASTBuilder { const ExprGroup &args, const ExprGroup &outputs); Expr expr_alloca(); + Expr expr_alloca_local_matrix(const std::vector &shape, + const std::optional &dt, + const std::vector &elements); Expr expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements); diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 9a2ea841e6a66..784c0a499c202 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -1807,5 +1807,17 @@ class MeshPatchIndexStmt : public Stmt { TI_DEFINE_ACCEPT_AND_CLONE }; +class MatrixInitStmt : public Stmt { + public: + std::vector values; + + MatrixInitStmt(const std::vector &values) : values(values) { + TI_STMT_REG_FIELDS; + } + + TI_STMT_DEF_FIELDS(ret_type, values); + TI_DEFINE_ACCEPT_AND_CLONE +}; + } // namespace lang } // namespace taichi diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 6b9d1a51e7990..0188722f15e6f 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -87,6 +87,14 @@ std::string TensorType::to_string() const { return s; } +int TensorType::vector_width() const { + int vw = 1; + for (auto dim : shape_) { + vw *= dim; + } + return vw; +} + int Type::vector_width() const { return 1; // TODO: CPU vectorization } diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 339e2553ffb32..7588093588bcc 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -175,6 +175,8 @@ class TensorType : public Type { return shape_; } + int vector_width() const; + Type *get_compute_type() override { return this; } diff --git a/taichi/ir/type_utils.cpp b/taichi/ir/type_utils.cpp index e49428b022445..dc5f816ecfa19 100644 --- a/taichi/ir/type_utils.cpp +++ b/taichi/ir/type_utils.cpp @@ -53,6 +53,36 @@ int data_type_size(DataType t) { } } +std::string tensor_type_format_helper(const std::vector &shape, + std::string format_str, + int dim) { + std::string fmt = "["; + for (int i = 0; i < shape[dim]; ++i) { + if (dim != shape.size() - 1) { + fmt += tensor_type_format_helper(shape, format_str, dim + 1); + } else { + fmt += format_str; + } + if (i != shape[dim] - 1) { + fmt += ", "; + if (dim == 0) { + fmt += "\n"; + } + } + } + fmt += "]"; + return fmt; +} + +std::string tensor_type_format(DataType t) { + TI_ASSERT(t->is()); + auto tensor_type = t->as(); + auto shape = tensor_type->get_shape(); + auto element_type = tensor_type->get_element_type(); + auto element_type_format = data_type_format(element_type); + return tensor_type_format_helper(shape, element_type_format, 0); +} + std::string data_type_format(DataType dt) { if (dt->is_primitive(PrimitiveTypeID::i16)) { return "%hd"; @@ -79,6 +109,8 @@ std::string data_type_format(DataType dt) { // TaskCodeGenLLVM::visit(PrintStmt *stmt) and // TaskCodeGenCUDA::visit(PrintStmt *stmt) for more details. return "%f"; + } else if (dt->is()) { + return tensor_type_format(dt); } else { TI_NOT_IMPLEMENTED } diff --git a/taichi/program/compile_config.cpp b/taichi/program/compile_config.cpp index a15ebcf7f9c5a..6f3d7a16deb2f 100644 --- a/taichi/program/compile_config.cpp +++ b/taichi/program/compile_config.cpp @@ -48,6 +48,7 @@ CompileConfig::CompileConfig() { detect_read_only = true; ndarray_use_cached_allocator = true; use_mesh = false; + real_matrix = false; saturating_grid_dim = 0; max_block_dim = 0; diff --git a/taichi/program/compile_config.h b/taichi/program/compile_config.h index 7820834066041..14f17e419fd6d 100644 --- a/taichi/program/compile_config.h +++ b/taichi/program/compile_config.h @@ -44,6 +44,7 @@ struct CompileConfig { bool detect_read_only; bool ndarray_use_cached_allocator; bool use_mesh; + bool real_matrix; DataType default_fp; DataType default_ip; DataType default_up; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 39a2149db656f..f210522a92996 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -193,6 +193,7 @@ void export_lang(py::module &m) { .def_readwrite("ndarray_use_cached_allocator", &CompileConfig::ndarray_use_cached_allocator) .def_readwrite("use_mesh", &CompileConfig::use_mesh) + .def_readwrite("real_matrix", &CompileConfig::real_matrix) .def_readwrite("cc_compile_cmd", &CompileConfig::cc_compile_cmd) .def_readwrite("cc_link_cmd", &CompileConfig::cc_link_cmd) .def_readwrite("quant_opt_store_fusion", @@ -285,6 +286,7 @@ void export_lang(py::module &m) { .def("insert_external_func_call", &ASTBuilder::insert_external_func_call) .def("expr_alloca", &ASTBuilder::expr_alloca) .def("expr_alloca_local_tensor", &ASTBuilder::expr_alloca_local_tensor) + .def("expr_alloca_matrix", &ASTBuilder::expr_alloca_local_matrix) .def("expr_alloca_shared_array", &ASTBuilder::expr_alloca_shared_array) .def("create_assert_stmt", &ASTBuilder::create_assert_stmt) .def("expr_assign", &ASTBuilder::expr_assign) diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index 379acb1efdae0..edde242f48f96 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -135,6 +135,10 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { return llvm::Type::getInt64Ty(*ctx); } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*ctx); + } else if (dt->is()) { + auto vectorty = dt->as(); + auto dtype = this->get_data_type(vectorty->get_element_type()); + return llvm::VectorType::get(dtype, vectorty->get_num_elements(), false); } else { TI_INFO(data_type_name(dt)); TI_NOT_IMPLEMENTED diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 4ce787066f230..75ed12641e31b 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -164,8 +164,14 @@ class AlgSimp : public BasicStmtVisitor { if ((fast_math || is_integral(stmt->ret_type)) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a / a -> 1 - replace_with_one(stmt); - return true; + if (stmt->lhs->ret_type->is() && + stmt->rhs->ret_type->is()) { + replace_with_one(stmt); + return true; + } else { + // TODO: handle tensor division + return false; + } } if (fast_math && rhs && is_real(rhs->ret_type) && stmt->op_type != BinaryOpType::floordiv) { @@ -245,7 +251,13 @@ class AlgSimp : public BasicStmtVisitor { (fast_math || is_integral(stmt->ret_type)) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a -^ a -> 0 - replace_with_zero(stmt); + if (stmt->lhs->ret_type->is() && + stmt->rhs->ret_type->is()) { + replace_with_zero(stmt); + } else { + // TODO: handle tensor operations + return; + } } } else if (rhs && stmt->op_type == BinaryOpType::pow) { float64 exponent = rhs->val[0].val_cast_to_float64(); @@ -330,6 +342,11 @@ class AlgSimp : public BasicStmtVisitor { modifier.erase(stmt); } else if (alg_is_zero(lhs) || alg_is_zero(rhs)) { // 0 & a -> 0, a & 0 -> 0 + if (stmt->ret_type->is() || + stmt->rhs->ret_type->is()) { + // TODO: support tensor type + return; + } replace_with_zero(stmt); } else if (irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // a & a -> a @@ -344,6 +361,11 @@ class AlgSimp : public BasicStmtVisitor { // a << 0 -> a // 0 << a -> 0 // 0 >> a -> 0 + if (stmt->ret_type->is() || + stmt->rhs->ret_type->is()) { + // TODO: support tensor type + return; + } TI_ASSERT(stmt->lhs->ret_type == stmt->ret_type); stmt->replace_usages_with(stmt->lhs); modifier.erase(stmt); diff --git a/taichi/transforms/die.cpp b/taichi/transforms/die.cpp index 3176d8f576949..f6d696c4da498 100644 --- a/taichi/transforms/die.cpp +++ b/taichi/transforms/die.cpp @@ -108,6 +108,13 @@ class DIE : public IRVisitor { } stmt->all_blocks_accept(this, true); } + + void visit(MatrixInitStmt *stmt) override { + register_usage(stmt); + for (auto &elts : stmt->values) { + elts->accept(this); + } + } }; namespace irpass { diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index ca462e42773e8..97538ee93acce 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -794,6 +794,19 @@ class IRPrinter : public IRVisitor { print("{}{} = ref({})", stmt->type_hint(), stmt->name(), stmt->var->name()); } + void visit(MatrixInitStmt *stmt) override { + std::string result = ""; + result += fmt::format("{}{} = [", stmt->type_hint(), stmt->name()); + for (int i = 0; i < stmt->values.size(); ++i) { + result += stmt->values[i]->name(); + if (i != stmt->values.size() - 1) { + result += ", "; + } + } + result += "]"; + print(result); + } + private: std::string expr_to_string(Expr &expr) { return expr_to_string(expr.expr.get()); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index e16c62501ba3c..cd4d4f7e6cf59 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -553,6 +553,26 @@ class TypeCheck : public IRVisitor { stmt->ret_type = stmt->var->ret_type; stmt->ret_type.set_is_pointer(true); } + + void visit(MatrixInitStmt *stmt) override { + TI_ASSERT_INFO(stmt->ret_type->is(), + "Matrix should have tensor type, got {}", + stmt->ret_type->to_string()); + auto tensor_type = stmt->ret_type->as(); + auto element_dtype = tensor_type->get_element_type(); + for (auto elt : stmt->values) { + element_dtype = promoted_type(element_dtype, elt->ret_type); + } + for (int i = 0; i < stmt->values.size(); ++i) { + if (element_dtype != stmt->values[i]->ret_type) { + cast(stmt->values[i], element_dtype); + } + } + if (element_dtype != tensor_type->get_element_type()) { + stmt->ret_type = TypeFactory::create_tensor_type(tensor_type->get_shape(), + element_dtype); + } + } }; namespace irpass { From 549a359e5e69d648a047d8a6698db8decfcf2852 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Aug 2022 17:51:12 +0000 Subject: [PATCH 02/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/taichi/lang/ast/ast_transformer.py | 6 +++--- taichi/analysis/gen_offline_cache_key.cpp | 3 +-- taichi/ir/expression_printer.h | 1 - taichi/transforms/alg_simp.cpp | 6 +++--- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index aeca62a0c7421..f70b9e20c494e 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -14,7 +14,7 @@ from taichi.lang.ast.symbol_resolver import ASTResolver from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError from taichi.lang.field import Field -from taichi.lang.matrix import (Matrix, Vector, MatrixType, _PyScopeMatrixImpl, +from taichi.lang.matrix import (Matrix, MatrixType, Vector, _PyScopeMatrixImpl, _TiScopeMatrixImpl) from taichi.lang.snode import append from taichi.lang.util import in_taichi_scope, is_taichi_class, to_taichi_type @@ -487,9 +487,9 @@ def build_Call(ctx, node): args.insert(0, node.func.value.ptr) node.ptr = impl.ti_format(*args, **keywords) return node.ptr - + if isinstance(node.func, - ast.Attribute) and func == Matrix or func == Vector: + ast.Attribute) and func == Matrix or func == Vector: node.ptr = matrix.make_matrix(*args, **keywords) return node.ptr diff --git a/taichi/analysis/gen_offline_cache_key.cpp b/taichi/analysis/gen_offline_cache_key.cpp index 06c4a012e02b3..e0c825761c973 100644 --- a/taichi/analysis/gen_offline_cache_key.cpp +++ b/taichi/analysis/gen_offline_cache_key.cpp @@ -170,10 +170,9 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { emit(expr->dt); for (auto elt : expr->elements) { emit(elt); - } + } } - void visit(StrideExpression *expr) override { emit(ExprOpCode::StrideExpression); emit(expr->var); diff --git a/taichi/ir/expression_printer.h b/taichi/ir/expression_printer.h index 5d7bb1c012292..a32a0468d6907 100644 --- a/taichi/ir/expression_printer.h +++ b/taichi/ir/expression_printer.h @@ -117,7 +117,6 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { emit(fmt::format(" (dt={})", expr->dt->to_string())); } - void visit(IndexExpression *expr) override { expr->var->accept(this); emit('['); diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 75ed12641e31b..f4892fdc439a8 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -165,7 +165,7 @@ class AlgSimp : public BasicStmtVisitor { irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a / a -> 1 if (stmt->lhs->ret_type->is() && - stmt->rhs->ret_type->is()) { + stmt->rhs->ret_type->is()) { replace_with_one(stmt); return true; } else { @@ -252,7 +252,7 @@ class AlgSimp : public BasicStmtVisitor { irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a -^ a -> 0 if (stmt->lhs->ret_type->is() && - stmt->rhs->ret_type->is()) { + stmt->rhs->ret_type->is()) { replace_with_zero(stmt); } else { // TODO: handle tensor operations @@ -362,7 +362,7 @@ class AlgSimp : public BasicStmtVisitor { // 0 << a -> 0 // 0 >> a -> 0 if (stmt->ret_type->is() || - stmt->rhs->ret_type->is()) { + stmt->rhs->ret_type->is()) { // TODO: support tensor type return; } From 07a4dc12b34d98da7267cee950cc83c5b0a44381 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Mon, 15 Aug 2022 13:57:46 -0400 Subject: [PATCH 03/40] matrix assign --- python/taichi/lang/impl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index d9ab715b28dad..60389ce57227d 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -54,6 +54,8 @@ def expr_init(rhs): if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")): return Matrix(*rhs.to_list(), ndim=rhs.ndim) if isinstance(rhs, Matrix): + if current_cfg().real_matrix: + return rhs return Matrix(rhs.to_list(), ndim=rhs.ndim) if isinstance(rhs, SharedArray): return rhs From efca3f04a08fe1683d3801c8b83dd7575d573f37 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Mon, 15 Aug 2022 14:41:39 -0400 Subject: [PATCH 04/40] move checks to caller side --- python/taichi/lang/ast/ast_transformer.py | 4 ++-- python/taichi/lang/matrix.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index f70b9e20c494e..45e01b18a8ab2 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -488,8 +488,8 @@ def build_Call(ctx, node): node.ptr = impl.ti_format(*args, **keywords) return node.ptr - if isinstance(node.func, - ast.Attribute) and func == Matrix or func == Vector: + if (isinstance(node.func, + ast.Attribute) and (func in {Matrix, Vector})) and impl.current_cfg().real_matrix and in_taichi_scope(): node.ptr = matrix.make_matrix(*args, **keywords) return node.ptr diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index d694659de0975..8d8c59d9f2922 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -97,9 +97,7 @@ def prop_setter(instance, value): return cls -def make_matrix(arr, dt=None, suppress_warning=False, is_ref=False, **kwargs): - if not impl.current_cfg().real_matrix or in_python_scope(): - return Matrix(arr, dt, suppress_warning, is_ref, **kwargs) +def make_matrix(arr, dt=None): cast = (lambda x: ops_mod.cast(x, dt)) if dt else ( lambda x: x if isinstance(x, expr.Expr) else expr.Expr(x)) if len(arr) == 0: From 82c941383f29067ea973516bf1ee2a332038c703 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Aug 2022 18:43:03 +0000 Subject: [PATCH 05/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/taichi/lang/ast/ast_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 45e01b18a8ab2..4230da35f78b9 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -488,8 +488,8 @@ def build_Call(ctx, node): node.ptr = impl.ti_format(*args, **keywords) return node.ptr - if (isinstance(node.func, - ast.Attribute) and (func in {Matrix, Vector})) and impl.current_cfg().real_matrix and in_taichi_scope(): + if (isinstance(node.func, ast.Attribute) and (func in {Matrix, Vector}) + ) and impl.current_cfg().real_matrix and in_taichi_scope(): node.ptr = matrix.make_matrix(*args, **keywords) return node.ptr From 38ec7501c5b24c1e1e14f2ec480d4a556f4d84a6 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Mon, 15 Aug 2022 15:34:07 -0400 Subject: [PATCH 06/40] use == --- python/taichi/lang/ast/ast_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 45e01b18a8ab2..85cf388c8aa46 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -489,7 +489,7 @@ def build_Call(ctx, node): return node.ptr if (isinstance(node.func, - ast.Attribute) and (func in {Matrix, Vector})) and impl.current_cfg().real_matrix and in_taichi_scope(): + ast.Attribute) and (func == Matrix or func == Vector)) and impl.current_cfg().real_matrix and in_taichi_scope(): node.ptr = matrix.make_matrix(*args, **keywords) return node.ptr From 9c91103358a48b489cfee64c1a502c10552cca52 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 17 Aug 2022 13:37:56 -0400 Subject: [PATCH 07/40] refine impl --- python/taichi/lang/matrix.py | 17 +++++++++-------- taichi/ir/frontend_ir.cpp | 16 +++++++++------- taichi/ir/frontend_ir.h | 5 +++-- taichi/program/function.cpp | 2 +- taichi/program/kernel.cpp | 2 +- taichi/python/export_lang.cpp | 1 - taichi/runtime/llvm/llvm_context.cpp | 2 +- taichi/transforms/alg_simp.cpp | 24 +++++++++++++----------- 8 files changed, 37 insertions(+), 32 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 8d8c59d9f2922..06dea39d7adb3 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -98,17 +98,18 @@ def prop_setter(instance, value): def make_matrix(arr, dt=None): - cast = (lambda x: ops_mod.cast(x, dt)) if dt else ( - lambda x: x if isinstance(x, expr.Expr) else expr.Expr(x)) if len(arr) == 0: - return impl.expr_init(impl.expr_init_matrix([0], dt, [])) - if not isinstance(arr[0], Iterable): + return impl.expr_init(impl.expr_init_local_tensor([0], ti_python_core.DataType_unknown, impl.make_expr_group([]))) + is_matrix = isinstance(arr[0], Iterable) + if dt is None: + dt = _make_entries_initializer(is_matrix).infer_dt(arr) + if not is_matrix: return impl.expr_init( - impl.expr_init_matrix([len(arr)], dt, - [cast(elt).ptr for elt in arr])) + impl.expr_init_local_tensor([len(arr)], dt, + impl.make_expr_group([expr.Expr(elt) for elt in arr]))) return impl.expr_init( - impl.expr_init_matrix([len(arr), len(arr[0])], dt, - [cast(elt).ptr for row in arr for elt in row])) + impl.expr_init_local_tensor([len(arr), len(arr[0])], dt, + impl.make_expr_group([expr.Expr(elt) for row in arr for elt in row]))) class _MatrixBaseImpl: diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 6a00ed9e62121..3984a4dfa474e 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -87,9 +87,9 @@ FrontendForStmt::FrontendForStmt(const ExprGroup &loop_var, } } -FrontendContext::FrontendContext(Arch arch) { +FrontendContext::FrontendContext(Arch arch, bool real_matrix) { root_node_ = std::make_unique(); - current_builder_ = std::make_unique(root_node_.get(), arch); + current_builder_ = std::make_unique(root_node_.get(), arch, real_matrix); } FrontendForStmt::FrontendForStmt(const Expr &loop_var, @@ -979,16 +979,18 @@ Expr ASTBuilder::expr_alloca() { return var; } -Expr ASTBuilder::expr_alloca_local_matrix(const std::vector &shape, - const std::optional &dt, - const std::vector &elements) { - auto dtype = dt.value_or(PrimitiveType::unknown); - return Expr(std::make_shared(elements, shape, dtype)); +Expr make_local_matrix(const std::vector &shape, + const DataType &dt, + const std::vector &elements) { + return Expr(std::make_shared(elements, shape, dt)); } Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements) { + if (this->use_real_matrix_) { + return make_local_matrix(shape,element_type, elements.exprs); + } auto var = Expr(std::make_shared(get_next_id())); this->insert(std::make_unique( std::static_pointer_cast(var.expr)->id, shape, diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index d1dfcbe56d2d5..8e774ebfcccfd 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -861,9 +861,10 @@ class ASTBuilder { Arch arch_; ForLoopDecoratorRecorder for_loop_dec_; int id_counter_{0}; + bool use_real_matrix_{false}; public: - ASTBuilder(Block *initial, Arch arch) : arch_(arch) { + ASTBuilder(Block *initial, Arch arch, bool real_matrix) : arch_(arch), use_real_matrix_(real_matrix) { stack_.push_back(initial); loop_state_stack_.push_back(None); } @@ -964,7 +965,7 @@ class FrontendContext { std::unique_ptr root_node_; public: - FrontendContext(Arch arch); + FrontendContext(Arch arch, bool real_matrix); ASTBuilder &builder() { return *current_builder_; diff --git a/taichi/program/function.cpp b/taichi/program/function.cpp index e2655fed02432..2ff368711fddd 100644 --- a/taichi/program/function.cpp +++ b/taichi/program/function.cpp @@ -12,7 +12,7 @@ Function::Function(Program *program, const FunctionKey &func_key) } void Function::set_function_body(const std::function &func) { - context = std::make_unique(program->config.arch); + context = std::make_unique(program->config.arch, program->config.real_matrix); ir = context->get_root(); { // Note: this is not a mutex diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 9f417fb408672..89cf804a2961d 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -391,7 +391,7 @@ void Kernel::init(Program &program, is_accessor = false; is_evaluator = false; compiled_ = nullptr; - context = std::make_unique(program.config.arch); + context = std::make_unique(program.config.arch, program.config.real_matrix); ir = context->get_root(); ir_is_ast_ = true; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index f210522a92996..b9ceda1b6cbfe 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -286,7 +286,6 @@ void export_lang(py::module &m) { .def("insert_external_func_call", &ASTBuilder::insert_external_func_call) .def("expr_alloca", &ASTBuilder::expr_alloca) .def("expr_alloca_local_tensor", &ASTBuilder::expr_alloca_local_tensor) - .def("expr_alloca_matrix", &ASTBuilder::expr_alloca_local_matrix) .def("expr_alloca_shared_array", &ASTBuilder::expr_alloca_shared_array) .def("create_assert_stmt", &ASTBuilder::create_assert_stmt) .def("expr_assign", &ASTBuilder::expr_assign) diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index edde242f48f96..75cd098bf7b3e 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -138,7 +138,7 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { } else if (dt->is()) { auto vectorty = dt->as(); auto dtype = this->get_data_type(vectorty->get_element_type()); - return llvm::VectorType::get(dtype, vectorty->get_num_elements(), false); + return llvm::VectorType::get(dtype, vectorty->get_num_elements(), /*scalable=*/false); } else { TI_INFO(data_type_name(dt)); TI_NOT_IMPLEMENTED diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index f4892fdc439a8..cb8b4c1741e13 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -164,14 +164,13 @@ class AlgSimp : public BasicStmtVisitor { if ((fast_math || is_integral(stmt->ret_type)) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a / a -> 1 - if (stmt->lhs->ret_type->is() && - stmt->rhs->ret_type->is()) { - replace_with_one(stmt); - return true; - } else { - // TODO: handle tensor division + if (stmt->lhs->ret_type->is() || + stmt->rhs->ret_type->is()) { + // TODO: support tensor type return false; } + replace_with_one(stmt); + return true; } if (fast_math && rhs && is_real(rhs->ret_type) && stmt->op_type != BinaryOpType::floordiv) { @@ -251,12 +250,11 @@ class AlgSimp : public BasicStmtVisitor { (fast_math || is_integral(stmt->ret_type)) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a -^ a -> 0 - if (stmt->lhs->ret_type->is() && - stmt->rhs->ret_type->is()) { - replace_with_zero(stmt); + if (stmt->lhs->ret_type->is() && + stmt->rhs->ret_type->is()) { + // TODO: support tensor type } else { - // TODO: handle tensor operations - return; + replace_with_zero(stmt); } } } else if (rhs && stmt->op_type == BinaryOpType::pow) { @@ -267,6 +265,10 @@ class AlgSimp : public BasicStmtVisitor { modifier.erase(stmt); } else if (exponent == 0) { // a ** 0 -> 1 + if (stmt->ret_type->is()) { + // TODO: support tensor type + return; + } replace_with_one(stmt); } else if (exponent == 0.5) { // a ** 0.5 -> sqrt(a) From 28f3e0ac57c20f1d3db07b9d5000d85845af0471 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Aug 2022 17:39:30 +0000 Subject: [PATCH 08/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/taichi/lang/matrix.py | 12 +++++++++--- taichi/ir/frontend_ir.cpp | 5 +++-- taichi/ir/frontend_ir.h | 3 ++- taichi/program/function.cpp | 3 ++- taichi/program/kernel.cpp | 3 ++- taichi/runtime/llvm/llvm_context.cpp | 3 ++- 6 files changed, 20 insertions(+), 9 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 06dea39d7adb3..0ab648ddb5155 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -99,17 +99,23 @@ def prop_setter(instance, value): def make_matrix(arr, dt=None): if len(arr) == 0: - return impl.expr_init(impl.expr_init_local_tensor([0], ti_python_core.DataType_unknown, impl.make_expr_group([]))) + return impl.expr_init( + impl.expr_init_local_tensor([0], ti_python_core.DataType_unknown, + impl.make_expr_group([]))) is_matrix = isinstance(arr[0], Iterable) if dt is None: dt = _make_entries_initializer(is_matrix).infer_dt(arr) if not is_matrix: return impl.expr_init( impl.expr_init_local_tensor([len(arr)], dt, - impl.make_expr_group([expr.Expr(elt) for elt in arr]))) + impl.make_expr_group( + [expr.Expr(elt) for elt in arr]))) return impl.expr_init( impl.expr_init_local_tensor([len(arr), len(arr[0])], dt, - impl.make_expr_group([expr.Expr(elt) for row in arr for elt in row]))) + impl.make_expr_group([ + expr.Expr(elt) for row in arr + for elt in row + ]))) class _MatrixBaseImpl: diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 3984a4dfa474e..edb1aca141b47 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -89,7 +89,8 @@ FrontendForStmt::FrontendForStmt(const ExprGroup &loop_var, FrontendContext::FrontendContext(Arch arch, bool real_matrix) { root_node_ = std::make_unique(); - current_builder_ = std::make_unique(root_node_.get(), arch, real_matrix); + current_builder_ = + std::make_unique(root_node_.get(), arch, real_matrix); } FrontendForStmt::FrontendForStmt(const Expr &loop_var, @@ -989,7 +990,7 @@ Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements) { if (this->use_real_matrix_) { - return make_local_matrix(shape,element_type, elements.exprs); + return make_local_matrix(shape, element_type, elements.exprs); } auto var = Expr(std::make_shared(get_next_id())); this->insert(std::make_unique( diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 8e774ebfcccfd..effd8abd566e9 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -864,7 +864,8 @@ class ASTBuilder { bool use_real_matrix_{false}; public: - ASTBuilder(Block *initial, Arch arch, bool real_matrix) : arch_(arch), use_real_matrix_(real_matrix) { + ASTBuilder(Block *initial, Arch arch, bool real_matrix) + : arch_(arch), use_real_matrix_(real_matrix) { stack_.push_back(initial); loop_state_stack_.push_back(None); } diff --git a/taichi/program/function.cpp b/taichi/program/function.cpp index 2ff368711fddd..64dd970cde4ea 100644 --- a/taichi/program/function.cpp +++ b/taichi/program/function.cpp @@ -12,7 +12,8 @@ Function::Function(Program *program, const FunctionKey &func_key) } void Function::set_function_body(const std::function &func) { - context = std::make_unique(program->config.arch, program->config.real_matrix); + context = std::make_unique(program->config.arch, + program->config.real_matrix); ir = context->get_root(); { // Note: this is not a mutex diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 89cf804a2961d..5f7d5ea31db31 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -391,7 +391,8 @@ void Kernel::init(Program &program, is_accessor = false; is_evaluator = false; compiled_ = nullptr; - context = std::make_unique(program.config.arch, program.config.real_matrix); + context = std::make_unique(program.config.arch, + program.config.real_matrix); ir = context->get_root(); ir_is_ast_ = true; diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index 75cd098bf7b3e..076e7ad8788e0 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -138,7 +138,8 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { } else if (dt->is()) { auto vectorty = dt->as(); auto dtype = this->get_data_type(vectorty->get_element_type()); - return llvm::VectorType::get(dtype, vectorty->get_num_elements(), /*scalable=*/false); + return llvm::VectorType::get(dtype, vectorty->get_num_elements(), + /*scalable=*/false); } else { TI_INFO(data_type_name(dt)); TI_NOT_IMPLEMENTED From bf719a3ddd93c79839cbb6c722f0bc280f3fd4cb Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 17 Aug 2022 14:06:07 -0400 Subject: [PATCH 09/40] no long in use --- taichi/ir/type.cpp | 8 -------- taichi/ir/type.h | 2 -- 2 files changed, 10 deletions(-) diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 0188722f15e6f..6b9d1a51e7990 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -87,14 +87,6 @@ std::string TensorType::to_string() const { return s; } -int TensorType::vector_width() const { - int vw = 1; - for (auto dim : shape_) { - vw *= dim; - } - return vw; -} - int Type::vector_width() const { return 1; // TODO: CPU vectorization } diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 7588093588bcc..339e2553ffb32 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -175,8 +175,6 @@ class TensorType : public Type { return shape_; } - int vector_width() const; - Type *get_compute_type() override { return this; } From 65199eabc6e0f28315688dc666e4b6e543a8d10b Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 17 Aug 2022 14:14:29 -0400 Subject: [PATCH 10/40] add some comments --- taichi/codegen/llvm/codegen_llvm.cpp | 2 +- taichi/ir/frontend_ir.h | 4 ++++ taichi/ir/statements.h | 3 +++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 5b6d31a345f70..bc54a5a2e8f51 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -692,7 +692,7 @@ llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) { auto tensor_type = dt->cast(); auto element_type = llvm_type(tensor_type->get_element_type()); return llvm::VectorType::get(element_type, tensor_type->get_num_elements(), - false); + /*scalable=*/false); } else { TI_NOT_IMPLEMENTED; } diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index effd8abd566e9..a544d15793dc9 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -504,6 +504,10 @@ class GlobalVariableExpression : public Expression { TI_DEFINE_ACCEPT_FOR_EXPRESSION }; +/** + * Creating a local matrix; + * lowered from ti.Matrix with real_matrix=True + */ class MatrixExpression : public Expression { public: std::vector elements; diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 784c0a499c202..11c87dd98b67c 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -1807,6 +1807,9 @@ class MeshPatchIndexStmt : public Stmt { TI_DEFINE_ACCEPT_AND_CLONE }; +/** + * Initialization of a local matrix + */ class MatrixInitStmt : public Stmt { public: std::vector values; From cbf1ea84ec52f9c721e1b687e06068de1fdb2386 Mon Sep 17 00:00:00 2001 From: Mike He Date: Tue, 23 Aug 2022 19:17:51 -0400 Subject: [PATCH 11/40] get rid of always-true condition Co-authored-by: Yi Xu --- python/taichi/lang/ast/ast_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 2376f77c79092..fcc0b46cfebb8 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -490,7 +490,7 @@ def build_Call(ctx, node): if (isinstance(node.func, ast.Attribute) and (func == Matrix or func == Vector) - ) and impl.current_cfg().real_matrix and in_taichi_scope(): + ) and impl.current_cfg().real_matrix: node.ptr = matrix.make_matrix(*args, **keywords) return node.ptr From 4c8d6b7d7dc9e957d180baca78ad531c366b1e53 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Aug 2022 23:19:08 +0000 Subject: [PATCH 12/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/taichi/lang/ast/ast_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index fcc0b46cfebb8..135a8f14d175c 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -489,8 +489,8 @@ def build_Call(ctx, node): return node.ptr if (isinstance(node.func, ast.Attribute) and - (func == Matrix or func == Vector) - ) and impl.current_cfg().real_matrix: + (func == Matrix + or func == Vector)) and impl.current_cfg().real_matrix: node.ptr = matrix.make_matrix(*args, **keywords) return node.ptr From 1a9df8c7107c0b48dac477253ba61821b80b3edf Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 23 Aug 2022 19:40:49 -0400 Subject: [PATCH 13/40] save --- python/taichi/lang/impl.py | 16 ++++++++-------- python/taichi/lang/matrix.py | 6 +++--- taichi/ir/frontend_ir.h | 3 --- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 60389ce57227d..4f45a8fb4c8d2 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -15,7 +15,8 @@ from taichi.lang.field import Field, ScalarField from taichi.lang.kernel_arguments import SparseMatrixProxy from taichi.lang.matrix import (Matrix, MatrixField, MatrixNdarray, MatrixType, - _IntermediateMatrix, _MatrixFieldElement) + _IntermediateMatrix, _MatrixFieldElement, + make_matrix) from taichi.lang.mesh import (ConvType, MeshElementFieldProxy, MeshInstance, MeshRelationAccessProxy, MeshReorderedMatrixFieldProxy, @@ -35,12 +36,6 @@ def expr_init_local_tensor(shape, element_type, elements): shape, element_type, elements) -@taichi_scope -def expr_init_matrix(shape, element_type, elements): - return get_runtime().prog.current_ast_builder().expr_alloca_matrix( - shape, element_type, elements) - - @taichi_scope def expr_init_shared_array(shape, element_type): return get_runtime().prog.current_ast_builder().expr_alloca_shared_array( @@ -55,7 +50,12 @@ def expr_init(rhs): return Matrix(*rhs.to_list(), ndim=rhs.ndim) if isinstance(rhs, Matrix): if current_cfg().real_matrix: - return rhs + if rhs.ndim == 1: + entries = [rhs(i) for i in range(rhs.n)] + else: + entries = [[rhs(i, j) for j in range(rhs.m)] + for i in range(rhs.n)] + return make_matrix(entries) return Matrix(rhs.to_list(), ndim=rhs.ndim) if isinstance(rhs, SharedArray): return rhs diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 0ab648ddb5155..1a755b6492004 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -99,18 +99,18 @@ def prop_setter(instance, value): def make_matrix(arr, dt=None): if len(arr) == 0: - return impl.expr_init( + return impl.Expr( impl.expr_init_local_tensor([0], ti_python_core.DataType_unknown, impl.make_expr_group([]))) is_matrix = isinstance(arr[0], Iterable) if dt is None: dt = _make_entries_initializer(is_matrix).infer_dt(arr) if not is_matrix: - return impl.expr_init( + return impl.Expr( impl.expr_init_local_tensor([len(arr)], dt, impl.make_expr_group( [expr.Expr(elt) for elt in arr]))) - return impl.expr_init( + return impl.Expr( impl.expr_init_local_tensor([len(arr), len(arr[0])], dt, impl.make_expr_group([ expr.Expr(elt) for row in arr diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index a544d15793dc9..afaf6db776bc3 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -902,9 +902,6 @@ class ASTBuilder { const ExprGroup &args, const ExprGroup &outputs); Expr expr_alloca(); - Expr expr_alloca_local_matrix(const std::vector &shape, - const std::optional &dt, - const std::vector &elements); Expr expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements); From 834699e1bc47d1883f605489957cbc1e1192438f Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 23 Aug 2022 19:50:06 -0400 Subject: [PATCH 14/40] some fixes for print and matrix expr --- taichi/codegen/llvm/codegen_llvm.cpp | 8 +++++++- taichi/ir/frontend_ir.cpp | 1 + taichi/ir/frontend_ir.h | 1 - 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index bc54a5a2e8f51..4609efa1d43f8 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -807,8 +807,14 @@ void TaskCodeGenLLVM::visit(PrintStmt *stmt) { auto value = llvm_val[arg_stmt]; if (arg_stmt->ret_type->is()) { auto dtype = arg_stmt->ret_type->cast(); + auto elem_type = dtype->get_element_type(); for (int i = 0; i < dtype->get_num_elements(); ++i) { - args.push_back(builder->CreateExtractElement(value, i)); + auto elem_value = builder->CreateExtractElement(value, i); + if (elem_type->is_primitive(PrimitiveTypeID::f32) || + elem_type->is_primitive(PrimitiveTypeID::f16)) + elem_value = builder->CreateFPExt( + elem_value, tlctx->get_data_type(PrimitiveType::f64)); + args.push_back(elem_value); } formats += data_type_format(arg_stmt->ret_type); } else { diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index edb1aca141b47..d272d295c98df 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -435,6 +435,7 @@ void MatrixExpression::type_check(CompileConfig *config) { for (auto &arg : elements) { TI_ASSERT_TYPE_CHECKED(arg); } + ret_type = dt; } void MatrixExpression::flatten(FlattenContext *ctx) { diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index afaf6db776bc3..df472b954377e 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -518,7 +518,6 @@ class MatrixExpression : public Expression { DataType element_type) : elements(elements) { this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type)); - this->ret_type = this->dt; } void type_check(CompileConfig *config) override; From 23d7bf7bca927d138384cb5666613ac8c8e7b291 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 23 Aug 2022 19:53:20 -0400 Subject: [PATCH 15/40] fix codegen alloca size --- taichi/codegen/llvm/codegen_llvm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 4609efa1d43f8..397da5cf0b42b 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -124,7 +124,7 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) { void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { if (stmt->ret_type->is()) { auto tensor_type = stmt->ret_type->cast(); - auto type = tlctx->get_data_type(tensor_type); + auto type = tlctx->get_data_type(tensor_type->get_element_type()); auto array_size = tlctx->get_constant(tensor_type->get_num_elements()); // Return type is [array_size x type]*. if (stmt->is_shared) { From 6a8a8cbf4eb1b38a0e159a89fc14879f2ab03629 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 02:34:04 -0400 Subject: [PATCH 16/40] unsupport empty matrix --- python/taichi/lang/matrix.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 1a755b6492004..2d68b83d2b248 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -98,10 +98,7 @@ def prop_setter(instance, value): def make_matrix(arr, dt=None): - if len(arr) == 0: - return impl.Expr( - impl.expr_init_local_tensor([0], ti_python_core.DataType_unknown, - impl.make_expr_group([]))) + assert len(arr) > 0, "Cannot create empty matrix" is_matrix = isinstance(arr[0], Iterable) if dt is None: dt = _make_entries_initializer(is_matrix).infer_dt(arr) @@ -440,10 +437,8 @@ def __init__(self, raise TaichiTypeError( "An Matrix/Vector can only be initialized with an array-like object" ) - if len(arr) == 0: - mat = [] - self.ndim = 0 - elif isinstance(arr[0], Matrix): + assert len(arr) > 0, "Cannot create empty matrix" + if isinstance(arr[0], Matrix): raise Exception('cols/rows required when using list of vectors') else: is_matrix = isinstance(arr[0], Iterable) From 08926eff85bc8465e945ed350fd5c5f6e797d738 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 02:41:31 -0400 Subject: [PATCH 17/40] only check and cast elements --- taichi/transforms/type_check.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index cd4d4f7e6cf59..c0a24954b2bad 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -560,18 +560,11 @@ class TypeCheck : public IRVisitor { stmt->ret_type->to_string()); auto tensor_type = stmt->ret_type->as(); auto element_dtype = tensor_type->get_element_type(); - for (auto elt : stmt->values) { - element_dtype = promoted_type(element_dtype, elt->ret_type); - } for (int i = 0; i < stmt->values.size(); ++i) { if (element_dtype != stmt->values[i]->ret_type) { cast(stmt->values[i], element_dtype); } } - if (element_dtype != tensor_type->get_element_type()) { - stmt->ret_type = TypeFactory::create_tensor_type(tensor_type->get_shape(), - element_dtype); - } } }; From 1ae02aae51b5b1fa86d7656b1b007da86b28907a Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 02:44:02 -0400 Subject: [PATCH 18/40] fmt Vectors to one line --- taichi/ir/type_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/ir/type_utils.cpp b/taichi/ir/type_utils.cpp index dc5f816ecfa19..76ee3aa1f7b30 100644 --- a/taichi/ir/type_utils.cpp +++ b/taichi/ir/type_utils.cpp @@ -65,7 +65,7 @@ std::string tensor_type_format_helper(const std::vector &shape, } if (i != shape[dim] - 1) { fmt += ", "; - if (dim == 0) { + if (dim == 0 && dim != shape.size() - 1) { fmt += "\n"; } } From 17412e81060c2212a09d129eea6c3c86d5f3258b Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 02:52:50 -0400 Subject: [PATCH 19/40] lift duplicate part --- taichi/codegen/llvm/codegen_llvm.cpp | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 397da5cf0b42b..695e28e96a310 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -801,6 +801,13 @@ void TaskCodeGenLLVM::visit(PrintStmt *stmt) { TI_ASSERT(stmt->width() == 1); std::vector args; std::string formats; + auto value_for_printf = [this](llvm::Value *to_print, DataType dtype) { + if (dtype->is_primitive(PrimitiveTypeID::f32) || + dtype->is_primitive(PrimitiveTypeID::f16)) + return this->builder->CreateFPExt( + to_print, this->tlctx->get_data_type(PrimitiveType::f64)); + return to_print; + }; for (auto const &content : stmt->contents) { if (std::holds_alternative(content)) { auto arg_stmt = std::get(content); @@ -810,19 +817,11 @@ void TaskCodeGenLLVM::visit(PrintStmt *stmt) { auto elem_type = dtype->get_element_type(); for (int i = 0; i < dtype->get_num_elements(); ++i) { auto elem_value = builder->CreateExtractElement(value, i); - if (elem_type->is_primitive(PrimitiveTypeID::f32) || - elem_type->is_primitive(PrimitiveTypeID::f16)) - elem_value = builder->CreateFPExt( - elem_value, tlctx->get_data_type(PrimitiveType::f64)); - args.push_back(elem_value); + args.push_back(value_for_printf(elem_value, elem_type)); } formats += data_type_format(arg_stmt->ret_type); } else { - if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) || - arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) - value = builder->CreateFPExt( - value, tlctx->get_data_type(PrimitiveType::f64)); - args.push_back(value); + args.push_back(value_for_printf(value, arg_stmt->ret_type)); formats += data_type_format(arg_stmt->ret_type); } } else { From 5421fe1a8eae6c462f141567d603ac50bca4216c Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 03:02:49 -0400 Subject: [PATCH 20/40] clean-up --- taichi/analysis/data_source_analysis.cpp | 2 -- taichi/ir/control_flow_graph.cpp | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index 9baad69fb80fb..39bfa6750a0b9 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -37,8 +37,6 @@ std::vector get_load_pointers(Stmt *load_stmt) { return external_func->arg_stmts; } else if (auto ref = load_stmt->cast()) { return {ref->var}; - } else if (auto matrix_init = load_stmt->cast()) { - return matrix_init->values; } else if (auto ptr_offset = load_stmt->cast()) { return {ptr_offset->origin}; } else { diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 0460894fb65f6..3c4ddfedf2cac 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -501,8 +501,7 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { } } auto load_ptrs = irpass::analysis::get_load_pointers(stmt); - if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1 && - !stmt->is()) { + if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1) { // Identical load elimination auto load_ptr = load_ptrs.front(); if (!after_lower_access || From b9fd3a941629d9ded1fa53bc446f86084307478c Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 12:53:29 -0400 Subject: [PATCH 21/40] clean-up cse code --- taichi/analysis/same_statements.cpp | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index 5f908603c9497..c26957e81f4a8 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -196,24 +196,6 @@ class IRNodeComparator : public IRVisitor { basic_check(stmt); } - void visit(MatrixInitStmt *stmt) override { - basic_check(stmt); - if (!same) - return; - auto o = other_node_->as(); - if (stmt->values.size() != o->values.size()) { - same = false; - return; - } - for (int i = 0; i < stmt->values.size(); ++i) { - other_node_ = o->values[i]; - stmt->values[i]->accept(this); - other_node_ = o; - if (!same) - return; - } - } - void visit(IfStmt *stmt) override { basic_check(stmt); if (!same) From 43a456a1c97db6fa4a78d77a8ea90570afad2b78 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 13:03:56 -0400 Subject: [PATCH 22/40] breaks ci; keep as original impl --- python/taichi/lang/matrix.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 2d68b83d2b248..fe429e7238e59 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -437,8 +437,10 @@ def __init__(self, raise TaichiTypeError( "An Matrix/Vector can only be initialized with an array-like object" ) - assert len(arr) > 0, "Cannot create empty matrix" - if isinstance(arr[0], Matrix): + if len(arr) == 0: + mat = [] + self.ndim = 0 + elif isinstance(arr[0], Matrix): raise Exception('cols/rows required when using list of vectors') else: is_matrix = isinstance(arr[0], Iterable) From 78ad14ac822ba54621dc5d7f86ebb0f85eaeb7d6 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 13:10:13 -0400 Subject: [PATCH 23/40] handle alloca --- taichi/ir/frontend_ir.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index d272d295c98df..361d77945507c 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -991,7 +991,10 @@ Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements) { if (this->use_real_matrix_) { - return make_local_matrix(shape, element_type, elements.exprs); + auto matrix_expr = make_local_matrix(shape, element_type, elements.exprs); + auto v = this->expr_alloca(); + this->insert(std::make_unique(v, matrix_expr)); + return v; } auto var = Expr(std::make_shared(get_next_id())); this->insert(std::make_unique( From 40825f466a9b1a55218fcac2163b77b03e555852 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 13:18:40 -0400 Subject: [PATCH 24/40] move checks to front --- taichi/transforms/alg_simp.cpp | 41 +++++++++++++--------------------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index cb8b4c1741e13..f7830fad9d850 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -100,6 +100,11 @@ class AlgSimp : public BasicStmtVisitor { bool optimize_multiplication(BinaryOpStmt *stmt) { // return true iff the IR is modified + if (stmt->lhs->ret_type->is() || + stmt->rhs->ret_type->is()) { + // TODO: support tensor type + return false; + } auto lhs = stmt->lhs->cast(); auto rhs = stmt->rhs->cast(); TI_ASSERT(stmt->op_type == BinaryOpType::mul); @@ -151,6 +156,11 @@ class AlgSimp : public BasicStmtVisitor { bool optimize_division(BinaryOpStmt *stmt) { // return true iff the IR is modified + if (stmt->lhs->ret_type->is() || + stmt->rhs->ret_type->is()) { + // TODO: support tensor type + return false; + } auto rhs = stmt->rhs->cast(); TI_ASSERT(stmt->op_type == BinaryOpType::div || stmt->op_type == BinaryOpType::floordiv); @@ -164,11 +174,6 @@ class AlgSimp : public BasicStmtVisitor { if ((fast_math || is_integral(stmt->ret_type)) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a / a -> 1 - if (stmt->lhs->ret_type->is() || - stmt->rhs->ret_type->is()) { - // TODO: support tensor type - return false; - } replace_with_one(stmt); return true; } @@ -218,6 +223,11 @@ class AlgSimp : public BasicStmtVisitor { } void visit(BinaryOpStmt *stmt) override { + if (stmt->lhs->ret_type->is() || + stmt->rhs->ret_type->is()) { + // TODO: support tensor type + return; + } auto lhs = stmt->lhs->cast(); auto rhs = stmt->rhs->cast(); if (stmt->width() != 1) { @@ -250,12 +260,7 @@ class AlgSimp : public BasicStmtVisitor { (fast_math || is_integral(stmt->ret_type)) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a -^ a -> 0 - if (stmt->lhs->ret_type->is() && - stmt->rhs->ret_type->is()) { - // TODO: support tensor type - } else { - replace_with_zero(stmt); - } + replace_with_zero(stmt); } } else if (rhs && stmt->op_type == BinaryOpType::pow) { float64 exponent = rhs->val[0].val_cast_to_float64(); @@ -265,10 +270,6 @@ class AlgSimp : public BasicStmtVisitor { modifier.erase(stmt); } else if (exponent == 0) { // a ** 0 -> 1 - if (stmt->ret_type->is()) { - // TODO: support tensor type - return; - } replace_with_one(stmt); } else if (exponent == 0.5) { // a ** 0.5 -> sqrt(a) @@ -344,11 +345,6 @@ class AlgSimp : public BasicStmtVisitor { modifier.erase(stmt); } else if (alg_is_zero(lhs) || alg_is_zero(rhs)) { // 0 & a -> 0, a & 0 -> 0 - if (stmt->ret_type->is() || - stmt->rhs->ret_type->is()) { - // TODO: support tensor type - return; - } replace_with_zero(stmt); } else if (irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // a & a -> a @@ -363,11 +359,6 @@ class AlgSimp : public BasicStmtVisitor { // a << 0 -> a // 0 << a -> 0 // 0 >> a -> 0 - if (stmt->ret_type->is() || - stmt->rhs->ret_type->is()) { - // TODO: support tensor type - return; - } TI_ASSERT(stmt->lhs->ret_type == stmt->ret_type); stmt->replace_usages_with(stmt->lhs); modifier.erase(stmt); From f395d2a3bda9fee931412384427b0339fa90aed1 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 13:27:43 -0400 Subject: [PATCH 25/40] reuse code --- taichi/ir/frontend_ir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index e65e5009ceade..7627fceb79cd9 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -1004,7 +1004,7 @@ Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, if (this->use_real_matrix_) { auto matrix_expr = make_local_matrix(shape, element_type, elements.exprs); auto v = this->expr_alloca(); - this->insert(std::make_unique(v, matrix_expr)); + this->expr_assign(v, matrix_expr, tb); return v; } auto var = Expr(std::make_shared(get_next_id())); From 2a6a8e64828fad73ba21259cc7f1c052fe3d36b5 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 13:45:33 -0400 Subject: [PATCH 26/40] Revert "clean-up cse code" This reverts commit b9fd3a941629d9ded1fa53bc446f86084307478c. --- taichi/analysis/same_statements.cpp | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index c26957e81f4a8..5f908603c9497 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -196,6 +196,24 @@ class IRNodeComparator : public IRVisitor { basic_check(stmt); } + void visit(MatrixInitStmt *stmt) override { + basic_check(stmt); + if (!same) + return; + auto o = other_node_->as(); + if (stmt->values.size() != o->values.size()) { + same = false; + return; + } + for (int i = 0; i < stmt->values.size(); ++i) { + other_node_ = o->values[i]; + stmt->values[i]->accept(this); + other_node_ = o; + if (!same) + return; + } + } + void visit(IfStmt *stmt) override { basic_check(stmt); if (!same) From 7f8ca37da666d389ac6bf07ca6e098f3bf40c418 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 13:54:46 -0400 Subject: [PATCH 27/40] clean up together --- taichi/analysis/data_source_analysis.cpp | 2 -- taichi/analysis/same_statements.cpp | 18 ------------------ taichi/codegen/llvm/codegen_llvm.cpp | 1 + 3 files changed, 1 insertion(+), 20 deletions(-) diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index 39bfa6750a0b9..8bdeb72e9a70b 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -37,8 +37,6 @@ std::vector get_load_pointers(Stmt *load_stmt) { return external_func->arg_stmts; } else if (auto ref = load_stmt->cast()) { return {ref->var}; - } else if (auto ptr_offset = load_stmt->cast()) { - return {ptr_offset->origin}; } else { return std::vector(); } diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index 5f908603c9497..c26957e81f4a8 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -196,24 +196,6 @@ class IRNodeComparator : public IRVisitor { basic_check(stmt); } - void visit(MatrixInitStmt *stmt) override { - basic_check(stmt); - if (!same) - return; - auto o = other_node_->as(); - if (stmt->values.size() != o->values.size()) { - same = false; - return; - } - for (int i = 0; i < stmt->values.size(); ++i) { - other_node_ = o->values[i]; - stmt->values[i]->accept(this); - other_node_ = o; - if (!same) - return; - } - } - void visit(IfStmt *stmt) override { basic_check(stmt); if (!same) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 695e28e96a310..c436484d32351 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -689,6 +689,7 @@ llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) { } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*llvm_context); } else if (dt->is()) { + TI_ASSERT_INFO(kernel->program->config.real_matrix, "Real matrix not enabled but got TensorType"); auto tensor_type = dt->cast(); auto element_type = llvm_type(tensor_type->get_element_type()); return llvm::VectorType::get(element_type, tensor_type->get_num_elements(), From 988abb3c36b2854e5c0b8c5509a91f1058c0946a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Aug 2022 17:56:37 +0000 Subject: [PATCH 28/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/codegen/llvm/codegen_llvm.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index c436484d32351..ccaef8001cecb 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -689,7 +689,8 @@ llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) { } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*llvm_context); } else if (dt->is()) { - TI_ASSERT_INFO(kernel->program->config.real_matrix, "Real matrix not enabled but got TensorType"); + TI_ASSERT_INFO(kernel->program->config.real_matrix, + "Real matrix not enabled but got TensorType"); auto tensor_type = dt->cast(); auto element_type = llvm_type(tensor_type->get_element_type()); return llvm::VectorType::get(element_type, tensor_type->get_num_elements(), From 93b3a036c824b9906673c1558ea57502a6ab74a0 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 13:58:07 -0400 Subject: [PATCH 29/40] also checks for tlctx --- taichi/runtime/llvm/llvm_context.cpp | 3 ++- taichi/runtime/llvm/llvm_context.h | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index f0d2acbd450f7..a8d857d2c2ead 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -64,7 +64,7 @@ namespace lang { using namespace llvm; TaichiLLVMContext::TaichiLLVMContext(CompileConfig *config, Arch arch) - : arch_(arch) { + : config_(config), arch_(arch) { TI_TRACE("Creating Taichi llvm context for arch: {}", arch_name(arch)); main_thread_id_ = std::this_thread::get_id(); main_thread_data_ = get_this_thread_data(); @@ -143,6 +143,7 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*ctx); } else if (dt->is()) { + TI_ASSERT_INFO(config_->real_matrix, "Real matrix not enabled but got TensorType"); auto vectorty = dt->as(); auto dtype = this->get_data_type(vectorty->get_element_type()); return llvm::VectorType::get(dtype, vectorty->get_num_elements(), diff --git a/taichi/runtime/llvm/llvm_context.h b/taichi/runtime/llvm/llvm_context.h index afcccebbbbcfe..ae87699f48484 100644 --- a/taichi/runtime/llvm/llvm_context.h +++ b/taichi/runtime/llvm/llvm_context.h @@ -33,6 +33,7 @@ class TaichiLLVMContext { std::unique_ptr struct_module{nullptr}; ~ThreadLocalData(); }; + CompileConfig *config_; public: std::unique_ptr jit{nullptr}; From b2e101af0e0cca1e52400241caed7ca13cb48002 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 13:58:42 -0400 Subject: [PATCH 30/40] format --- taichi/runtime/llvm/llvm_context.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index a8d857d2c2ead..8e99f32503b58 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -143,7 +143,8 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*ctx); } else if (dt->is()) { - TI_ASSERT_INFO(config_->real_matrix, "Real matrix not enabled but got TensorType"); + TI_ASSERT_INFO(config_->real_matrix, + "Real matrix not enabled but got TensorType"); auto vectorty = dt->as(); auto dtype = this->get_data_type(vectorty->get_element_type()); return llvm::VectorType::get(dtype, vectorty->get_num_elements(), From 6fcf07026ebcf2535342035651bedb72581dd4b0 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 14:57:54 -0400 Subject: [PATCH 31/40] fix codegen: allocate pointer to vector --- taichi/codegen/llvm/codegen_llvm.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index ccaef8001cecb..4bb9c9f12c2fc 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -124,9 +124,8 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) { void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { if (stmt->ret_type->is()) { auto tensor_type = stmt->ret_type->cast(); - auto type = tlctx->get_data_type(tensor_type->get_element_type()); - auto array_size = tlctx->get_constant(tensor_type->get_num_elements()); - // Return type is [array_size x type]*. + auto type = tlctx->get_data_type(tensor_type); + // Return type is vector*. if (stmt->is_shared) { size_t data_element_size = tlctx->get_type_size( tlctx->get_data_type(tensor_type->get_element_type())); @@ -148,7 +147,7 @@ void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { tlctx->get_data_type(tensor_type->get_element_type()), 0); llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type); } else { - llvm_val[stmt] = create_entry_block_alloca(type, 0, array_size); + llvm_val[stmt] = create_entry_block_alloca(type, stmt->ret_type.is_pointer()); } } else { TI_ASSERT(stmt->width() == 1); From f43d2a80928529820b08058de7ab8e5529f8e574 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Aug 2022 18:59:26 +0000 Subject: [PATCH 32/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/codegen/llvm/codegen_llvm.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 4bb9c9f12c2fc..c91a072898230 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -147,7 +147,8 @@ void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { tlctx->get_data_type(tensor_type->get_element_type()), 0); llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type); } else { - llvm_val[stmt] = create_entry_block_alloca(type, stmt->ret_type.is_pointer()); + llvm_val[stmt] = + create_entry_block_alloca(type, stmt->ret_type.is_pointer()); } } else { TI_ASSERT(stmt->width() == 1); From c61331885b604a676ff3e35e852da30e93d70a36 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 17:19:09 -0400 Subject: [PATCH 33/40] check real matrix when allocating memory --- taichi/codegen/llvm/codegen_llvm.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 4bb9c9f12c2fc..9ac9b286ffdaa 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -124,8 +124,9 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) { void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { if (stmt->ret_type->is()) { auto tensor_type = stmt->ret_type->cast(); - auto type = tlctx->get_data_type(tensor_type); - // Return type is vector*. + auto type = kernel->program->config.real_matrix ? tlctx->get_data_type(tensor_type) : tlctx->get_data_type(tensor_type->get_element_type()); + // Return type is vector* if use real matrix. + // otherwise the return type is [type * array_size]* if (stmt->is_shared) { size_t data_element_size = tlctx->get_type_size( tlctx->get_data_type(tensor_type->get_element_type())); @@ -147,7 +148,12 @@ void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { tlctx->get_data_type(tensor_type->get_element_type()), 0); llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type); } else { - llvm_val[stmt] = create_entry_block_alloca(type, stmt->ret_type.is_pointer()); + if (kernel->program->config.real_matrix) + llvm_val[stmt] = + create_entry_block_alloca(type, stmt->ret_type.is_pointer()); + else + llvm_val[stmt] = + create_entry_block_alloca(type, 0, tlctx->get_constant(tensor_type->get_num_elements())); } } else { TI_ASSERT(stmt->width() == 1); From 9fc758d06ffb0d6f45186225162088b41ce74d0f Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 24 Aug 2022 17:23:05 -0400 Subject: [PATCH 34/40] format and fix tc for variable holding matrix expression --- taichi/codegen/llvm/codegen_llvm.cpp | 8 +++++--- taichi/ir/frontend_ir.cpp | 3 +++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 9ac9b286ffdaa..a6ddb17d88555 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -124,7 +124,9 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) { void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { if (stmt->ret_type->is()) { auto tensor_type = stmt->ret_type->cast(); - auto type = kernel->program->config.real_matrix ? tlctx->get_data_type(tensor_type) : tlctx->get_data_type(tensor_type->get_element_type()); + auto type = kernel->program->config.real_matrix + ? tlctx->get_data_type(tensor_type) + : tlctx->get_data_type(tensor_type->get_element_type()); // Return type is vector* if use real matrix. // otherwise the return type is [type * array_size]* if (stmt->is_shared) { @@ -152,8 +154,8 @@ void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { llvm_val[stmt] = create_entry_block_alloca(type, stmt->ret_type.is_pointer()); else - llvm_val[stmt] = - create_entry_block_alloca(type, 0, tlctx->get_constant(tensor_type->get_num_elements())); + llvm_val[stmt] = create_entry_block_alloca( + type, 0, tlctx->get_constant(tensor_type->get_num_elements())); } } else { TI_ASSERT(stmt->width() == 1); diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 7627fceb79cd9..09f8c33e3b569 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -1005,6 +1005,9 @@ Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, auto matrix_expr = make_local_matrix(shape, element_type, elements.exprs); auto v = this->expr_alloca(); this->expr_assign(v, matrix_expr, tb); + // type check for variable `v` since + // expr_assign couldn't propagate the info + v->ret_type = matrix_expr.cast()->dt; return v; } auto var = Expr(std::make_shared(get_next_id())); From f635b7c20d46f2f48bcc21f22c567255dcd1aa24 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 25 Aug 2022 00:23:20 -0400 Subject: [PATCH 35/40] refactor: change to `make_local_matrix` which returns only an Expr; postpone var alloca to ast transformer --- python/taichi/lang/impl.py | 6 ++++++ python/taichi/lang/matrix.py | 15 +++++---------- taichi/analysis/data_source_analysis.cpp | 2 +- taichi/ir/frontend_ir.cpp | 20 +++++--------------- taichi/ir/frontend_ir.h | 9 +++++---- taichi/program/function.cpp | 3 +-- taichi/program/kernel.cpp | 3 +-- taichi/python/export_lang.cpp | 1 + 8 files changed, 25 insertions(+), 34 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 3fd16ec2ade16..1727d4d5e44cf 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -37,6 +37,12 @@ def expr_init_local_tensor(shape, element_type, elements): get_runtime().get_current_src_info()) +@taichi_scope +def make_local_matrix(shape, element_type, elements): + return get_runtime().prog.current_ast_builder().make_local_matrix( + shape, element_type, elements) + + @taichi_scope def expr_init_shared_array(shape, element_type): return get_runtime().prog.current_ast_builder().expr_alloca_shared_array( diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 336fb534978d9..026a966de024b 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -104,16 +104,11 @@ def make_matrix(arr, dt=None): if dt is None: dt = _make_entries_initializer(is_matrix).infer_dt(arr) if not is_matrix: - return impl.Expr( - impl.expr_init_local_tensor([len(arr)], dt, - impl.make_expr_group( - [expr.Expr(elt) for elt in arr]))) - return impl.Expr( - impl.expr_init_local_tensor([len(arr), len(arr[0])], dt, - impl.make_expr_group([ - expr.Expr(elt) for row in arr - for elt in row - ]))) + return impl.make_local_matrix([len(arr)], dt, + [expr.Expr(elt).ptr for elt in arr]) + return impl.make_local_matrix( + [len(arr), len(arr[0])], dt, + [expr.Expr(elt).ptr for row in arr for elt in row]) class _MatrixBaseImpl: diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index 8bdeb72e9a70b..4a018afa6bf47 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -59,7 +59,7 @@ Stmt *get_store_data(Stmt *store_stmt) { std::vector get_store_destination(Stmt *store_stmt) { // If store_stmt provides some data sources, return the pointers of the data. - if (store_stmt->is()) { + if (store_stmt->is() && !store_stmt->ret_type->is()) { // The statement itself provides a data source (const [0]). return std::vector(1, store_stmt); } else if (auto local_store = store_stmt->cast()) { diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 09f8c33e3b569..1cc8404c73c77 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -87,10 +87,9 @@ FrontendForStmt::FrontendForStmt(const ExprGroup &loop_var, } } -FrontendContext::FrontendContext(Arch arch, bool real_matrix) { +FrontendContext::FrontendContext(Arch arch) { root_node_ = std::make_unique(); - current_builder_ = - std::make_unique(root_node_.get(), arch, real_matrix); + current_builder_ = std::make_unique(root_node_.get(), arch); } FrontendForStmt::FrontendForStmt(const Expr &loop_var, @@ -991,9 +990,9 @@ Expr ASTBuilder::expr_alloca() { return var; } -Expr make_local_matrix(const std::vector &shape, - const DataType &dt, - const std::vector &elements) { +Expr ASTBuilder::make_local_matrix(const std::vector &shape, + const DataType &dt, + const std::vector &elements) { return Expr(std::make_shared(elements, shape, dt)); } @@ -1001,15 +1000,6 @@ Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements, std::string tb) { - if (this->use_real_matrix_) { - auto matrix_expr = make_local_matrix(shape, element_type, elements.exprs); - auto v = this->expr_alloca(); - this->expr_assign(v, matrix_expr, tb); - // type check for variable `v` since - // expr_assign couldn't propagate the info - v->ret_type = matrix_expr.cast()->dt; - return v; - } auto var = Expr(std::make_shared(get_next_id())); this->insert(std::make_unique( std::static_pointer_cast(var.expr)->id, shape, diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 66b99e9772996..29d07e3e24cdd 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -867,11 +867,9 @@ class ASTBuilder { Arch arch_; ForLoopDecoratorRecorder for_loop_dec_; int id_counter_{0}; - bool use_real_matrix_{false}; public: - ASTBuilder(Block *initial, Arch arch, bool real_matrix) - : arch_(arch), use_real_matrix_(real_matrix) { + ASTBuilder(Block *initial, Arch arch) : arch_(arch) { stack_.push_back(initial); loop_state_stack_.push_back(None); } @@ -890,6 +888,9 @@ class ASTBuilder { const std::function &func); Expr make_id_expr(const std::string &name); + Expr make_local_matrix(const std::vector &shape, + const DataType &dt, + const std::vector &elements); Expr insert_thread_idx_expr(); Expr insert_patch_idx_expr(); void create_kernel_exprgroup_return(const ExprGroup &group); @@ -972,7 +973,7 @@ class FrontendContext { std::unique_ptr root_node_; public: - FrontendContext(Arch arch, bool real_matrix); + FrontendContext(Arch arch); ASTBuilder &builder() { return *current_builder_; diff --git a/taichi/program/function.cpp b/taichi/program/function.cpp index 64dd970cde4ea..e2655fed02432 100644 --- a/taichi/program/function.cpp +++ b/taichi/program/function.cpp @@ -12,8 +12,7 @@ Function::Function(Program *program, const FunctionKey &func_key) } void Function::set_function_body(const std::function &func) { - context = std::make_unique(program->config.arch, - program->config.real_matrix); + context = std::make_unique(program->config.arch); ir = context->get_root(); { // Note: this is not a mutex diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 0276a71269cf1..f6f0115955634 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -391,8 +391,7 @@ void Kernel::init(Program &program, is_accessor = false; is_evaluator = false; compiled_ = nullptr; - context = std::make_unique(program.config.arch, - program.config.real_matrix); + context = std::make_unique(program.config.arch); ir = context->get_root(); ir_is_ast_ = true; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 569e631611d85..29af61a2f9166 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -291,6 +291,7 @@ void export_lang(py::module &m) { .def("insert_deactivate", &ASTBuilder::insert_snode_deactivate) .def("insert_activate", &ASTBuilder::insert_snode_activate) .def("insert_external_func_call", &ASTBuilder::insert_external_func_call) + .def("make_local_matrix", &ASTBuilder::make_local_matrix) .def("expr_alloca", &ASTBuilder::expr_alloca) .def("expr_alloca_local_tensor", &ASTBuilder::expr_alloca_local_tensor) .def("expr_alloca_shared_array", &ASTBuilder::expr_alloca_shared_array) From f82dc259c6b7a34ff53c5544dc30a2593a478ec9 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 25 Aug 2022 00:24:27 -0400 Subject: [PATCH 36/40] get rid of duplicated check --- taichi/transforms/alg_simp.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index f7830fad9d850..5fc782c094842 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -100,11 +100,6 @@ class AlgSimp : public BasicStmtVisitor { bool optimize_multiplication(BinaryOpStmt *stmt) { // return true iff the IR is modified - if (stmt->lhs->ret_type->is() || - stmt->rhs->ret_type->is()) { - // TODO: support tensor type - return false; - } auto lhs = stmt->lhs->cast(); auto rhs = stmt->rhs->cast(); TI_ASSERT(stmt->op_type == BinaryOpType::mul); From bd68c2369f26b04edfe58aecf7604686b307bbcc Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 25 Aug 2022 01:42:48 -0400 Subject: [PATCH 37/40] save changes --- python/taichi/lang/impl.py | 2 +- python/taichi/lang/matrix.py | 8 ++++---- taichi/transforms/alg_simp.cpp | 5 ----- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 1727d4d5e44cf..632cbc0e95d42 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -38,7 +38,7 @@ def expr_init_local_tensor(shape, element_type, elements): @taichi_scope -def make_local_matrix(shape, element_type, elements): +def make_matrix_expr(shape, element_type, elements): return get_runtime().prog.current_ast_builder().make_local_matrix( shape, element_type, elements) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 026a966de024b..887637d5bc306 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -104,11 +104,11 @@ def make_matrix(arr, dt=None): if dt is None: dt = _make_entries_initializer(is_matrix).infer_dt(arr) if not is_matrix: - return impl.make_local_matrix([len(arr)], dt, - [expr.Expr(elt).ptr for elt in arr]) - return impl.make_local_matrix( + return impl.Expr(impl.make_matrix_expr([len(arr)], dt, + [expr.Expr(elt).ptr for elt in arr])) + return impl.Expr(impl.make_matrix_expr( [len(arr), len(arr[0])], dt, - [expr.Expr(elt).ptr for row in arr for elt in row]) + [expr.Expr(elt).ptr for row in arr for elt in row])) class _MatrixBaseImpl: diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 5fc782c094842..b8bdf2a32f698 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -151,11 +151,6 @@ class AlgSimp : public BasicStmtVisitor { bool optimize_division(BinaryOpStmt *stmt) { // return true iff the IR is modified - if (stmt->lhs->ret_type->is() || - stmt->rhs->ret_type->is()) { - // TODO: support tensor type - return false; - } auto rhs = stmt->rhs->cast(); TI_ASSERT(stmt->op_type == BinaryOpType::div || stmt->op_type == BinaryOpType::floordiv); From 5d00c98a6548815623150382f8ea4931d052c3da Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 25 Aug 2022 01:43:21 -0400 Subject: [PATCH 38/40] format --- python/taichi/lang/matrix.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 887637d5bc306..2bbb23c1acf93 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -104,11 +104,13 @@ def make_matrix(arr, dt=None): if dt is None: dt = _make_entries_initializer(is_matrix).infer_dt(arr) if not is_matrix: - return impl.Expr(impl.make_matrix_expr([len(arr)], dt, - [expr.Expr(elt).ptr for elt in arr])) - return impl.Expr(impl.make_matrix_expr( - [len(arr), len(arr[0])], dt, - [expr.Expr(elt).ptr for row in arr for elt in row])) + return impl.Expr( + impl.make_matrix_expr([len(arr)], dt, + [expr.Expr(elt).ptr for elt in arr])) + return impl.Expr( + impl.make_matrix_expr( + [len(arr), len(arr[0])], dt, + [expr.Expr(elt).ptr for row in arr for elt in row])) class _MatrixBaseImpl: From 3451397a0c94a69ddf981b0b2c5758236d1b7a36 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 25 Aug 2022 02:02:27 -0400 Subject: [PATCH 39/40] also rename cxx part --- taichi/ir/frontend_ir.cpp | 6 +++--- taichi/ir/frontend_ir.h | 6 +++--- taichi/python/export_lang.cpp | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 1cc8404c73c77..9ef38c9701896 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -990,9 +990,9 @@ Expr ASTBuilder::expr_alloca() { return var; } -Expr ASTBuilder::make_local_matrix(const std::vector &shape, - const DataType &dt, - const std::vector &elements) { +Expr ASTBuilder::make_matrix_expr(const std::vector &shape, + const DataType &dt, + const std::vector &elements) { return Expr(std::make_shared(elements, shape, dt)); } diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 29d07e3e24cdd..d4da9b635820f 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -888,9 +888,9 @@ class ASTBuilder { const std::function &func); Expr make_id_expr(const std::string &name); - Expr make_local_matrix(const std::vector &shape, - const DataType &dt, - const std::vector &elements); + Expr make_matrix_expr(const std::vector &shape, + const DataType &dt, + const std::vector &elements); Expr insert_thread_idx_expr(); Expr insert_patch_idx_expr(); void create_kernel_exprgroup_return(const ExprGroup &group); diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 29af61a2f9166..89ad1eaeb2a4f 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -291,7 +291,7 @@ void export_lang(py::module &m) { .def("insert_deactivate", &ASTBuilder::insert_snode_deactivate) .def("insert_activate", &ASTBuilder::insert_snode_activate) .def("insert_external_func_call", &ASTBuilder::insert_external_func_call) - .def("make_local_matrix", &ASTBuilder::make_local_matrix) + .def("make_local_matrix", &ASTBuilder::make_matrix_expr) .def("expr_alloca", &ASTBuilder::expr_alloca) .def("expr_alloca_local_tensor", &ASTBuilder::expr_alloca_local_tensor) .def("expr_alloca_shared_array", &ASTBuilder::expr_alloca_shared_array) From 5ed0e9300ef7e5f332d5cef60bafac199da34d23 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Thu, 25 Aug 2022 15:15:16 +0800 Subject: [PATCH 40/40] Apply suggestions from code review --- python/taichi/lang/impl.py | 2 +- taichi/python/export_lang.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 632cbc0e95d42..027696989dc9d 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -39,7 +39,7 @@ def expr_init_local_tensor(shape, element_type, elements): @taichi_scope def make_matrix_expr(shape, element_type, elements): - return get_runtime().prog.current_ast_builder().make_local_matrix( + return get_runtime().prog.current_ast_builder().make_matrix_expr( shape, element_type, elements) diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 89ad1eaeb2a4f..8b44098656b3d 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -291,7 +291,7 @@ void export_lang(py::module &m) { .def("insert_deactivate", &ASTBuilder::insert_snode_deactivate) .def("insert_activate", &ASTBuilder::insert_snode_activate) .def("insert_external_func_call", &ASTBuilder::insert_external_func_call) - .def("make_local_matrix", &ASTBuilder::make_matrix_expr) + .def("make_matrix_expr", &ASTBuilder::make_matrix_expr) .def("expr_alloca", &ASTBuilder::expr_alloca) .def("expr_alloca_local_tensor", &ASTBuilder::expr_alloca_local_tensor) .def("expr_alloca_shared_array", &ASTBuilder::expr_alloca_shared_array)