From ab37bfe5a7636c336b4f54fa8f3793f6a9211d2f Mon Sep 17 00:00:00 2001 From: AD1024 Date: Mon, 18 Jul 2022 18:36:52 -0700 Subject: [PATCH 01/35] init: the real matrix --- python/taichi/lang/impl.py | 13 +++++++ python/taichi/lang/matrix.py | 45 +++++++++++++++++++++-- taichi/analysis/gen_offline_cache_key.cpp | 8 ++++ taichi/analysis/offline_cache_util.cpp | 1 + taichi/codegen/cpu/codegen_cpu.cpp | 3 ++ taichi/inc/expressions.inc.h | 1 + taichi/ir/expression_printer.h | 7 ++++ taichi/ir/frontend_ir.cpp | 25 +++++++++++++ taichi/ir/frontend_ir.h | 23 ++++++++++++ taichi/program/compile_config.cpp | 1 + taichi/program/compile_config.h | 1 + taichi/program/kernel.cpp | 7 ++++ taichi/python/export_lang.cpp | 2 + 13 files changed, 133 insertions(+), 4 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 2df39b5692ac8..dc6f5a7088836 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -32,6 +32,17 @@ def expr_init_local_tensor(shape, element_type, elements): shape, element_type, elements) +@taichi_scope +def expr_init_matrix(shape, element_type, elements): + return get_runtime().prog.current_ast_builder().expr_alloca_matrix( + shape, element_type, elements) + + +@taichi_scope +def expr_init_indexed_matrix(mat, indices): + pass + + @taichi_scope def expr_init(rhs): if rhs is None: @@ -39,6 +50,8 @@ def expr_init(rhs): if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")): return type(rhs)(*rhs.to_list()) if isinstance(rhs, Matrix): + if current_cfg().real_matrix: + return rhs return Matrix(rhs.to_list()) if isinstance(rhs, Struct): return Struct(rhs.to_dict(include_methods=True)) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 84df8b4cf755d..520e5ccef586c 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -97,6 +97,29 @@ def prop_setter(instance, value): return cls +class _Matrix: + def __init__(self, arr, dt, is_vec=False): + cast = ( + lambda x: ops_mod.cast(x, dt) + ) if dt else lambda x: x if isinstance(x, expr.Expr) else expr.Expr(x) + if is_vec: + self._impl = impl.expr_init_matrix([len(arr)], dt, + [cast(elt).ptr for elt in arr]) + else: + self._impl = impl.expr_init_matrix( + [len(arr), len(arr[0])], dt, + [cast(elt).ptr for row in arr for elt in row]) + self.dt = dt + + def __getitem__(self, key): + # TODO(@AD1024): create IndexMatrixStmt here + raise NotImplementedError() + + def __setitem__(self, key, value): + # TODO(@AD1024): create IndexMatrixStmt here + raise NotImplementedError() + + class _MatrixBaseImpl: def __init__(self, m, n, entries): self.m = m @@ -308,6 +331,10 @@ def pyscope_or_ref(self, arr): return [[x] for x in arr] def no_dynamic_index(self, arr, dt): + if impl.current_cfg().real_matrix: + if dt is None: + dt = self.infer_dt(arr) + return _Matrix(arr, dt, is_vec=True) return [[impl.expr_init(ops_mod.cast(x, dt) if dt else x)] for x in arr] @@ -333,6 +360,10 @@ def pyscope_or_ref(self, arr): return [list(row) for row in arr] def no_dynamic_index(self, arr, dt): + if impl.current_cfg().real_matrix: + if dt is None: + dt = self.infer_dt(arr) + return _Matrix(arr, dt) return [[ impl.expr_init(ops_mod.cast(x, dt) if dt else x) for x in row ] for row in arr] @@ -435,10 +466,16 @@ def __init__(self, arr, dt=None, suppress_warning=False, is_ref=False): local_tensor_proxy, mat = initializer.with_dynamic_index( arr, dt) - self.n, self.m = len(mat), 1 - if len(mat) > 0: - self.m = len(mat[0]) - entries = [x for row in mat for x in row] + if impl.current_cfg().real_matrix: + self.n, self.m = len(arr), 1 + if len(arr) > 0: + self.m = len(arr[0]) + entries = mat + else: + self.n, self.m = len(mat), 1 + if len(mat) > 0: + self.m = len(mat[0]) + entries = [x for row in mat for x in row] if self.n * self.m > 32 and not suppress_warning: warning( diff --git a/taichi/analysis/gen_offline_cache_key.cpp b/taichi/analysis/gen_offline_cache_key.cpp index 935f08fc783ff..fada6f915598b 100644 --- a/taichi/analysis/gen_offline_cache_key.cpp +++ b/taichi/analysis/gen_offline_cache_key.cpp @@ -152,6 +152,14 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { emit(expr->dual); } + void visit(MatrixExpression *expr) override { + emit(ExprOpCode::MatrixExpression); + emit(expr->dt); + for (auto elt : expr->elements) { + emit(elt); + } + } + void visit(IndexExpression *expr) override { emit(ExprOpCode::IndexExpression); emit(expr->var); diff --git a/taichi/analysis/offline_cache_util.cpp b/taichi/analysis/offline_cache_util.cpp index 66b344b184c30..f8c6e652efbe2 100644 --- a/taichi/analysis/offline_cache_util.cpp +++ b/taichi/analysis/offline_cache_util.cpp @@ -64,6 +64,7 @@ static std::vector get_offline_cache_key_of_compile_config( serializer(config->demote_no_access_mesh_fors); serializer(config->experimental_auto_mesh_local); serializer(config->auto_mesh_local_default_occupacy); + serializer(config->real_matrix); serializer.finalize(); return serializer.data; diff --git a/taichi/codegen/cpu/codegen_cpu.cpp b/taichi/codegen/cpu/codegen_cpu.cpp index 901723416f602..67140adce29a6 100644 --- a/taichi/codegen/cpu/codegen_cpu.cpp +++ b/taichi/codegen/cpu/codegen_cpu.cpp @@ -286,7 +286,9 @@ FunctionType CodeGenCPU::codegen() { auto *llvm_prog = get_llvm_program(prog); auto *tlctx = llvm_prog->get_llvm_context(kernel->arch); auto &config = prog->config; + TI_TRACE("in codegen()"); std::string kernel_key = get_hashed_offline_cache_key(&config, kernel); + TI_TRACE("in codegen() 1"); kernel->set_kernel_key_for_cache(kernel_key); if (config.offline_cache && !config.async_mode && this->supports_offline_cache() && !kernel->is_evaluator) { @@ -299,6 +301,7 @@ FunctionType CodeGenCPU::codegen() { } } if (!kernel->lowered()) { + TI_TRACE("calling lowering"); kernel->lower(/*to_executable=*/false); } diff --git a/taichi/inc/expressions.inc.h b/taichi/inc/expressions.inc.h index 9b20ba86bd80a..7fb3ef3975947 100644 --- a/taichi/inc/expressions.inc.h +++ b/taichi/inc/expressions.inc.h @@ -6,6 +6,7 @@ PER_EXPRESSION(TernaryOpExpression) PER_EXPRESSION(InternalFuncCallExpression) PER_EXPRESSION(ExternalTensorExpression) PER_EXPRESSION(GlobalVariableExpression) +PER_EXPRESSION(MatrixExpression) PER_EXPRESSION(IndexExpression) PER_EXPRESSION(StrideExpression) PER_EXPRESSION(RangeAssumptionExpression) diff --git a/taichi/ir/expression_printer.h b/taichi/ir/expression_printer.h index 74887c890e099..a32a0468d6907 100644 --- a/taichi/ir/expression_printer.h +++ b/taichi/ir/expression_printer.h @@ -110,6 +110,13 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { } } + void visit(MatrixExpression *expr) override { + emit('['); + emit_vector(expr->elements); + emit(']'); + emit(fmt::format(" (dt={})", expr->dt->to_string())); + } + void visit(IndexExpression *expr) override { expr->var->accept(this); emit('['); diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index f2c03ea18c99d..da6b0742b1666 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -429,6 +429,15 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, return ctx->push_back(var->stmt, offset_stmt); } +void MatrixExpression::type_check(CompileConfig *config) { + // TODO: typecheck matrix +} + +void MatrixExpression::flatten(FlattenContext *ctx) { + // TODO: implement flatten + TI_NOT_IMPLEMENTED +} + bool IndexExpression::is_field() const { return var.is(); } @@ -960,6 +969,22 @@ Expr ASTBuilder::expr_alloca() { return var; } +Expr ASTBuilder::expr_alloca_local_matrix(const std::vector &shape, + const DataType &dt, + const std::vector &elements) { + auto var = Expr(std::make_shared(get_next_id())); + this->insert(std::make_unique( + std::static_pointer_cast(var.expr)->id, shape, dt)); + auto rhs = Expr(std::make_shared(elements, shape, dt)); + this->insert(std::make_unique(var, rhs)); + return var; +} + +Expr ASTBuilder::expr_indexed_matrix(const Expr &matrix, + const ExprGroup &indices) { + return Expr(std::make_shared(matrix, indices)); +} + Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements) { diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index fdbe86c95c620..f1cbfd0cd8aee 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -499,6 +499,25 @@ class GlobalVariableExpression : public Expression { TI_DEFINE_ACCEPT_FOR_EXPRESSION }; +class MatrixExpression : public Expression { + public: + std::vector elements; + DataType dt; + + MatrixExpression(const std::vector &elements, + std::vector shape, + DataType element_type) + : elements(elements) { + this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type)); + } + + void type_check(CompileConfig *config) override; + + void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION +}; + class IndexExpression : public Expression { public: // `var` is one of GlobalVariableExpression, ExternalTensorExpression, @@ -871,6 +890,10 @@ class ASTBuilder { const ExprGroup &args, const ExprGroup &outputs); Expr expr_alloca(); + Expr expr_alloca_local_matrix(const std::vector &shape, + const DataType &dt, + const std::vector &elements); + Expr expr_indexed_matrix(const Expr &matrix, const ExprGroup &indices); Expr expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements); diff --git a/taichi/program/compile_config.cpp b/taichi/program/compile_config.cpp index 08508a8e7de34..f6b9f8ba2033e 100644 --- a/taichi/program/compile_config.cpp +++ b/taichi/program/compile_config.cpp @@ -47,6 +47,7 @@ CompileConfig::CompileConfig() { detect_read_only = true; ndarray_use_cached_allocator = true; use_mesh = false; + real_matrix = false; saturating_grid_dim = 0; max_block_dim = 0; diff --git a/taichi/program/compile_config.h b/taichi/program/compile_config.h index eff885c347166..61fd094c3cbbe 100644 --- a/taichi/program/compile_config.h +++ b/taichi/program/compile_config.h @@ -44,6 +44,7 @@ struct CompileConfig { bool detect_read_only; bool ndarray_use_cached_allocator; bool use_mesh; + bool real_matrix; DataType default_fp; DataType default_ip; std::string extra_flags; diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 92e050af1a57f..2302339fea93e 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -63,6 +63,7 @@ Kernel::Kernel(Program &program, void Kernel::compile() { CurrentCallableGuard _(program, this); + TI_TRACE("compiling kernel {}", name); compiled_ = program->compile(*this); } @@ -111,9 +112,11 @@ void Kernel::lower(bool to_executable) { void Kernel::operator()(LaunchContextBuilder &ctx_builder) { if (!program->config.async_mode || this->is_evaluator) { + TI_TRACE("called kernel::operator()"); if (!compiled_) { compile(); } + TI_TRACE("compiled"); if (!this->from_offline_cache_) { for (auto &offloaded : ir->as()->statements) { @@ -121,7 +124,9 @@ void Kernel::operator()(LaunchContextBuilder &ctx_builder) { } } + TI_TRACE("calling compiled_"); compiled_(ctx_builder.get_context()); + TI_TRACE("called compiled_"); program->sync = (program->sync && arch_is_cpu(arch)); // Note that Kernel::arch may be different from program.config.arch @@ -441,12 +446,14 @@ void Kernel::init(Program &program, // concurrently, we need to lock this block of code together with // taichi::lang::context with a mutex. CurrentCallableGuard _(this->program, this); + TI_TRACE("Calling func in {}", name); func(); ir->as()->kernel = this; } if (!program.config.lazy_compilation) compile(); + TI_TRACE("Finish compiling {}", name); } // static diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index ebfa3a7616156..a4de588128236 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -196,6 +196,7 @@ void export_lang(py::module &m) { .def_readwrite("ndarray_use_cached_allocator", &CompileConfig::ndarray_use_cached_allocator) .def_readwrite("use_mesh", &CompileConfig::use_mesh) + .def_readwrite("real_matrix", &CompileConfig::real_matrix) .def_readwrite("cc_compile_cmd", &CompileConfig::cc_compile_cmd) .def_readwrite("cc_link_cmd", &CompileConfig::cc_link_cmd) .def_readwrite("async_opt_passes", &CompileConfig::async_opt_passes) @@ -302,6 +303,7 @@ void export_lang(py::module &m) { .def("insert_activate", &ASTBuilder::insert_snode_activate) .def("insert_external_func_call", &ASTBuilder::insert_external_func_call) .def("expr_alloca", &ASTBuilder::expr_alloca) + .def("expr_alloca_matrix", &ASTBuilder::expr_alloca_local_matrix) .def("expr_alloca_local_tensor", &ASTBuilder::expr_alloca_local_tensor) .def("create_assert_stmt", &ASTBuilder::create_assert_stmt) .def("expr_assign", &ASTBuilder::expr_assign) From 841b4050d4ef0ddedfa4083fcb5996232279d612 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 20 Jul 2022 17:27:54 -0700 Subject: [PATCH 02/35] save: refactor matrix init --- python/taichi/lang/ast/ast_transformer.py | 6 ++- python/taichi/lang/matrix.py | 51 +++++++++-------------- taichi/ir/frontend_ir.cpp | 27 ++++++++---- taichi/ir/frontend_ir.h | 2 +- taichi/ir/statements.h | 30 +++++++++++++ 5 files changed, 74 insertions(+), 42 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index e565045cb109c..03e377f769be4 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -6,7 +6,7 @@ from sys import version_info from taichi._lib import core as _ti_core -from taichi.lang import expr, impl, kernel_arguments, matrix, mesh +from taichi.lang import expr, impl, kernel_arguments, matrix, mesh, ops from taichi.lang import ops as ti_ops from taichi.lang._ndrange import _Ndrange, ndrange from taichi.lang.ast.ast_transformer_utils import (Builder, LoopStatus, @@ -479,6 +479,10 @@ def build_Call(ctx, node): node.ptr = impl.ti_format(*args, **keywords) return node.ptr + if isinstance(node.func, ast.Attribute) and node.func.ptr == Matrix: + node.ptr = matrix.make_matrix(*args, **keywords) + return node.ptr + if ASTTransformer.build_call_if_is_builtin(ctx, node, args, keywords): return node.ptr diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 520e5ccef586c..316c2ee4e1076 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -97,28 +97,21 @@ def prop_setter(instance, value): return cls -class _Matrix: - def __init__(self, arr, dt, is_vec=False): - cast = ( - lambda x: ops_mod.cast(x, dt) - ) if dt else lambda x: x if isinstance(x, expr.Expr) else expr.Expr(x) - if is_vec: - self._impl = impl.expr_init_matrix([len(arr)], dt, - [cast(elt).ptr for elt in arr]) - else: - self._impl = impl.expr_init_matrix( - [len(arr), len(arr[0])], dt, - [cast(elt).ptr for row in arr for elt in row]) - self.dt = dt - - def __getitem__(self, key): - # TODO(@AD1024): create IndexMatrixStmt here - raise NotImplementedError() - - def __setitem__(self, key, value): - # TODO(@AD1024): create IndexMatrixStmt here - raise NotImplementedError() - +def make_matrix(arr, dt=None, suppress_warning=False, is_ref=False): + if not impl.current_cfg().real_matrix or in_python_scope(): + return Matrix(arr, dt, suppress_warning, is_ref) + cast = ( + lambda x: ops_mod.cast(x, dt) + ) if dt else lambda x: x if isinstance(x, expr.Expr) else expr.Expr(x) + if len(arr) == 0: + return impl.expr_init_matrix([0], dt, []) + if not isinstance(arr[0], Iterable): + return impl.expr_init_matrix([len(arr)], dt, + [cast(elt).ptr for elt in arr]) + else: + return impl.expr_init_matrix( + [len(arr), len(arr[0])], dt, + [cast(elt).ptr for row in arr for elt in row]) class _MatrixBaseImpl: def __init__(self, m, n, entries): @@ -466,16 +459,10 @@ def __init__(self, arr, dt=None, suppress_warning=False, is_ref=False): local_tensor_proxy, mat = initializer.with_dynamic_index( arr, dt) - if impl.current_cfg().real_matrix: - self.n, self.m = len(arr), 1 - if len(arr) > 0: - self.m = len(arr[0]) - entries = mat - else: - self.n, self.m = len(mat), 1 - if len(mat) > 0: - self.m = len(mat[0]) - entries = [x for row in mat for x in row] + self.n, self.m = len(mat), 1 + if len(mat) > 0: + self.m = len(mat[0]) + entries = [x for row in mat for x in row] if self.n * self.m > 32 and not suppress_warning: warning( diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index da6b0742b1666..79cc3e010a0d6 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -431,11 +431,21 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, void MatrixExpression::type_check(CompileConfig *config) { // TODO: typecheck matrix + for (auto &arg : elements) { + TI_ASSERT_TYPE_CHECKED(arg); + } } void MatrixExpression::flatten(FlattenContext *ctx) { // TODO: implement flatten - TI_NOT_IMPLEMENTED + std::vector values; + for (auto &elt : elements) { + flatten_rvalue(elt, ctx); + values.push_back(elt->stmt); + } + ctx->push_back(std::make_unique(values)); + stmt = ctx->back_stmt(); + stmt->ret_type = this->ret_type; } bool IndexExpression::is_field() const { @@ -970,14 +980,15 @@ Expr ASTBuilder::expr_alloca() { } Expr ASTBuilder::expr_alloca_local_matrix(const std::vector &shape, - const DataType &dt, + const std::optional &dt, const std::vector &elements) { - auto var = Expr(std::make_shared(get_next_id())); - this->insert(std::make_unique( - std::static_pointer_cast(var.expr)->id, shape, dt)); - auto rhs = Expr(std::make_shared(elements, shape, dt)); - this->insert(std::make_unique(var, rhs)); - return var; + // auto var = Expr(std::make_shared(get_next_id())); + auto dtype = dt.value_or(PrimitiveType::unknown); + // this->insert(std::make_unique( + // std::static_pointer_cast(var.expr)->id, shape, dtype)); + auto rhs = Expr(std::make_shared(elements, shape, dtype)); + // this->insert(std::make_unique(var, rhs)); + return rhs; } Expr ASTBuilder::expr_indexed_matrix(const Expr &matrix, diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index f1cbfd0cd8aee..5efa927ed1920 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -891,7 +891,7 @@ class ASTBuilder { const ExprGroup &outputs); Expr expr_alloca(); Expr expr_alloca_local_matrix(const std::vector &shape, - const DataType &dt, + const std::optional &dt, const std::vector &elements); Expr expr_indexed_matrix(const Expr &matrix, const ExprGroup &indices); Expr expr_alloca_local_tensor(const std::vector &shape, diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 0ffa885f1fb27..63b2f15e41415 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -1803,5 +1803,35 @@ class MeshPatchIndexStmt : public Stmt { TI_DEFINE_ACCEPT_AND_CLONE }; +class MatrixInitStmt : public Stmt { + public: + std::vector values; + + MatrixInitStmt(const std::vector &values) : values(values) { + TI_STMT_REG_FIELDS; + } + + TI_STMT_DEF_FIELDS(ret_type); + TI_DEFINE_ACCEPT_AND_CLONE +}; + +class IndexStmt : public Stmt { + public: + Stmt *target; + std::vector indices; + + IndexStmt(Stmt *target, const std::vector &indices) + : target(target), indices(indices) { + TI_STMT_REG_FIELDS; + } + + bool has_global_side_effect() const override { + return false; + } + + TI_STMT_DEF_FIELDS(ret_type); + TI_DEFINE_ACCEPT_AND_CLONE +}; + } // namespace lang } // namespace taichi From 241223872f7733ab0674e059b1fd46423a554a9f Mon Sep 17 00:00:00 2001 From: AD1024 Date: Fri, 22 Jul 2022 15:35:34 -0700 Subject: [PATCH 03/35] fix some passes --- python/taichi/lang/ast/ast_transformer.py | 2 +- python/taichi/lang/matrix.py | 9 ++++----- taichi/inc/statements.inc.h | 1 + taichi/ir/control_flow_graph.cpp | 20 +++++++++++--------- taichi/ir/frontend_ir.cpp | 20 +++++++++----------- taichi/ir/statements.h | 12 ++++++++++-- taichi/ir/type.cpp | 5 +++++ taichi/ir/type.h | 2 ++ taichi/transforms/die.cpp | 7 +++++++ taichi/transforms/ir_printer.cpp | 15 ++++++++++++++- taichi/transforms/type_check.cpp | 17 +++++++++++++++-- 11 files changed, 79 insertions(+), 31 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 03e377f769be4..8f913b1d1be8c 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -6,7 +6,7 @@ from sys import version_info from taichi._lib import core as _ti_core -from taichi.lang import expr, impl, kernel_arguments, matrix, mesh, ops +from taichi.lang import expr, impl, kernel_arguments, matrix, mesh from taichi.lang import ops as ti_ops from taichi.lang._ndrange import _Ndrange, ndrange from taichi.lang.ast.ast_transformer_utils import (Builder, LoopStatus, diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 316c2ee4e1076..ec118f57c4906 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -107,11 +107,10 @@ def make_matrix(arr, dt=None, suppress_warning=False, is_ref=False): return impl.expr_init_matrix([0], dt, []) if not isinstance(arr[0], Iterable): return impl.expr_init_matrix([len(arr)], dt, - [cast(elt).ptr for elt in arr]) - else: - return impl.expr_init_matrix( - [len(arr), len(arr[0])], dt, - [cast(elt).ptr for row in arr for elt in row]) + [cast(elt).ptr for elt in arr]) + return impl.expr_init_matrix([len(arr), len(arr[0])], dt, + [cast(elt).ptr for row in arr for elt in row]) + class _MatrixBaseImpl: def __init__(self, m, n, entries): diff --git a/taichi/inc/statements.inc.h b/taichi/inc/statements.inc.h index fe12a8941f7f5..05056ce46b9ae 100644 --- a/taichi/inc/statements.inc.h +++ b/taichi/inc/statements.inc.h @@ -38,6 +38,7 @@ PER_STATEMENT(LoopUniqueStmt) PER_STATEMENT(AssertStmt) PER_STATEMENT(ExternalFuncCallStmt) PER_STATEMENT(ExternalTensorShapeAlongAxisStmt) +PER_STATEMENT(MatrixInitStmt) // Locals with reverse-mode autodiff PER_STATEMENT(AdStackAllocaStmt) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index b00eb9dc8366b..969b874b6562a 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -262,16 +262,18 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) { auto stmt = block->statements[i].get(); Stmt *result = nullptr; if (auto local_load = stmt->cast()) { - bool regular = true; - auto alloca = local_load->src[0].var; - for (int l = 0; l < stmt->width(); l++) { - if (local_load->src[l].offset != l || - local_load->src[l].var != alloca) { - regular = false; + for (int i = 0; i < local_load->src.size(); ++i) { + bool regular = true; + auto alloca = local_load->src[i].var; + for (int l = 0; l < stmt->width(); l++) { + if (local_load->src[l].offset != l || + local_load->src[l].var != alloca) { + regular = false; + } + } + if (regular) { + result = get_store_forwarding_data(alloca, i); } - } - if (regular) { - result = get_store_forwarding_data(alloca, i); } } else if (auto global_load = stmt->cast()) { if (!after_lower_access) { diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 79cc3e010a0d6..4c88853c18933 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -438,14 +438,17 @@ void MatrixExpression::type_check(CompileConfig *config) { void MatrixExpression::flatten(FlattenContext *ctx) { // TODO: implement flatten - std::vector values; + TI_ASSERT(this->dt->is()); + std::vector values; for (auto &elt : elements) { flatten_rvalue(elt, ctx); - values.push_back(elt->stmt); + auto elt_alloca = ctx->push_back(elt->stmt->ret_type); + ctx->push_back(elt_alloca, elt->stmt); + values.push_back(LocalAddress(elt_alloca, 0)); } - ctx->push_back(std::make_unique(values)); - stmt = ctx->back_stmt(); - stmt->ret_type = this->ret_type; + stmt = ctx->push_back(values, + this->dt->as()->get_shape()); + stmt->ret_type = this->dt; } bool IndexExpression::is_field() const { @@ -982,13 +985,8 @@ Expr ASTBuilder::expr_alloca() { Expr ASTBuilder::expr_alloca_local_matrix(const std::vector &shape, const std::optional &dt, const std::vector &elements) { - // auto var = Expr(std::make_shared(get_next_id())); auto dtype = dt.value_or(PrimitiveType::unknown); - // this->insert(std::make_unique( - // std::static_pointer_cast(var.expr)->id, shape, dtype)); - auto rhs = Expr(std::make_shared(elements, shape, dtype)); - // this->insert(std::make_unique(var, rhs)); - return rhs; + return Expr(std::make_shared(elements, shape, dtype)); } Expr ASTBuilder::expr_indexed_matrix(const Expr &matrix, diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 63b2f15e41415..4e4d872d2d40b 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -597,8 +597,16 @@ class GlobalStoreStmt : public Stmt { class LocalLoadStmt : public Stmt { public: LaneAttribute src; + std::vector shape; - explicit LocalLoadStmt(const LaneAttribute &src) : src(src) { + explicit LocalLoadStmt(const LaneAttribute &src) + : src(src), shape({static_cast(src.data.size())}) { + TI_STMT_REG_FIELDS; + } + + LocalLoadStmt(const LaneAttribute &src, + const std::vector &shape) + : src(src), shape(shape) { TI_STMT_REG_FIELDS; } @@ -1811,7 +1819,7 @@ class MatrixInitStmt : public Stmt { TI_STMT_REG_FIELDS; } - TI_STMT_DEF_FIELDS(ret_type); + TI_STMT_DEF_FIELDS(ret_type, values); TI_DEFINE_ACCEPT_AND_CLONE }; diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 194b1211d1547..e09d9fa75b860 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -1,3 +1,4 @@ +#include #include "taichi/ir/type.h" #include "taichi/ir/type_factory.h" @@ -87,6 +88,10 @@ std::string TensorType::to_string() const { return s; } +int TensorType::vector_width() const { + return std::reduce(shape_.begin(), shape_.end(), 1, std::multiplies()); +} + int Type::vector_width() const { return 1; // TODO: CPU vectorization } diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 9294f484103eb..15cae2c220fed 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -175,6 +175,8 @@ class TensorType : public Type { return shape_; } + int vector_width() const; + Type *get_compute_type() override { return this; } diff --git a/taichi/transforms/die.cpp b/taichi/transforms/die.cpp index 3176d8f576949..f6d696c4da498 100644 --- a/taichi/transforms/die.cpp +++ b/taichi/transforms/die.cpp @@ -108,6 +108,13 @@ class DIE : public IRVisitor { } stmt->all_blocks_accept(this, true); } + + void visit(MatrixInitStmt *stmt) override { + register_usage(stmt); + for (auto &elts : stmt->values) { + elts->accept(this); + } + } }; namespace irpass { diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index eb94695e83125..9d41cb76169a1 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -37,7 +37,7 @@ std::string block_dim_info(int block_dim) { std::string to_string(const LaneAttribute &ptr) { std::string ret = " ["; for (int i = 0; i < (int)ptr.size(); i++) { - ret += fmt::format("{}[{}]", ptr[i].var->name(), ptr[i].offset); + ret += fmt::format("{}", ptr[i].var->name()); if (i + 1 < (int)ptr.size()) ret += ", "; } @@ -792,6 +792,19 @@ class IRPrinter : public IRVisitor { print("{}{} = ref({})", stmt->type_hint(), stmt->name(), stmt->var->name()); } + void visit(MatrixInitStmt *stmt) override { + std::string result = ""; + result += fmt::format("{}{} = [", stmt->type_hint(), stmt->name()); + for (int i = 0; i < stmt->values.size(); ++i) { + result += stmt->values[i]->name(); + if (i != stmt->values.size() - 1) { + result += ", "; + } + } + result += "]"; + print(result); + } + private: std::string expr_to_string(Expr &expr) { return expr_to_string(expr.expr.get()); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 59ad94bee8ea5..f58f8f61c9f94 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -86,7 +86,8 @@ class TypeCheck : public IRVisitor { void visit(LocalLoadStmt *stmt) override { TI_ASSERT(stmt->width() == 1); - TI_ASSERT_INFO(stmt->src.size() == 1, "Vectorization has been disabled."); + // TI_ASSERT_INFO(stmt->src.size() == 1, "Vectorization has been + // disabled."); TI_ASSERT(stmt->src[0].var->is() || stmt->src[0].var->is()); if (auto ptr_offset_stmt = stmt->src[0].var->cast()) { @@ -106,9 +107,21 @@ class TypeCheck : public IRVisitor { .ptr_removed(); stmt->ret_type = lookup; } - } else { + } else if (stmt->src.size() == 1) { auto lookup = stmt->src[0].var->ret_type; stmt->ret_type = lookup; + } else { + TI_ASSERT(stmt->src.size() > 1); + auto acc = stmt->src[0].var->ret_type; + for (int i = 1; i < stmt->src.size(); i++) { + acc = promoted_type(acc, stmt->src[i].var->ret_type); + } + if (stmt->ret_type != PrimitiveType::unknown) { + TI_ASSERT(stmt->ret_type->is()); + acc = promoted_type( + acc, stmt->ret_type->as()->get_element_type()); + } + stmt->ret_type = TypeFactory::create_tensor_type(stmt->shape, acc); } } From 49d3d2bdeb719bf77ac1d2b18daf5d5160d05903 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Fri, 22 Jul 2022 16:18:03 -0700 Subject: [PATCH 04/35] PrintStmt for TensorType --- taichi/codegen/llvm/codegen_llvm.cpp | 16 +++++++++------ taichi/ir/type_utils.cpp | 30 ++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 6966c81f44d84..d27d9dfb50d40 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -776,12 +776,16 @@ void CodeGenLLVM::visit(PrintStmt *stmt) { if (std::holds_alternative(content)) { auto arg_stmt = std::get(content); auto value = llvm_val[arg_stmt]; - if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) || - arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) - value = builder->CreateFPExt(value, - tlctx->get_data_type(PrimitiveType::f64)); - args.push_back(value); - formats += data_type_format(arg_stmt->ret_type); + if (arg_stmt->ret_type->is()) { + formats += data_type_format(arg_stmt->ret_type); + } else { + if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) || + arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) + value = builder->CreateFPExt(value, + tlctx->get_data_type(PrimitiveType::f64)); + args.push_back(value); + formats += data_type_format(arg_stmt->ret_type); + } } else { auto arg_str = std::get(content); auto value = builder->CreateGlobalStringPtr(arg_str, "content_string"); diff --git a/taichi/ir/type_utils.cpp b/taichi/ir/type_utils.cpp index 879ba26ff599a..8fa7d37a94abc 100644 --- a/taichi/ir/type_utils.cpp +++ b/taichi/ir/type_utils.cpp @@ -53,6 +53,34 @@ int data_type_size(DataType t) { } } +std::string tensor_type_format_helper(const std::vector &shape, std::string format_str, int dim) { + std::string fmt = "["; + for (int i = 0; i < shape[dim]; ++i) { + if (dim != shape.size() - 1) { + fmt += tensor_type_format_helper(shape, format_str, dim + 1); + } else { + fmt += format_str; + } + if (i != shape[dim] - 1) { + fmt += ", "; + if (dim == 0) { + fmt += "\n"; + } + } + } + fmt += "]"; + return fmt; +} + +std::string tensor_type_format(DataType t) { + TI_ASSERT(t->is()); + auto tensor_type = t->as(); + auto shape = tensor_type->get_shape(); + auto element_type = tensor_type->get_element_type(); + auto element_type_format = data_type_format(element_type); + return tensor_type_format_helper(shape, element_type_format, 0); +} + std::string data_type_format(DataType dt) { if (dt->is_primitive(PrimitiveTypeID::i16)) { return "%hd"; @@ -79,6 +107,8 @@ std::string data_type_format(DataType dt) { // CodeGenLLVM::visit(PrintStmt *stmt) and // CodeGenLLVMCUDA::visit(PrintStmt *stmt) for more details. return "%f"; + } else if (dt->is()) { + return tensor_type_format(dt); } else { TI_NOT_IMPLEMENTED } From 82df5966fb7611313e97a6a31c8dc75747484fc1 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Fri, 22 Jul 2022 17:07:45 -0700 Subject: [PATCH 05/35] use MatrixInitStmt --- python/taichi/lang/matrix.py | 2 +- taichi/ir/frontend_ir.cpp | 14 ++++++++------ taichi/transforms/type_check.cpp | 3 +-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index ec118f57c4906..b1c914aa76751 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -102,7 +102,7 @@ def make_matrix(arr, dt=None, suppress_warning=False, is_ref=False): return Matrix(arr, dt, suppress_warning, is_ref) cast = ( lambda x: ops_mod.cast(x, dt) - ) if dt else lambda x: x if isinstance(x, expr.Expr) else expr.Expr(x) + ) if dt else (lambda x: x if isinstance(x, expr.Expr) else expr.Expr(x)) if len(arr) == 0: return impl.expr_init_matrix([0], dt, []) if not isinstance(arr[0], Iterable): diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 4c88853c18933..5707cc2a1159b 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -439,15 +439,17 @@ void MatrixExpression::type_check(CompileConfig *config) { void MatrixExpression::flatten(FlattenContext *ctx) { // TODO: implement flatten TI_ASSERT(this->dt->is()); - std::vector values; + // std::vector values; + std::vector values; for (auto &elt : elements) { flatten_rvalue(elt, ctx); - auto elt_alloca = ctx->push_back(elt->stmt->ret_type); - ctx->push_back(elt_alloca, elt->stmt); - values.push_back(LocalAddress(elt_alloca, 0)); + // auto elt_alloca = ctx->push_back(elt->stmt->ret_type); + // ctx->push_back(elt_alloca, elt->stmt); + values.push_back(elt->stmt); } - stmt = ctx->push_back(values, - this->dt->as()->get_shape()); + // stmt = ctx->push_back(values, + // this->dt->as()->get_shape()); + stmt = ctx->push_back(values); stmt->ret_type = this->dt; } diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index f58f8f61c9f94..826f3035fa57e 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -86,8 +86,7 @@ class TypeCheck : public IRVisitor { void visit(LocalLoadStmt *stmt) override { TI_ASSERT(stmt->width() == 1); - // TI_ASSERT_INFO(stmt->src.size() == 1, "Vectorization has been - // disabled."); + TI_ASSERT_INFO(stmt->src.size() == 1, "Vectorization has been disabled."); TI_ASSERT(stmt->src[0].var->is() || stmt->src[0].var->is()); if (auto ptr_offset_stmt = stmt->src[0].var->cast()) { From 85b868b9ffbab29748b1a63760f0b5cc1de4321c Mon Sep 17 00:00:00 2001 From: AD1024 Date: Fri, 22 Jul 2022 18:10:37 -0700 Subject: [PATCH 06/35] try codegen --- taichi/codegen/llvm/codegen_llvm.cpp | 13 +++++++++++++ taichi/codegen/llvm/codegen_llvm.h | 2 ++ 2 files changed, 15 insertions(+) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index d27d9dfb50d40..59c6f63593522 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -2468,6 +2468,19 @@ void CodeGenLLVM::visit(MeshPatchIndexStmt *stmt) { llvm_val[stmt] = get_arg(2); } +void CodeGenLLVM::visit(MatrixInitStmt *stmt) { + TI_TRACE("build matrix init"); + auto type = tlctx->get_data_type(stmt->ret_type->as()->get_element_type()); + auto *vectorty = llvm::VectorType::get(type, stmt->width()); + llvm::Value *vec = llvm::UndefValue::get(vectorty); + for (int i = 0; i < stmt->values.size(); ++i) { + auto *elem = llvm_val[stmt->values[i]]; + vec = builder->CreateInsertElement(vec, elem, i); + } + llvm_val[stmt] = vec; + // llvm_val[stmt] = tlctx->get_constant(0); +} + void CodeGenLLVM::eliminate_unused_functions() { TaichiLLVMContext::eliminate_unused_functions( module.get(), [&](std::string func_name) { diff --git a/taichi/codegen/llvm/codegen_llvm.h b/taichi/codegen/llvm/codegen_llvm.h index 9cb97fd056f49..478598b4cff82 100644 --- a/taichi/codegen/llvm/codegen_llvm.h +++ b/taichi/codegen/llvm/codegen_llvm.h @@ -379,6 +379,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(ReferenceStmt *stmt) override; + void visit(MatrixInitStmt *stmt) override; + llvm::Value *create_xlogue(std::unique_ptr &block); llvm::Value *create_mesh_xlogue(std::unique_ptr &block); From 2c746c6e72252adee6e541ae4c95d43a0f48886a Mon Sep 17 00:00:00 2001 From: AD1024 Date: Mon, 25 Jul 2022 16:46:59 -0700 Subject: [PATCH 07/35] finish codegen --- taichi/codegen/llvm/codegen_llvm.cpp | 15 ++++++++------- taichi/runtime/llvm/llvm_context.cpp | 4 ++++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 59c6f63593522..83f8d8bd693e3 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -125,10 +125,9 @@ void CodeGenLLVM::visit(Block *stmt_list) { void CodeGenLLVM::visit(AllocaStmt *stmt) { if (stmt->ret_type->is()) { auto tensor_type = stmt->ret_type->cast(); - auto type = tlctx->get_data_type(tensor_type->get_element_type()); - auto array_size = tlctx->get_constant(tensor_type->get_num_elements()); + auto type = tlctx->get_data_type(tensor_type); // Return type is [array_size x type]*. - llvm_val[stmt] = create_entry_block_alloca(type, 0, array_size); + llvm_val[stmt] = create_entry_block_alloca(type, false); } else { TI_ASSERT(stmt->width() == 1); llvm_val[stmt] = @@ -777,6 +776,10 @@ void CodeGenLLVM::visit(PrintStmt *stmt) { auto arg_stmt = std::get(content); auto value = llvm_val[arg_stmt]; if (arg_stmt->ret_type->is()) { + auto dtype = arg_stmt->ret_type->cast(); + for (int i = 0; i < dtype->get_num_elements(); ++i) { + args.push_back(builder->CreateExtractElement(value, i)); + } formats += data_type_format(arg_stmt->ret_type); } else { if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) || @@ -2469,10 +2472,8 @@ void CodeGenLLVM::visit(MeshPatchIndexStmt *stmt) { } void CodeGenLLVM::visit(MatrixInitStmt *stmt) { - TI_TRACE("build matrix init"); - auto type = tlctx->get_data_type(stmt->ret_type->as()->get_element_type()); - auto *vectorty = llvm::VectorType::get(type, stmt->width()); - llvm::Value *vec = llvm::UndefValue::get(vectorty); + auto type = tlctx->get_data_type(stmt->ret_type->as()); + llvm::Value *vec = llvm::UndefValue::get(type); for (int i = 0; i < stmt->values.size(); ++i) { auto *elem = llvm_val[stmt->values[i]]; vec = builder->CreateInsertElement(vec, elem, i); diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index 379acb1efdae0..a9cd16ec78967 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -135,6 +135,10 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { return llvm::Type::getInt64Ty(*ctx); } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*ctx); + } else if (dt->is()){ + auto vectorty = dt->as(); + auto dtype = this->get_data_type(vectorty->get_element_type()); + return llvm::VectorType::get(dtype, vectorty->get_num_elements()); } else { TI_INFO(data_type_name(dt)); TI_NOT_IMPLEMENTED From f21373cde8e3839060bcd88d0a255e6daa25756c Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 27 Jul 2022 08:22:04 -0700 Subject: [PATCH 08/35] [impl] basic indexing --- python/taichi/lang/expr.py | 7 +++++++ taichi/codegen/llvm/codegen_llvm.cpp | 9 ++++++--- taichi/ir/frontend_ir.cpp | 4 ++++ taichi/ir/frontend_ir.h | 1 + taichi/ir/statements.h | 8 ++++---- taichi/python/export_lang.cpp | 1 + taichi/transforms/cfg_optimization.cpp | 4 ++-- 7 files changed, 25 insertions(+), 9 deletions(-) diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 2f60e86aeb284..a8853469d0414 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -37,6 +37,13 @@ def __init__(self, *args, tb=None, dtype=None): if self.tb: self.ptr.set_tb(self.tb) self.ptr.type_check(impl.get_runtime().prog.config) + + def __getitem__(self, *indices): + if not isinstance(indices, (list, tuple)): + indices = (indices,) + + indices = make_expr_group(*indices) + return impl.get_runtime().prog.current_ast_builder().expr_indexed_matrix(self.ptr, indices) def __hash__(self): return self.ptr.get_raw_address() diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 83f8d8bd693e3..526eded11aaa8 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1710,8 +1710,12 @@ void CodeGenLLVM::visit(PtrOffsetStmt *stmt) { llvm_val[stmt] = builder->CreateGEP(ptr_ty, llvm_val[stmt->origin], llvm_val[stmt->offset]); #else - llvm_val[stmt] = - builder->CreateGEP(llvm_val[stmt->origin], llvm_val[stmt->offset]); + auto stmt_dtype = stmt->origin->ret_type->as(); + auto element_dtype = stmt_dtype->get_element_type(); + auto llvm_type = tlctx->get_data_type(element_dtype); + auto casted_ptr = builder->CreateBitCast(llvm_val[stmt->origin], + llvm::PointerType::get(llvm_type, 0)); + llvm_val[stmt] = builder->CreateGEP(casted_ptr, llvm_val[stmt->offset]); #endif } else { auto origin_address = builder->CreatePtrToInt( @@ -2479,7 +2483,6 @@ void CodeGenLLVM::visit(MatrixInitStmt *stmt) { vec = builder->CreateInsertElement(vec, elem, i); } llvm_val[stmt] = vec; - // llvm_val[stmt] = tlctx->get_constant(0); } void CodeGenLLVM::eliminate_unused_functions() { diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 5707cc2a1159b..6750e8adb966d 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -412,6 +412,9 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, std::vector shape, int stride) { flatten_lvalue(var, ctx); + if (var->stmt->ret_type->is_primitive(PrimitiveTypeID::unknown)) { + var->stmt->ret_type = var->ret_type; + } Stmt *offset_stmt = ctx->push_back(TypedConstant(0)); for (int i = 0; i < (int)indices.size(); ++i) { flatten_rvalue(indices[i], ctx); @@ -993,6 +996,7 @@ Expr ASTBuilder::expr_alloca_local_matrix(const std::vector &shape, Expr ASTBuilder::expr_indexed_matrix(const Expr &matrix, const ExprGroup &indices) { + TI_ASSERT(matrix.get_ret_type()->is()); return Expr(std::make_shared(matrix, indices)); } diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 5efa927ed1920..53599e5f6fda0 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -509,6 +509,7 @@ class MatrixExpression : public Expression { DataType element_type) : elements(elements) { this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type)); + this->ret_type = this->dt; } void type_check(CompileConfig *config) override; diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 4e4d872d2d40b..11f301e636c43 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -1826,10 +1826,10 @@ class MatrixInitStmt : public Stmt { class IndexStmt : public Stmt { public: Stmt *target; - std::vector indices; + Stmt *index; - IndexStmt(Stmt *target, const std::vector &indices) - : target(target), indices(indices) { + IndexStmt(Stmt *target, Stmt *index) + : target(target), index(index) { TI_STMT_REG_FIELDS; } @@ -1837,7 +1837,7 @@ class IndexStmt : public Stmt { return false; } - TI_STMT_DEF_FIELDS(ret_type); + TI_STMT_DEF_FIELDS(ret_type, target, index); TI_DEFINE_ACCEPT_AND_CLONE }; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index a4de588128236..80b6884d6532b 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -304,6 +304,7 @@ void export_lang(py::module &m) { .def("insert_external_func_call", &ASTBuilder::insert_external_func_call) .def("expr_alloca", &ASTBuilder::expr_alloca) .def("expr_alloca_matrix", &ASTBuilder::expr_alloca_local_matrix) + .def("expr_indexed_matrix", &ASTBuilder::expr_indexed_matrix) .def("expr_alloca_local_tensor", &ASTBuilder::expr_alloca_local_tensor) .def("create_assert_stmt", &ASTBuilder::create_assert_stmt) .def("expr_assign", &ASTBuilder::expr_assign) diff --git a/taichi/transforms/cfg_optimization.cpp b/taichi/transforms/cfg_optimization.cpp index 48a49a6942eae..22050c59a8379 100644 --- a/taichi/transforms/cfg_optimization.cpp +++ b/taichi/transforms/cfg_optimization.cpp @@ -20,8 +20,8 @@ bool cfg_optimization( cfg->simplify_graph(); if (cfg->store_to_load_forwarding(after_lower_access)) modified = true; - if (cfg->dead_store_elimination(after_lower_access, lva_config_opt)) - modified = true; + // if (cfg->dead_store_elimination(after_lower_access, lva_config_opt)) + // modified = true; if (modified) result_modified = true; else From 23656ab0d5c56f7b0c392f26e7447b1005f85531 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 27 Jul 2022 08:29:03 -0700 Subject: [PATCH 09/35] format code --- python/taichi/lang/expr.py | 9 +++++---- python/taichi/lang/matrix.py | 5 ++--- taichi/codegen/llvm/codegen_llvm.cpp | 8 ++++---- taichi/ir/statements.h | 3 +-- taichi/ir/type_utils.cpp | 4 +++- taichi/program/kernel.cpp | 4 ---- taichi/runtime/llvm/llvm_context.cpp | 2 +- 7 files changed, 16 insertions(+), 19 deletions(-) diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index a8853469d0414..45aadc3d5aca2 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -37,13 +37,14 @@ def __init__(self, *args, tb=None, dtype=None): if self.tb: self.ptr.set_tb(self.tb) self.ptr.type_check(impl.get_runtime().prog.config) - + def __getitem__(self, *indices): if not isinstance(indices, (list, tuple)): - indices = (indices,) - + indices = (indices, ) + indices = make_expr_group(*indices) - return impl.get_runtime().prog.current_ast_builder().expr_indexed_matrix(self.ptr, indices) + return impl.get_runtime().prog.current_ast_builder( + ).expr_indexed_matrix(self.ptr, indices) def __hash__(self): return self.ptr.get_raw_address() diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index b1c914aa76751..0b17c3295e038 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -100,9 +100,8 @@ def prop_setter(instance, value): def make_matrix(arr, dt=None, suppress_warning=False, is_ref=False): if not impl.current_cfg().real_matrix or in_python_scope(): return Matrix(arr, dt, suppress_warning, is_ref) - cast = ( - lambda x: ops_mod.cast(x, dt) - ) if dt else (lambda x: x if isinstance(x, expr.Expr) else expr.Expr(x)) + cast = (lambda x: ops_mod.cast(x, dt)) if dt else ( + lambda x: x if isinstance(x, expr.Expr) else expr.Expr(x)) if len(arr) == 0: return impl.expr_init_matrix([0], dt, []) if not isinstance(arr[0], Iterable): diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 526eded11aaa8..86652c6683a6c 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -784,8 +784,8 @@ void CodeGenLLVM::visit(PrintStmt *stmt) { } else { if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) || arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) - value = builder->CreateFPExt(value, - tlctx->get_data_type(PrimitiveType::f64)); + value = builder->CreateFPExt( + value, tlctx->get_data_type(PrimitiveType::f64)); args.push_back(value); formats += data_type_format(arg_stmt->ret_type); } @@ -1713,8 +1713,8 @@ void CodeGenLLVM::visit(PtrOffsetStmt *stmt) { auto stmt_dtype = stmt->origin->ret_type->as(); auto element_dtype = stmt_dtype->get_element_type(); auto llvm_type = tlctx->get_data_type(element_dtype); - auto casted_ptr = builder->CreateBitCast(llvm_val[stmt->origin], - llvm::PointerType::get(llvm_type, 0)); + auto casted_ptr = builder->CreateBitCast( + llvm_val[stmt->origin], llvm::PointerType::get(llvm_type, 0)); llvm_val[stmt] = builder->CreateGEP(casted_ptr, llvm_val[stmt->offset]); #endif } else { diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 11f301e636c43..88f5f5c8c05fd 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -1828,8 +1828,7 @@ class IndexStmt : public Stmt { Stmt *target; Stmt *index; - IndexStmt(Stmt *target, Stmt *index) - : target(target), index(index) { + IndexStmt(Stmt *target, Stmt *index) : target(target), index(index) { TI_STMT_REG_FIELDS; } diff --git a/taichi/ir/type_utils.cpp b/taichi/ir/type_utils.cpp index 8fa7d37a94abc..8fd6229496794 100644 --- a/taichi/ir/type_utils.cpp +++ b/taichi/ir/type_utils.cpp @@ -53,7 +53,9 @@ int data_type_size(DataType t) { } } -std::string tensor_type_format_helper(const std::vector &shape, std::string format_str, int dim) { +std::string tensor_type_format_helper(const std::vector &shape, + std::string format_str, + int dim) { std::string fmt = "["; for (int i = 0; i < shape[dim]; ++i) { if (dim != shape.size() - 1) { diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 2302339fea93e..485db45998c87 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -112,11 +112,9 @@ void Kernel::lower(bool to_executable) { void Kernel::operator()(LaunchContextBuilder &ctx_builder) { if (!program->config.async_mode || this->is_evaluator) { - TI_TRACE("called kernel::operator()"); if (!compiled_) { compile(); } - TI_TRACE("compiled"); if (!this->from_offline_cache_) { for (auto &offloaded : ir->as()->statements) { @@ -124,9 +122,7 @@ void Kernel::operator()(LaunchContextBuilder &ctx_builder) { } } - TI_TRACE("calling compiled_"); compiled_(ctx_builder.get_context()); - TI_TRACE("called compiled_"); program->sync = (program->sync && arch_is_cpu(arch)); // Note that Kernel::arch may be different from program.config.arch diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index a9cd16ec78967..bdc4f2f82938a 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -135,7 +135,7 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { return llvm::Type::getInt64Ty(*ctx); } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*ctx); - } else if (dt->is()){ + } else if (dt->is()) { auto vectorty = dt->as(); auto dtype = this->get_data_type(vectorty->get_element_type()); return llvm::VectorType::get(dtype, vectorty->get_num_elements()); From 1342f7c4a1ccae439518a4d71ee627da3d3bc6de Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 27 Jul 2022 12:54:38 -0700 Subject: [PATCH 10/35] [impl] basic operators step 1 --- taichi/analysis/same_statements.cpp | 18 ++++++++++++++++++ taichi/ir/frontend_ir.cpp | 15 +++++++++++++-- taichi/ir/frontend_ir.h | 26 +++++++++++++++++++++++++- taichi/ir/type_utils.h | 6 ++++++ taichi/transforms/cfg_optimization.cpp | 4 ++-- taichi/transforms/type_check.cpp | 18 ++++++++++++++++++ 6 files changed, 82 insertions(+), 5 deletions(-) diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index 4a8cd890ff1b8..515928a33d59c 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -236,6 +236,24 @@ class IRNodeComparator : public IRVisitor { basic_check(stmt); } + void visit(MatrixInitStmt *stmt) override { + basic_check(stmt); + if (!same) + return; + auto o = other_node_->as(); + if (stmt->values.size() != o->values.size()) { + same = false; + return; + } + for (int i = 0; i < stmt->values.size(); ++i) { + other_node_ = o->values[i]; + stmt->values[i]->accept(this); + other_node_ = o; + if (!same) + return; + } + } + void visit(IfStmt *stmt) override { basic_check(stmt); if (!same) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 6750e8adb966d..171836aed86cc 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -195,8 +195,8 @@ void BinaryOpExpression::type_check(CompileConfig *config) { binary_op_type_symbol(type), lhs->ret_type->to_string(), rhs->ret_type->to_string())); }; - if (!lhs_type->is() || !rhs_type->is()) - error(); + // if (!lhs_type->is() || !rhs_type->is()) + // error(); if (binary_is_bitwise(type) && (!is_integral(lhs_type) || !is_integral(rhs_type))) error(); @@ -212,6 +212,17 @@ void BinaryOpExpression::type_check(CompileConfig *config) { return; } + if (lhs_type->is()) { + auto dtype = lhs_type->as()->get_element_type(); + if (rhs_type->is()) { + ret_type = promoted_type(dtype, rhs_type); + } else { + TI_ASSERT(rhs_type->is()); + ret_type = promoted_type(dtype, rhs_type->as()->get_element_type()); + } + return; + } + // Some backends such as vulkan doesn't support fp64 // Try not promoting to fp64 unless necessary if (type == BinaryOpType::atan2) { diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 53599e5f6fda0..56f1c831bdd7f 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -374,7 +374,31 @@ class BinaryOpExpression : public Expression { Expr lhs, rhs; BinaryOpExpression(const BinaryOpType &type, const Expr &lhs, const Expr &rhs) - : type(type), lhs(lhs), rhs(rhs) { + : type(type) { + + auto to_broadcast_tensor = [](const Expr &elt, const DataType &dt) -> Expr { + TI_ASSERT(dt->is()); + auto tensor_type = dt->as(); + auto elt_type = tensor_type->get_element_type(); + TI_ASSERT_INFO(elt_type->is(), "Only primitive types are supported in Tensors, got {}", elt_type->to_string()); + std::vector broadcast_values(tensor_type->get_num_elements(), elt); + return Expr::make(broadcast_values, tensor_type->get_shape(), elt_type); + }; + + auto unify_expr = [&](const Expr &e1, const Expr &e2) { + if ((!e1->ret_type->is() && !e2->ret_type->is()) || + (e1->ret_type->is() && e2->ret_type->is())) { + return std::tuple(e1, e2); + } + if (!e1->ret_type->is()) { + return std::tuple(to_broadcast_tensor(e1, e2->ret_type), e2); + } + return std::tuple(e1, to_broadcast_tensor(e2, e1->ret_type)); + }; + + auto [unified_l, unified_r] = unify_expr(lhs, rhs); + this->lhs = unified_l; + this->rhs = unified_r; } void type_check(CompileConfig *config) override; diff --git a/taichi/ir/type_utils.h b/taichi/ir/type_utils.h index 32898f16c9382..5fdeef6050b19 100644 --- a/taichi/ir/type_utils.h +++ b/taichi/ir/type_utils.h @@ -80,6 +80,8 @@ inline bool is_quant(DataType dt) { } inline bool is_real(DataType dt) { + if (dt->is()) + return is_real(dt->as()->get_element_type()); return dt->is_primitive(PrimitiveTypeID::f16) || dt->is_primitive(PrimitiveTypeID::f32) || dt->is_primitive(PrimitiveTypeID::f64) || dt->is() || @@ -87,6 +89,8 @@ inline bool is_real(DataType dt) { } inline bool is_integral(DataType dt) { + if (dt->is()) + return is_integral(dt->as()->get_element_type()); return dt->is_primitive(PrimitiveTypeID::i8) || dt->is_primitive(PrimitiveTypeID::i16) || dt->is_primitive(PrimitiveTypeID::i32) || @@ -100,6 +104,8 @@ inline bool is_integral(DataType dt) { inline bool is_signed(DataType dt) { // Shall we return false if is_integral returns false? TI_ASSERT(is_integral(dt)); + if (auto t = dt->cast()) + return is_signed(t->get_element_type()); if (auto t = dt->cast()) return t->get_is_signed(); return dt->is_primitive(PrimitiveTypeID::i8) || diff --git a/taichi/transforms/cfg_optimization.cpp b/taichi/transforms/cfg_optimization.cpp index 22050c59a8379..48a49a6942eae 100644 --- a/taichi/transforms/cfg_optimization.cpp +++ b/taichi/transforms/cfg_optimization.cpp @@ -20,8 +20,8 @@ bool cfg_optimization( cfg->simplify_graph(); if (cfg->store_to_load_forwarding(after_lower_access)) modified = true; - // if (cfg->dead_store_elimination(after_lower_access, lva_config_opt)) - // modified = true; + if (cfg->dead_store_elimination(after_lower_access, lva_config_opt)) + modified = true; if (modified) result_modified = true; else diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 826f3035fa57e..eb4de86a23912 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -327,6 +327,10 @@ class TypeCheck : public IRVisitor { } } + if (stmt->lhs->ret_type->is()) { + + } + if (stmt->lhs->ret_type != stmt->rhs->ret_type) { DataType ret_type; if (is_shift_op(stmt->op_type)) { @@ -556,6 +560,20 @@ class TypeCheck : public IRVisitor { stmt->ret_type = stmt->var->ret_type; stmt->ret_type.set_is_pointer(true); } + + void visit(MatrixInitStmt *stmt) override { + TI_ASSERT_INFO(stmt->ret_type->is(), "Matrix should have tensor type, got {}", stmt->ret_type->to_string()); + auto tensor_type = stmt->ret_type->as(); + auto element_dtype = tensor_type->get_element_type(); + for (auto elt : stmt->values) { + element_dtype = promoted_type(element_dtype, elt->ret_type); + } + for (int i = 0; i < stmt->values.size(); ++i) { + if (element_dtype != stmt->values[i]->ret_type) { + cast(stmt->values[i], element_dtype); + } + } + } }; namespace irpass { From fb39d783ab905a30c62b36484e5e45d5c2b9f9e2 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 27 Jul 2022 15:21:53 -0700 Subject: [PATCH 11/35] [fix] skip alg simp for some cases --- taichi/codegen/llvm/codegen_llvm.cpp | 79 +++++++++++++++++++--------- taichi/ir/type.h | 4 ++ taichi/ir/type_utils.h | 2 - taichi/transforms/alg_simp.cpp | 29 ++++++++-- taichi/transforms/constant_fold.cpp | 3 ++ taichi/transforms/type_check.cpp | 36 +++++++++++-- 6 files changed, 120 insertions(+), 33 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 86652c6683a6c..067e58ce169fd 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -325,24 +325,43 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) { llvm_val[stmt] = \ builder->CreateIntrinsic(llvm::Intrinsic::x, {input_type}, {input}); \ } + auto get_cast_op = [](auto from, auto to) { + llvm::CastInst::CastOps cast_op; + if (is_real(from) && is_integral(to)) { + cast_op = is_signed(to) ? llvm::Instruction::CastOps::FPToSI + : llvm::Instruction::CastOps::FPToUI; + } else if (is_integral(from) && is_real(to)) { + cast_op = is_signed(from) ? llvm::Instruction::CastOps::SIToFP + : llvm::Instruction::CastOps::UIToFP; + } else { + TI_P(data_type_name(from)); + TI_P(data_type_name(to)); + TI_NOT_IMPLEMENTED; + } + return cast_op; + }; if (stmt->op_type == UnaryOpType::cast_value) { llvm::CastInst::CastOps cast_op; auto from = stmt->operand->ret_type; auto to = stmt->cast_type; if (from == to) { llvm_val[stmt] = llvm_val[stmt->operand]; + } else if (from->is()) { + TI_ASSERT_INFO(to->is(), + "Only tensor to tensor cast is supported, {} provided", to->to_string()); + auto from_ty = from->cast()->get_element_type(); + auto to_ty = to->cast()->get_element_type(); + cast_op = get_cast_op(from_ty, to_ty); + auto type = tlctx->get_data_type(to->cast()); + llvm::Value *vec = llvm::UndefValue::get(type); + for (int i = 0; i < from->cast()->get_num_elements(); ++i) { + auto elem = builder->CreateExtractElement(llvm_val[stmt->operand], i); + auto cast_input = builder->CreateCast(cast_op, elem, tlctx->get_data_type(to_ty)); + vec = builder->CreateInsertElement(vec, cast_input, i); + } + llvm_val[stmt] = vec; } else if (is_real(from) != is_real(to)) { - if (is_real(from) && is_integral(to)) { - cast_op = is_signed(to) ? llvm::Instruction::CastOps::FPToSI - : llvm::Instruction::CastOps::FPToUI; - } else if (is_integral(from) && is_real(to)) { - cast_op = is_signed(from) ? llvm::Instruction::CastOps::SIToFP - : llvm::Instruction::CastOps::UIToFP; - } else { - TI_P(data_type_name(from)); - TI_P(data_type_name(to)); - TI_NOT_IMPLEMENTED; - } + cast_op = get_cast_op(from, to); auto cast_type = to->is_primitive(PrimitiveTypeID::f16) ? PrimitiveType::f32 : stmt->cast_type; @@ -412,8 +431,16 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { auto op = stmt->op_type; auto ret_type = stmt->ret_type; + auto is_real_tensor = [](const DataType &dt) { + return dt->is() && is_real(dt->cast()->get_element_type()); + }; + + auto is_integral_tensor = [](const DataType &dt) { + return dt->is() && is_integral(dt->cast()->get_element_type()); + }; + if (op == BinaryOpType::add) { - if (is_real(stmt->ret_type)) { + if (is_real(stmt->ret_type) || is_real_tensor(stmt->ret_type)) { llvm_val[stmt] = builder->CreateFAdd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { @@ -421,7 +448,7 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { builder->CreateAdd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } } else if (op == BinaryOpType::sub) { - if (is_real(stmt->ret_type)) { + if (is_real(stmt->ret_type) || is_real_tensor(stmt->ret_type)) { llvm_val[stmt] = builder->CreateFSub(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { @@ -429,7 +456,7 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { builder->CreateSub(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } } else if (op == BinaryOpType::mul) { - if (is_real(stmt->ret_type)) { + if (is_real(stmt->ret_type) || is_real_tensor(stmt->ret_type)) { llvm_val[stmt] = builder->CreateFMul(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { @@ -437,7 +464,7 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { builder->CreateMul(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } } else if (op == BinaryOpType::floordiv) { - if (is_integral(ret_type)) + if (is_integral(ret_type) || is_integral_tensor(ret_type)) llvm_val[stmt] = create_call(fmt::format("floordiv_{}", data_type_name(ret_type)), {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); @@ -447,7 +474,7 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { llvm::Intrinsic::floor, {tlctx->get_data_type(ret_type)}, {div}); } } else if (op == BinaryOpType::div) { - if (is_real(stmt->ret_type)) { + if (is_real(stmt->ret_type) || is_real_tensor(stmt->ret_type)) { llvm_val[stmt] = builder->CreateFDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { @@ -484,7 +511,7 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { create_call("max_" #x, {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); \ } - if (is_real(ret_type)) { + if (is_real(ret_type) || is_real_tensor(ret_type)) { llvm_val[stmt] = builder->CreateMaxNum(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } @@ -505,7 +532,7 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { create_call("min_" #x, {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); \ } - if (is_real(ret_type)) { + if (is_real(ret_type) || is_real_tensor) { llvm_val[stmt] = builder->CreateMinNum(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } @@ -523,13 +550,13 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { llvm::Value *cmp = nullptr; auto input_type = stmt->lhs->ret_type; if (op == BinaryOpType::cmp_eq) { - if (is_real(input_type)) { + if (is_real(input_type) || is_real_tensor(input_type)) { cmp = builder->CreateFCmpOEQ(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { cmp = builder->CreateICmpEQ(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } } else if (op == BinaryOpType::cmp_le) { - if (is_real(input_type)) { + if (is_real(input_type) || is_real_tensor(input_type)) { cmp = builder->CreateFCmpOLE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { if (is_signed(input_type)) { @@ -541,7 +568,7 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { } } } else if (op == BinaryOpType::cmp_ge) { - if (is_real(input_type)) { + if (is_real(input_type) || is_real_tensor(input_type)) { cmp = builder->CreateFCmpOGE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { if (is_signed(input_type)) { @@ -553,7 +580,7 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { } } } else if (op == BinaryOpType::cmp_lt) { - if (is_real(input_type)) { + if (is_real(input_type) || is_real_tensor(input_type)) { cmp = builder->CreateFCmpOLT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { if (is_signed(input_type)) { @@ -565,7 +592,7 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { } } } else if (op == BinaryOpType::cmp_gt) { - if (is_real(input_type)) { + if (is_real(input_type) || is_real_tensor(input_type)) { cmp = builder->CreateFCmpOGT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { if (is_signed(input_type)) { @@ -577,7 +604,7 @@ void CodeGenLLVM::visit(BinaryOpStmt *stmt) { } } } else if (op == BinaryOpType::cmp_ne) { - if (is_real(input_type)) { + if (is_real(input_type) || is_real_tensor(input_type)) { cmp = builder->CreateFCmpONE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { cmp = builder->CreateICmpNE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); @@ -666,6 +693,10 @@ llvm::Type *CodeGenLLVM::llvm_type(DataType dt) { return llvm::Type::getDoubleTy(*llvm_context); } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*llvm_context); + } else if (dt->is()) { + auto tensor_type = dt->cast(); + auto element_type = llvm_type(tensor_type->get_element_type()); + return llvm::VectorType::get(element_type, tensor_type->get_num_elements()); } else { TI_NOT_IMPLEMENTED; } diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 15cae2c220fed..f3d1c0fdaa1d1 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -383,6 +383,10 @@ class TypedConstant { } TypedConstant(DataType dt) : dt(dt) { + if (!dt->is()) { + assert(false); + } + TI_ASSERT_INFO(dt->is(), "TypedConstant can only be PrimitiveType, got {}", dt->to_string()); value_bits = 0; } diff --git a/taichi/ir/type_utils.h b/taichi/ir/type_utils.h index 5fdeef6050b19..2d7d6da6699fd 100644 --- a/taichi/ir/type_utils.h +++ b/taichi/ir/type_utils.h @@ -80,8 +80,6 @@ inline bool is_quant(DataType dt) { } inline bool is_real(DataType dt) { - if (dt->is()) - return is_real(dt->as()->get_element_type()); return dt->is_primitive(PrimitiveTypeID::f16) || dt->is_primitive(PrimitiveTypeID::f32) || dt->is_primitive(PrimitiveTypeID::f64) || dt->is() || diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index caaedbd215b65..2ee1a4a4d3524 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -112,6 +112,10 @@ class AlgSimp : public BasicStmtVisitor { if ((fast_math || is_integral(stmt->ret_type)) && (alg_is_zero(lhs) || alg_is_zero(rhs))) { // fast_math or integral operands: 0 * a -> 0, a * 0 -> 0 + if (stmt->ret_type->is() || stmt->rhs->ret_type->is()) { + // TODO: handle 0-tensor + return false; + } replace_with_zero(stmt); return true; } @@ -163,8 +167,13 @@ class AlgSimp : public BasicStmtVisitor { if ((fast_math || is_integral(stmt->ret_type)) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a / a -> 1 - replace_with_one(stmt); - return true; + if (stmt->lhs->ret_type->is() && stmt->rhs->ret_type->is()) { + replace_with_one(stmt); + return true; + } else { + // TODO: handle tensor division + return false; + } } if (fast_math && rhs && is_real(rhs->ret_type) && stmt->op_type != BinaryOpType::floordiv) { @@ -244,7 +253,13 @@ class AlgSimp : public BasicStmtVisitor { (fast_math || is_integral(stmt->ret_type)) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a -^ a -> 0 - replace_with_zero(stmt); + if (stmt->lhs->ret_type->is() && + stmt->rhs->ret_type->is()) { + replace_with_zero(stmt); + } else { + // TODO: handle tensor operations + return; + } } } else if (rhs && stmt->op_type == BinaryOpType::pow) { float64 exponent = rhs->val[0].val_cast_to_float64(); @@ -329,6 +344,10 @@ class AlgSimp : public BasicStmtVisitor { modifier.erase(stmt); } else if (alg_is_zero(lhs) || alg_is_zero(rhs)) { // 0 & a -> 0, a & 0 -> 0 + if (stmt->ret_type->is() || stmt->rhs->ret_type->is()) { + // TODO: support tensor type + return; + } replace_with_zero(stmt); } else if (irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // a & a -> a @@ -343,6 +362,10 @@ class AlgSimp : public BasicStmtVisitor { // a << 0 -> a // 0 << a -> 0 // 0 >> a -> 0 + if (stmt->ret_type->is() || stmt->rhs->ret_type->is()) { + // TODO: support tensor type + return; + } TI_ASSERT(stmt->lhs->ret_type == stmt->ret_type); stmt->replace_usages_with(stmt->lhs); modifier.erase(stmt); diff --git a/taichi/transforms/constant_fold.cpp b/taichi/transforms/constant_fold.cpp index b8bf03d990ae6..bd83bcebd8990 100644 --- a/taichi/transforms/constant_fold.cpp +++ b/taichi/transforms/constant_fold.cpp @@ -133,6 +133,9 @@ class ConstantFold : public BasicStmtVisitor { } void visit(BinaryOpStmt *stmt) override { + if (stmt->lhs->ret_type->is() || stmt->rhs->ret_type->is()) + return; + TI_TRACE("{} {}", stmt->lhs->ret_type->to_string(), stmt->rhs->ret_type->to_string()); auto lhs = stmt->lhs->cast(); auto rhs = stmt->rhs->cast(); if (!lhs || !rhs) diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index eb4de86a23912..edae88bc68004 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -276,6 +276,7 @@ class TypeCheck : public IRVisitor { } void cast(Stmt *&val, DataType dt) { + TI_TRACE("Cast {} to {}", val->name(), dt->to_string()); if (val->ret_type == dt) return; @@ -304,10 +305,20 @@ class TypeCheck : public IRVisitor { if (stmt->op_type == BinaryOpType::truediv) { auto default_fp = config_.default_fp; if (!is_real(stmt->lhs->ret_type)) { - cast(stmt->lhs, default_fp); + if (stmt->lhs->ret_type->is()) { + cast(stmt->lhs, default_fp); + } else { + TI_ASSERT(stmt->lhs->ret_type->is()); + cast(stmt->lhs, TypeFactory::create_tensor_type(stmt->lhs->ret_type->as()->get_shape(), default_fp)); + } } if (!is_real(stmt->rhs->ret_type)) { - cast(stmt->rhs, default_fp); + if (stmt->rhs->ret_type->is()) { + cast(stmt->rhs, default_fp); + } else { + TI_ASSERT(stmt->rhs->ret_type->is()); + cast(stmt->rhs, TypeFactory::create_tensor_type(stmt->rhs->ret_type->as()->get_shape(), default_fp)); + } } stmt->op_type = BinaryOpType::div; } @@ -327,8 +338,22 @@ class TypeCheck : public IRVisitor { } } - if (stmt->lhs->ret_type->is()) { - + auto lhs_is_tensor = stmt->lhs->ret_type->is(); + auto rhs_is_tensor = stmt->rhs->ret_type->is(); + + if (lhs_is_tensor || rhs_is_tensor) { + auto lhs_dtype = lhs_is_tensor ? DataType(stmt->lhs->ret_type->as()->get_element_type()) + : stmt->lhs->ret_type; + auto rhs_dtype = rhs_is_tensor ? DataType(stmt->rhs->ret_type->as()->get_element_type()) + : stmt->rhs->ret_type; + auto dtype = promoted_type(lhs_dtype, rhs_dtype); + if (dtype != lhs_dtype) + cast(stmt->lhs, lhs_is_tensor ? TypeFactory::create_tensor_type(stmt->lhs->ret_type->as()->get_shape(), dtype) : dtype); + if (dtype != rhs_dtype) + cast(stmt->rhs, rhs_is_tensor ? TypeFactory::create_tensor_type(stmt->rhs->ret_type->as()->get_shape(), dtype) : dtype); + // TODO: add shape inference for matrix ops below + stmt->ret_type = stmt->lhs->ret_type; + return; } if (stmt->lhs->ret_type != stmt->rhs->ret_type) { @@ -573,6 +598,9 @@ class TypeCheck : public IRVisitor { cast(stmt->values[i], element_dtype); } } + if (element_dtype != tensor_type->get_element_type()) { + stmt->ret_type = TypeFactory::create_tensor_type(tensor_type->get_shape(), element_dtype); + } } }; From ef33ff19e602d9ccb56d3eb9170020f26e5d99d9 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 27 Jul 2022 15:34:22 -0700 Subject: [PATCH 12/35] add simple ad hoc shape check placeholder --- taichi/ir/frontend_ir.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 171836aed86cc..03eb156df32c4 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -218,8 +218,13 @@ void BinaryOpExpression::type_check(CompileConfig *config) { ret_type = promoted_type(dtype, rhs_type); } else { TI_ASSERT(rhs_type->is()); + auto rhs_tensor_type = rhs_type->cast(); + if (rhs_tensor_type->get_shape() != lhs_type->cast()->get_shape()) + error(); ret_type = promoted_type(dtype, rhs_type->as()->get_element_type()); } + // TODO: shape check! + ret_type = TypeFactory::create_tensor_type(lhs_type->cast()->get_shape(), ret_type); return; } From a90a901b5abeac31e412d2fd0f6e8e53c51d4bd6 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 2 Aug 2022 13:37:29 -0400 Subject: [PATCH 13/35] save --- taichi/transforms/constant_fold.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/taichi/transforms/constant_fold.cpp b/taichi/transforms/constant_fold.cpp index bd83bcebd8990..03cf9af668b9f 100644 --- a/taichi/transforms/constant_fold.cpp +++ b/taichi/transforms/constant_fold.cpp @@ -135,7 +135,6 @@ class ConstantFold : public BasicStmtVisitor { void visit(BinaryOpStmt *stmt) override { if (stmt->lhs->ret_type->is() || stmt->rhs->ret_type->is()) return; - TI_TRACE("{} {}", stmt->lhs->ret_type->to_string(), stmt->rhs->ret_type->to_string()); auto lhs = stmt->lhs->cast(); auto rhs = stmt->rhs->cast(); if (!lhs || !rhs) From 743393a0287b6ce43e7ef4264e7e73e3acf3d723 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 3 Aug 2022 16:17:01 -0400 Subject: [PATCH 14/35] fix cfg pass --- python/taichi/lang/ast/ast_transformer.py | 2 +- python/taichi/lang/matrix.py | 12 ++------- taichi/codegen/cpu/codegen_cpu.cpp | 3 --- taichi/ir/control_flow_graph.cpp | 33 +++++++++++++++-------- 4 files changed, 25 insertions(+), 25 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 8f913b1d1be8c..7d9cefb9202a2 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -479,7 +479,7 @@ def build_Call(ctx, node): node.ptr = impl.ti_format(*args, **keywords) return node.ptr - if isinstance(node.func, ast.Attribute) and node.func.ptr == Matrix: + if isinstance(node.func, ast.Attribute) and func == Matrix: node.ptr = matrix.make_matrix(*args, **keywords) return node.ptr diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 0b17c3295e038..988a276909eea 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -97,9 +97,9 @@ def prop_setter(instance, value): return cls -def make_matrix(arr, dt=None, suppress_warning=False, is_ref=False): +def make_matrix(arr, dt=None, suppress_warning=False, is_ref=False, **kwargs): if not impl.current_cfg().real_matrix or in_python_scope(): - return Matrix(arr, dt, suppress_warning, is_ref) + return Matrix(arr, dt, suppress_warning, is_ref, **kwargs) cast = (lambda x: ops_mod.cast(x, dt)) if dt else ( lambda x: x if isinstance(x, expr.Expr) else expr.Expr(x)) if len(arr) == 0: @@ -322,10 +322,6 @@ def pyscope_or_ref(self, arr): return [[x] for x in arr] def no_dynamic_index(self, arr, dt): - if impl.current_cfg().real_matrix: - if dt is None: - dt = self.infer_dt(arr) - return _Matrix(arr, dt, is_vec=True) return [[impl.expr_init(ops_mod.cast(x, dt) if dt else x)] for x in arr] @@ -351,10 +347,6 @@ def pyscope_or_ref(self, arr): return [list(row) for row in arr] def no_dynamic_index(self, arr, dt): - if impl.current_cfg().real_matrix: - if dt is None: - dt = self.infer_dt(arr) - return _Matrix(arr, dt) return [[ impl.expr_init(ops_mod.cast(x, dt) if dt else x) for x in row ] for row in arr] diff --git a/taichi/codegen/cpu/codegen_cpu.cpp b/taichi/codegen/cpu/codegen_cpu.cpp index 67140adce29a6..901723416f602 100644 --- a/taichi/codegen/cpu/codegen_cpu.cpp +++ b/taichi/codegen/cpu/codegen_cpu.cpp @@ -286,9 +286,7 @@ FunctionType CodeGenCPU::codegen() { auto *llvm_prog = get_llvm_program(prog); auto *tlctx = llvm_prog->get_llvm_context(kernel->arch); auto &config = prog->config; - TI_TRACE("in codegen()"); std::string kernel_key = get_hashed_offline_cache_key(&config, kernel); - TI_TRACE("in codegen() 1"); kernel->set_kernel_key_for_cache(kernel_key); if (config.offline_cache && !config.async_mode && this->supports_offline_cache() && !kernel->is_evaluator) { @@ -301,7 +299,6 @@ FunctionType CodeGenCPU::codegen() { } } if (!kernel->lowered()) { - TI_TRACE("calling lowering"); kernel->lower(/*to_executable=*/false); } diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 969b874b6562a..d8d540c9af135 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -262,19 +262,30 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) { auto stmt = block->statements[i].get(); Stmt *result = nullptr; if (auto local_load = stmt->cast()) { - for (int i = 0; i < local_load->src.size(); ++i) { - bool regular = true; - auto alloca = local_load->src[i].var; - for (int l = 0; l < stmt->width(); l++) { - if (local_load->src[l].offset != l || - local_load->src[l].var != alloca) { - regular = false; - } - } - if (regular) { - result = get_store_forwarding_data(alloca, i); + bool regular = true; + auto alloca = local_load->src[0].var; + for (int l = 0; l < stmt->width(); l++) { + if (local_load->src[l].offset != l || + local_load->src[l].var != alloca) { + regular = false; } } + if (regular) { + result = get_store_forwarding_data(alloca, i); + } + // for (int i = 0; i < local_load->src.size(); ++i) { + // bool regular = true; + // auto alloca = local_load->src[i].var; + // for (int l = 0; l < stmt->width(); l++) { + // if (local_load->src[l].offset != l || + // local_load->src[l].var != alloca) { + // regular = false; + // } + // } + // if (regular) { + // result = get_store_forwarding_data(alloca, i); + // } + // } } else if (auto global_load = stmt->cast()) { if (!after_lower_access) { bool store_forwarding = true; From b7fc15d9e40d786571024dfead376a7b7c30e348 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Aug 2022 21:04:37 +0000 Subject: [PATCH 15/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/taichi/lang/impl.py | 1 + taichi/codegen/llvm/codegen_llvm.cpp | 34 +++++++++++---------- taichi/ir/frontend_ir.cpp | 9 ++++-- taichi/ir/frontend_ir.h | 11 ++++--- taichi/ir/type.h | 4 ++- taichi/transforms/alg_simp.cpp | 12 +++++--- taichi/transforms/constant_fold.cpp | 3 +- taichi/transforms/type_check.cpp | 45 +++++++++++++++++++++------- 8 files changed, 81 insertions(+), 38 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 6a1a7499a05a1..cfa94b701a95e 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -40,6 +40,7 @@ def expr_init_matrix(shape, element_type, elements): return get_runtime().prog.current_ast_builder().expr_alloca_matrix( shape, element_type, elements) + @taichi_scope def expr_init_shared_array(shape, element_type): return get_runtime().prog.current_ast_builder().expr_alloca_shared_array( diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 22e42555bce41..83de0da9f0147 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -368,19 +368,21 @@ void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) { if (from == to) { llvm_val[stmt] = llvm_val[stmt->operand]; } else if (from->is()) { - TI_ASSERT_INFO(to->is(), - "Only tensor to tensor cast is supported, {} provided", to->to_string()); - auto from_ty = from->cast()->get_element_type(); - auto to_ty = to->cast()->get_element_type(); - cast_op = get_cast_op(from_ty, to_ty); - auto type = tlctx->get_data_type(to->cast()); - llvm::Value *vec = llvm::UndefValue::get(type); - for (int i = 0; i < from->cast()->get_num_elements(); ++i) { - auto elem = builder->CreateExtractElement(llvm_val[stmt->operand], i); - auto cast_input = builder->CreateCast(cast_op, elem, tlctx->get_data_type(to_ty)); - vec = builder->CreateInsertElement(vec, cast_input, i); - } - llvm_val[stmt] = vec; + TI_ASSERT_INFO(to->is(), + "Only tensor to tensor cast is supported, {} provided", + to->to_string()); + auto from_ty = from->cast()->get_element_type(); + auto to_ty = to->cast()->get_element_type(); + cast_op = get_cast_op(from_ty, to_ty); + auto type = tlctx->get_data_type(to->cast()); + llvm::Value *vec = llvm::UndefValue::get(type); + for (int i = 0; i < from->cast()->get_num_elements(); ++i) { + auto elem = builder->CreateExtractElement(llvm_val[stmt->operand], i); + auto cast_input = + builder->CreateCast(cast_op, elem, tlctx->get_data_type(to_ty)); + vec = builder->CreateInsertElement(vec, cast_input, i); + } + llvm_val[stmt] = vec; } else if (is_real(from) != is_real(to)) { cast_op = get_cast_op(from, to); auto cast_type = to->is_primitive(PrimitiveTypeID::f16) @@ -453,11 +455,13 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { auto ret_type = stmt->ret_type; auto is_real_tensor = [](const DataType &dt) { - return dt->is() && is_real(dt->cast()->get_element_type()); + return dt->is() && + is_real(dt->cast()->get_element_type()); }; auto is_integral_tensor = [](const DataType &dt) { - return dt->is() && is_integral(dt->cast()->get_element_type()); + return dt->is() && + is_integral(dt->cast()->get_element_type()); }; if (op == BinaryOpType::add) { diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index fa75b161d6019..cd9bc60ddc64f 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -219,12 +219,15 @@ void BinaryOpExpression::type_check(CompileConfig *config) { } else { TI_ASSERT(rhs_type->is()); auto rhs_tensor_type = rhs_type->cast(); - if (rhs_tensor_type->get_shape() != lhs_type->cast()->get_shape()) + if (rhs_tensor_type->get_shape() != + lhs_type->cast()->get_shape()) error(); - ret_type = promoted_type(dtype, rhs_type->as()->get_element_type()); + ret_type = + promoted_type(dtype, rhs_type->as()->get_element_type()); } // TODO: shape check! - ret_type = TypeFactory::create_tensor_type(lhs_type->cast()->get_shape(), ret_type); + ret_type = TypeFactory::create_tensor_type( + lhs_type->cast()->get_shape(), ret_type); return; } diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 7c1f6f46bc0c8..ccab1cc5bf7d1 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -379,18 +379,21 @@ class BinaryOpExpression : public Expression { BinaryOpExpression(const BinaryOpType &type, const Expr &lhs, const Expr &rhs) : type(type) { - auto to_broadcast_tensor = [](const Expr &elt, const DataType &dt) -> Expr { TI_ASSERT(dt->is()); auto tensor_type = dt->as(); auto elt_type = tensor_type->get_element_type(); - TI_ASSERT_INFO(elt_type->is(), "Only primitive types are supported in Tensors, got {}", elt_type->to_string()); + TI_ASSERT_INFO(elt_type->is(), + "Only primitive types are supported in Tensors, got {}", + elt_type->to_string()); std::vector broadcast_values(tensor_type->get_num_elements(), elt); - return Expr::make(broadcast_values, tensor_type->get_shape(), elt_type); + return Expr::make(broadcast_values, + tensor_type->get_shape(), elt_type); }; auto unify_expr = [&](const Expr &e1, const Expr &e2) { - if ((!e1->ret_type->is() && !e2->ret_type->is()) || + if ((!e1->ret_type->is() && + !e2->ret_type->is()) || (e1->ret_type->is() && e2->ret_type->is())) { return std::tuple(e1, e2); } diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 521c40f8f179b..fa5fe7a6b8b6b 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -385,7 +385,9 @@ class TypedConstant { if (!dt->is()) { assert(false); } - TI_ASSERT_INFO(dt->is(), "TypedConstant can only be PrimitiveType, got {}", dt->to_string()); + TI_ASSERT_INFO(dt->is(), + "TypedConstant can only be PrimitiveType, got {}", + dt->to_string()); value_bits = 0; } diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 1d8060121b5af..c32dc2b2f4a8c 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -112,7 +112,8 @@ class AlgSimp : public BasicStmtVisitor { if ((fast_math || is_integral(stmt->ret_type)) && (alg_is_zero(lhs) || alg_is_zero(rhs))) { // fast_math or integral operands: 0 * a -> 0, a * 0 -> 0 - if (stmt->ret_type->is() || stmt->rhs->ret_type->is()) { + if (stmt->ret_type->is() || + stmt->rhs->ret_type->is()) { // TODO: handle 0-tensor return false; } @@ -167,7 +168,8 @@ class AlgSimp : public BasicStmtVisitor { if ((fast_math || is_integral(stmt->ret_type)) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a / a -> 1 - if (stmt->lhs->ret_type->is() && stmt->rhs->ret_type->is()) { + if (stmt->lhs->ret_type->is() && + stmt->rhs->ret_type->is()) { replace_with_one(stmt); return true; } else { @@ -344,7 +346,8 @@ class AlgSimp : public BasicStmtVisitor { modifier.erase(stmt); } else if (alg_is_zero(lhs) || alg_is_zero(rhs)) { // 0 & a -> 0, a & 0 -> 0 - if (stmt->ret_type->is() || stmt->rhs->ret_type->is()) { + if (stmt->ret_type->is() || + stmt->rhs->ret_type->is()) { // TODO: support tensor type return; } @@ -362,7 +365,8 @@ class AlgSimp : public BasicStmtVisitor { // a << 0 -> a // 0 << a -> 0 // 0 >> a -> 0 - if (stmt->ret_type->is() || stmt->rhs->ret_type->is()) { + if (stmt->ret_type->is() || + stmt->rhs->ret_type->is()) { // TODO: support tensor type return; } diff --git a/taichi/transforms/constant_fold.cpp b/taichi/transforms/constant_fold.cpp index 03cf9af668b9f..6ea23b3c3c621 100644 --- a/taichi/transforms/constant_fold.cpp +++ b/taichi/transforms/constant_fold.cpp @@ -133,7 +133,8 @@ class ConstantFold : public BasicStmtVisitor { } void visit(BinaryOpStmt *stmt) override { - if (stmt->lhs->ret_type->is() || stmt->rhs->ret_type->is()) + if (stmt->lhs->ret_type->is() || + stmt->rhs->ret_type->is()) return; auto lhs = stmt->lhs->cast(); auto rhs = stmt->rhs->cast(); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index edce5ecd7e22c..d2f3e91e7ef19 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -311,7 +311,10 @@ class TypeCheck : public IRVisitor { cast(stmt->lhs, default_fp); } else { TI_ASSERT(stmt->lhs->ret_type->is()); - cast(stmt->lhs, TypeFactory::create_tensor_type(stmt->lhs->ret_type->as()->get_shape(), default_fp)); + cast(stmt->lhs, + TypeFactory::create_tensor_type( + stmt->lhs->ret_type->as()->get_shape(), + default_fp)); } } if (!is_real(stmt->rhs->ret_type)) { @@ -319,7 +322,10 @@ class TypeCheck : public IRVisitor { cast(stmt->rhs, default_fp); } else { TI_ASSERT(stmt->rhs->ret_type->is()); - cast(stmt->rhs, TypeFactory::create_tensor_type(stmt->rhs->ret_type->as()->get_shape(), default_fp)); + cast(stmt->rhs, + TypeFactory::create_tensor_type( + stmt->rhs->ret_type->as()->get_shape(), + default_fp)); } } stmt->op_type = BinaryOpType::div; @@ -344,15 +350,31 @@ class TypeCheck : public IRVisitor { auto rhs_is_tensor = stmt->rhs->ret_type->is(); if (lhs_is_tensor || rhs_is_tensor) { - auto lhs_dtype = lhs_is_tensor ? DataType(stmt->lhs->ret_type->as()->get_element_type()) - : stmt->lhs->ret_type; - auto rhs_dtype = rhs_is_tensor ? DataType(stmt->rhs->ret_type->as()->get_element_type()) - : stmt->rhs->ret_type; + auto lhs_dtype = + lhs_is_tensor + ? DataType( + stmt->lhs->ret_type->as()->get_element_type()) + : stmt->lhs->ret_type; + auto rhs_dtype = + rhs_is_tensor + ? DataType( + stmt->rhs->ret_type->as()->get_element_type()) + : stmt->rhs->ret_type; auto dtype = promoted_type(lhs_dtype, rhs_dtype); if (dtype != lhs_dtype) - cast(stmt->lhs, lhs_is_tensor ? TypeFactory::create_tensor_type(stmt->lhs->ret_type->as()->get_shape(), dtype) : dtype); + cast( + stmt->lhs, + lhs_is_tensor + ? TypeFactory::create_tensor_type( + stmt->lhs->ret_type->as()->get_shape(), dtype) + : dtype); if (dtype != rhs_dtype) - cast(stmt->rhs, rhs_is_tensor ? TypeFactory::create_tensor_type(stmt->rhs->ret_type->as()->get_shape(), dtype) : dtype); + cast( + stmt->rhs, + rhs_is_tensor + ? TypeFactory::create_tensor_type( + stmt->rhs->ret_type->as()->get_shape(), dtype) + : dtype); // TODO: add shape inference for matrix ops below stmt->ret_type = stmt->lhs->ret_type; return; @@ -596,7 +618,9 @@ class TypeCheck : public IRVisitor { } void visit(MatrixInitStmt *stmt) override { - TI_ASSERT_INFO(stmt->ret_type->is(), "Matrix should have tensor type, got {}", stmt->ret_type->to_string()); + TI_ASSERT_INFO(stmt->ret_type->is(), + "Matrix should have tensor type, got {}", + stmt->ret_type->to_string()); auto tensor_type = stmt->ret_type->as(); auto element_dtype = tensor_type->get_element_type(); for (auto elt : stmt->values) { @@ -608,7 +632,8 @@ class TypeCheck : public IRVisitor { } } if (element_dtype != tensor_type->get_element_type()) { - stmt->ret_type = TypeFactory::create_tensor_type(tensor_type->get_shape(), element_dtype); + stmt->ret_type = TypeFactory::create_tensor_type(tensor_type->get_shape(), + element_dtype); } } }; From ff8b73c161a20a3b940b64b906480806a693f6df Mon Sep 17 00:00:00 2001 From: AD1024 Date: Fri, 5 Aug 2022 15:30:05 -0400 Subject: [PATCH 16/35] get rid of reduce --- taichi/ir/type.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index daca642ce9e6b..0188722f15e6f 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -1,4 +1,3 @@ -#include #include "taichi/ir/type.h" #include "taichi/ir/type_factory.h" @@ -89,7 +88,11 @@ std::string TensorType::to_string() const { } int TensorType::vector_width() const { - return std::reduce(shape_.begin(), shape_.end(), 1, std::multiplies()); + int vw = 1; + for (auto dim : shape_) { + vw *= dim; + } + return vw; } int Type::vector_width() const { From 378e61a98a3c99ba949f032006c402a1b7000b99 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Fri, 5 Aug 2022 16:59:41 -0400 Subject: [PATCH 17/35] fix typo bug --- taichi/codegen/llvm/codegen_llvm.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 83de0da9f0147..cb8a0eb93d0c0 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -557,7 +557,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { create_call("min_" #x, {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); \ } - if (is_real(ret_type) || is_real_tensor) { + if (is_real(ret_type) || is_real_tensor(ret_type)) { llvm_val[stmt] = builder->CreateMinNum(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } @@ -1436,6 +1436,7 @@ void TaskCodeGenLLVM::visit(AtomicOpStmt *stmt) { if (is_local) { TI_ERROR("Local atomics should have been demoted."); } + TI_TRACE("Atomic: {} ({}, {})", stmt->ret_type->to_string(), stmt->dest->ret_type->to_string(), stmt->val->ret_type->to_string()); TI_ASSERT(stmt->width() == 1); for (int l = 0; l < stmt->width(); l++) { llvm::Value *old_value; From bcca7f5ae525e286e1e39b8c937e265c460933cf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Aug 2022 21:01:05 +0000 Subject: [PATCH 18/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/codegen/llvm/codegen_llvm.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index cb8a0eb93d0c0..eb7838459f40a 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1436,7 +1436,8 @@ void TaskCodeGenLLVM::visit(AtomicOpStmt *stmt) { if (is_local) { TI_ERROR("Local atomics should have been demoted."); } - TI_TRACE("Atomic: {} ({}, {})", stmt->ret_type->to_string(), stmt->dest->ret_type->to_string(), stmt->val->ret_type->to_string()); + TI_TRACE("Atomic: {} ({}, {})", stmt->ret_type->to_string(), + stmt->dest->ret_type->to_string(), stmt->val->ret_type->to_string()); TI_ASSERT(stmt->width() == 1); for (int l = 0; l < stmt->width(); l++) { llvm::Value *old_value; From 3132514f7df6c81fa2caa0a3f1cff6115c65df4a Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 9 Aug 2022 16:10:36 -0400 Subject: [PATCH 19/35] save --- python/taichi/lang/expr.py | 4 ++-- taichi/codegen/llvm/codegen_llvm.cpp | 19 ++++++++++++------- taichi/ir/frontend_ir.cpp | 2 ++ taichi/ir/statements.h | 17 ----------------- taichi/transforms/cfg_optimization.cpp | 4 ++-- 5 files changed, 18 insertions(+), 28 deletions(-) diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 45aadc3d5aca2..cb74c487f9c06 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -43,8 +43,8 @@ def __getitem__(self, *indices): indices = (indices, ) indices = make_expr_group(*indices) - return impl.get_runtime().prog.current_ast_builder( - ).expr_indexed_matrix(self.ptr, indices) + return Expr(impl.get_runtime().prog.current_ast_builder( + ).expr_indexed_matrix(self.ptr, indices)) def __hash__(self): return self.ptr.get_raw_address() diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index cb8a0eb93d0c0..62fb9a7afaf94 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -125,6 +125,7 @@ void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { if (stmt->ret_type->is()) { auto tensor_type = stmt->ret_type->cast(); auto type = tlctx->get_data_type(tensor_type); + auto array_size = tlctx->get_constant(tensor_type->get_num_elements()); // Return type is [array_size x type]*. if (stmt->is_shared) { size_t data_element_size = tlctx->get_type_size( @@ -147,7 +148,7 @@ void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { tlctx->get_data_type(tensor_type->get_element_type()), 0); llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type); } else { - llvm_val[stmt] = create_entry_block_alloca(type, false); + llvm_val[stmt] = create_entry_block_alloca(type, 0, array_size); } } else { TI_ASSERT(stmt->width() == 1); @@ -1774,12 +1775,16 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { llvm_val[stmt] = builder->CreateGEP(ptr_ty, llvm_val[stmt->origin], llvm_val[stmt->offset]); #else - auto stmt_dtype = stmt->origin->ret_type->as(); - auto element_dtype = stmt_dtype->get_element_type(); - auto llvm_type = tlctx->get_data_type(element_dtype); - auto casted_ptr = builder->CreateBitCast( - llvm_val[stmt->origin], llvm::PointerType::get(llvm_type, 0)); - llvm_val[stmt] = builder->CreateGEP(casted_ptr, llvm_val[stmt->offset]); + if (stmt->origin->ret_type->is()) { + auto stmt_dtype = stmt->origin->ret_type->cast(); + auto element_dtype = stmt_dtype->get_element_type(); + auto llvm_type = tlctx->get_data_type(element_dtype); + auto casted_ptr = builder->CreateBitCast( + llvm_val[stmt->origin], llvm::PointerType::get(llvm_type, 0)); + llvm_val[stmt] = builder->CreateGEP(casted_ptr, llvm_val[stmt->offset]); + } else { + llvm_val[stmt] = builder->CreateGEP(llvm_val[stmt->origin], llvm_val[stmt->offset]); + } #endif } else { auto origin_address = builder->CreatePtrToInt( diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index cd9bc60ddc64f..4cb9e462c4bcd 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -1188,6 +1188,8 @@ void flatten_rvalue(Expr ptr, Expression::FlattenContext *ctx) { } } else if (ptr.is()) { auto ix = ptr.cast(); + // if (ix->var->ret_type->is()) + // return; if (ix->is_local()) { flatten_local_load(ptr, ctx); } else { diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 82e04530a186d..1a201acbcfaf5 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -1827,22 +1827,5 @@ class MatrixInitStmt : public Stmt { TI_DEFINE_ACCEPT_AND_CLONE }; -class IndexStmt : public Stmt { - public: - Stmt *target; - Stmt *index; - - IndexStmt(Stmt *target, Stmt *index) : target(target), index(index) { - TI_STMT_REG_FIELDS; - } - - bool has_global_side_effect() const override { - return false; - } - - TI_STMT_DEF_FIELDS(ret_type, target, index); - TI_DEFINE_ACCEPT_AND_CLONE -}; - } // namespace lang } // namespace taichi diff --git a/taichi/transforms/cfg_optimization.cpp b/taichi/transforms/cfg_optimization.cpp index 93efc651a411d..520357c07198c 100644 --- a/taichi/transforms/cfg_optimization.cpp +++ b/taichi/transforms/cfg_optimization.cpp @@ -21,8 +21,8 @@ bool cfg_optimization( cfg->simplify_graph(); if (cfg->store_to_load_forwarding(after_lower_access, autodiff_enabled)) modified = true; - if (cfg->dead_store_elimination(after_lower_access, lva_config_opt)) - modified = true; + // if (cfg->dead_store_elimination(after_lower_access, lva_config_opt)) + // modified = true; if (modified) result_modified = true; else From 25f2ed864dd72914e3b4dbb611762a693d752267 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 9 Aug 2022 16:27:39 -0400 Subject: [PATCH 20/35] shape check for indexing --- taichi/ir/frontend_ir.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 4cb9e462c4bcd..5246f0c600d42 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -1016,6 +1016,20 @@ Expr ASTBuilder::expr_alloca_local_matrix(const std::vector &shape, Expr ASTBuilder::expr_indexed_matrix(const Expr &matrix, const ExprGroup &indices) { TI_ASSERT(matrix.get_ret_type()->is()); + auto shape = matrix.get_ret_type()->as()->get_shape(); + if (indices.size() != shape.size()) { + std::string shape_str = "["; + if (shape.size() > 0) { + shape_str += std::to_string(shape[0]); + for (int i = 1; i < shape.size(); i++) { + shape_str += ", " + std::to_string(shape[i]); + } + } + shape_str += "]"; + TI_ERROR("Indexed matrix of shape {} has wrong number of indices. Expected {} but got " + "{}.", + shape_str, shape.size(), indices.size()); + } return Expr(std::make_shared(matrix, indices)); } From 08a337c130067c57a5a9f61b8cad0fb81f84adba Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 9 Aug 2022 16:28:37 -0400 Subject: [PATCH 21/35] format --- python/taichi/lang/expr.py | 5 +++-- taichi/codegen/llvm/codegen_llvm.cpp | 3 ++- taichi/ir/frontend_ir.cpp | 8 +++++--- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index cb74c487f9c06..01fc8dff46627 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -43,8 +43,9 @@ def __getitem__(self, *indices): indices = (indices, ) indices = make_expr_group(*indices) - return Expr(impl.get_runtime().prog.current_ast_builder( - ).expr_indexed_matrix(self.ptr, indices)) + return Expr( + impl.get_runtime().prog.current_ast_builder().expr_indexed_matrix( + self.ptr, indices)) def __hash__(self): return self.ptr.get_raw_address() diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index df5832215d1d0..2a96f4d2d75ca 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1784,7 +1784,8 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { llvm_val[stmt->origin], llvm::PointerType::get(llvm_type, 0)); llvm_val[stmt] = builder->CreateGEP(casted_ptr, llvm_val[stmt->offset]); } else { - llvm_val[stmt] = builder->CreateGEP(llvm_val[stmt->origin], llvm_val[stmt->offset]); + llvm_val[stmt] = + builder->CreateGEP(llvm_val[stmt->origin], llvm_val[stmt->offset]); } #endif } else { diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 5246f0c600d42..e33b989f19fa5 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -1026,9 +1026,11 @@ Expr ASTBuilder::expr_indexed_matrix(const Expr &matrix, } } shape_str += "]"; - TI_ERROR("Indexed matrix of shape {} has wrong number of indices. Expected {} but got " - "{}.", - shape_str, shape.size(), indices.size()); + TI_ERROR( + "Indexed matrix of shape {} has wrong number of indices. Expected {} " + "but got " + "{}.", + shape_str, shape.size(), indices.size()); } return Expr(std::make_shared(matrix, indices)); } From b493e12905dd1b1e3741ac35cb95c588b6064630 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 9 Aug 2022 17:30:33 -0400 Subject: [PATCH 22/35] remove log --- taichi/codegen/llvm/codegen_llvm.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 2a96f4d2d75ca..1771d4af91085 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1437,8 +1437,6 @@ void TaskCodeGenLLVM::visit(AtomicOpStmt *stmt) { if (is_local) { TI_ERROR("Local atomics should have been demoted."); } - TI_TRACE("Atomic: {} ({}, {})", stmt->ret_type->to_string(), - stmt->dest->ret_type->to_string(), stmt->val->ret_type->to_string()); TI_ASSERT(stmt->width() == 1); for (int l = 0; l < stmt->width(); l++) { llvm::Value *old_value; From d3daeef2728dfab2d4e1c0cf35c78a911f1d0709 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 9 Aug 2022 17:36:55 -0400 Subject: [PATCH 23/35] oopsss --- taichi/transforms/cfg_optimization.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/taichi/transforms/cfg_optimization.cpp b/taichi/transforms/cfg_optimization.cpp index 520357c07198c..93efc651a411d 100644 --- a/taichi/transforms/cfg_optimization.cpp +++ b/taichi/transforms/cfg_optimization.cpp @@ -21,8 +21,8 @@ bool cfg_optimization( cfg->simplify_graph(); if (cfg->store_to_load_forwarding(after_lower_access, autodiff_enabled)) modified = true; - // if (cfg->dead_store_elimination(after_lower_access, lva_config_opt)) - // modified = true; + if (cfg->dead_store_elimination(after_lower_access, lva_config_opt)) + modified = true; if (modified) result_modified = true; else From a868388f1e2fdb7a7dcff9c4a9662b13fe9974dd Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 9 Aug 2022 18:22:35 -0400 Subject: [PATCH 24/35] fix cfg for matrix --- taichi/analysis/data_source_analysis.cpp | 6 +++++- taichi/ir/control_flow_graph.cpp | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index 4a018afa6bf47..9baad69fb80fb 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -37,6 +37,10 @@ std::vector get_load_pointers(Stmt *load_stmt) { return external_func->arg_stmts; } else if (auto ref = load_stmt->cast()) { return {ref->var}; + } else if (auto matrix_init = load_stmt->cast()) { + return matrix_init->values; + } else if (auto ptr_offset = load_stmt->cast()) { + return {ptr_offset->origin}; } else { return std::vector(); } @@ -59,7 +63,7 @@ Stmt *get_store_data(Stmt *store_stmt) { std::vector get_store_destination(Stmt *store_stmt) { // If store_stmt provides some data sources, return the pointers of the data. - if (store_stmt->is() && !store_stmt->ret_type->is()) { + if (store_stmt->is()) { // The statement itself provides a data source (const [0]). return std::vector(1, store_stmt); } else if (auto local_store = store_stmt->cast()) { diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index ad0e3788e3457..9e63d3cbc8bbf 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -463,6 +463,7 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { // Neither used in other nodes nor used in this node. if (!stmt->is()) { // Eliminate the dead store. + TI_TRACE("Elminate dead store: {}", block->operator[](i)->name()); erase(i); modified = true; continue; @@ -514,7 +515,7 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { } } auto load_ptrs = irpass::analysis::get_load_pointers(stmt); - if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1) { + if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1 && !stmt->is()) { // Identical load elimination auto load_ptr = load_ptrs.front(); if (!after_lower_access || From feef9e88a418254ff5a44df16dd1bcac8094e0a0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Aug 2022 22:23:59 +0000 Subject: [PATCH 25/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/ir/control_flow_graph.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 9e63d3cbc8bbf..baa64c21d8846 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -515,7 +515,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { } } auto load_ptrs = irpass::analysis::get_load_pointers(stmt); - if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1 && !stmt->is()) { + if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1 && + !stmt->is()) { // Identical load elimination auto load_ptr = load_ptrs.front(); if (!after_lower_access || From 1e6511d335650fbc60e04a3dba1b6c49efb722eb Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 10 Aug 2022 14:43:20 -0400 Subject: [PATCH 26/35] fix error for windows --- taichi/codegen/llvm/codegen_llvm.cpp | 2 +- taichi/runtime/llvm/llvm_context.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 1771d4af91085..d90b3773780c9 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -722,7 +722,7 @@ llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) { } else if (dt->is()) { auto tensor_type = dt->cast(); auto element_type = llvm_type(tensor_type->get_element_type()); - return llvm::VectorType::get(element_type, tensor_type->get_num_elements()); + return llvm::VectorType::get(element_type, llvm::ElementCount(tensor_type->get_num_elements(), false)); } else { TI_NOT_IMPLEMENTED; } diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index bdc4f2f82938a..307cbbcb8d3e9 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -138,7 +138,7 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { } else if (dt->is()) { auto vectorty = dt->as(); auto dtype = this->get_data_type(vectorty->get_element_type()); - return llvm::VectorType::get(dtype, vectorty->get_num_elements()); + return llvm::VectorType::get(dtype, llvm::ElementCount(vectorty->get_num_elements(), false)); } else { TI_INFO(data_type_name(dt)); TI_NOT_IMPLEMENTED From b94e0a50e2624b06162a610582ec6f72f841a4b6 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 10 Aug 2022 14:44:25 -0400 Subject: [PATCH 27/35] format --- taichi/codegen/llvm/codegen_llvm.cpp | 4 +++- taichi/runtime/llvm/llvm_context.cpp | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index d90b3773780c9..4c6b1f271a2e8 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -722,7 +722,9 @@ llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) { } else if (dt->is()) { auto tensor_type = dt->cast(); auto element_type = llvm_type(tensor_type->get_element_type()); - return llvm::VectorType::get(element_type, llvm::ElementCount(tensor_type->get_num_elements(), false)); + return llvm::VectorType::get( + element_type, + llvm::ElementCount(tensor_type->get_num_elements(), false)); } else { TI_NOT_IMPLEMENTED; } diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index 307cbbcb8d3e9..adf54d533512c 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -138,7 +138,8 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { } else if (dt->is()) { auto vectorty = dt->as(); auto dtype = this->get_data_type(vectorty->get_element_type()); - return llvm::VectorType::get(dtype, llvm::ElementCount(vectorty->get_num_elements(), false)); + return llvm::VectorType::get( + dtype, llvm::ElementCount(vectorty->get_num_elements(), false)); } else { TI_INFO(data_type_name(dt)); TI_NOT_IMPLEMENTED From 8490a6257e3e492a18f73c04eb7453db81ae91ab Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 10 Aug 2022 15:40:39 -0400 Subject: [PATCH 28/35] wat --- taichi/codegen/llvm/codegen_llvm.cpp | 3 ++- taichi/runtime/llvm/llvm_context.cpp | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 4c6b1f271a2e8..ed39024d8bf9d 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -724,7 +724,8 @@ llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) { auto element_type = llvm_type(tensor_type->get_element_type()); return llvm::VectorType::get( element_type, - llvm::ElementCount(tensor_type->get_num_elements(), false)); + tensor_type->get_num_elements(), + false); } else { TI_NOT_IMPLEMENTED; } diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index adf54d533512c..ce9b1ae48bcc4 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -139,7 +139,7 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { auto vectorty = dt->as(); auto dtype = this->get_data_type(vectorty->get_element_type()); return llvm::VectorType::get( - dtype, llvm::ElementCount(vectorty->get_num_elements(), false)); + dtype, vectorty->get_num_elements(), false); } else { TI_INFO(data_type_name(dt)); TI_NOT_IMPLEMENTED From 4f1325d7605fae0058c1405ef108c8f924888d11 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 10 Aug 2022 15:41:23 -0400 Subject: [PATCH 29/35] format yet again --- taichi/codegen/llvm/codegen_llvm.cpp | 6 ++---- taichi/runtime/llvm/llvm_context.cpp | 3 +-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index ed39024d8bf9d..6cbdc4b15382b 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -722,10 +722,8 @@ llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) { } else if (dt->is()) { auto tensor_type = dt->cast(); auto element_type = llvm_type(tensor_type->get_element_type()); - return llvm::VectorType::get( - element_type, - tensor_type->get_num_elements(), - false); + return llvm::VectorType::get(element_type, tensor_type->get_num_elements(), + false); } else { TI_NOT_IMPLEMENTED; } diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index ce9b1ae48bcc4..edde242f48f96 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -138,8 +138,7 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { } else if (dt->is()) { auto vectorty = dt->as(); auto dtype = this->get_data_type(vectorty->get_element_type()); - return llvm::VectorType::get( - dtype, vectorty->get_num_elements(), false); + return llvm::VectorType::get(dtype, vectorty->get_num_elements(), false); } else { TI_INFO(data_type_name(dt)); TI_NOT_IMPLEMENTED From acb9fd297fa6425400074d5c78027351575c983b Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 10 Aug 2022 16:14:32 -0400 Subject: [PATCH 30/35] fix struct_for loop --- taichi/codegen/llvm/codegen_llvm.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 6cbdc4b15382b..abfb71e0a7e57 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1775,13 +1775,20 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { llvm_val[stmt] = builder->CreateGEP(ptr_ty, llvm_val[stmt->origin], llvm_val[stmt->offset]); #else - if (stmt->origin->ret_type->is()) { - auto stmt_dtype = stmt->origin->ret_type->cast(); + if (stmt->origin->ret_type->is() + || (stmt->origin->ret_type->is() + && stmt->origin->ret_type->cast()->get_pointee_type()->is())) { + TensorType* stmt_dtype; + if (stmt->origin->ret_type->is()) { + stmt_dtype = stmt->origin->ret_type->cast()->get_pointee_type()->cast(); + } else { + stmt_dtype = stmt->origin->ret_type->cast(); + } auto element_dtype = stmt_dtype->get_element_type(); auto llvm_type = tlctx->get_data_type(element_dtype); auto casted_ptr = builder->CreateBitCast( llvm_val[stmt->origin], llvm::PointerType::get(llvm_type, 0)); - llvm_val[stmt] = builder->CreateGEP(casted_ptr, llvm_val[stmt->offset]); + llvm_val[stmt] = builder->CreateBitCast(builder->CreateGEP(casted_ptr, llvm_val[stmt->offset]), llvm::PointerType::get(llvm_type, 0)); } else { llvm_val[stmt] = builder->CreateGEP(llvm_val[stmt->origin], llvm_val[stmt->offset]); From bdfb179fd13705b22ff95626bbabb112b0bf34c7 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 10 Aug 2022 16:16:17 -0400 Subject: [PATCH 31/35] format yet yet again --- taichi/codegen/llvm/codegen_llvm.cpp | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index abfb71e0a7e57..1207e16d17a62 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1775,12 +1775,16 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { llvm_val[stmt] = builder->CreateGEP(ptr_ty, llvm_val[stmt->origin], llvm_val[stmt->offset]); #else - if (stmt->origin->ret_type->is() - || (stmt->origin->ret_type->is() - && stmt->origin->ret_type->cast()->get_pointee_type()->is())) { - TensorType* stmt_dtype; + if (stmt->origin->ret_type->is() || + (stmt->origin->ret_type->is() && + stmt->origin->ret_type->cast() + ->get_pointee_type() + ->is())) { + TensorType *stmt_dtype; if (stmt->origin->ret_type->is()) { - stmt_dtype = stmt->origin->ret_type->cast()->get_pointee_type()->cast(); + stmt_dtype = stmt->origin->ret_type->cast() + ->get_pointee_type() + ->cast(); } else { stmt_dtype = stmt->origin->ret_type->cast(); } @@ -1788,7 +1792,9 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { auto llvm_type = tlctx->get_data_type(element_dtype); auto casted_ptr = builder->CreateBitCast( llvm_val[stmt->origin], llvm::PointerType::get(llvm_type, 0)); - llvm_val[stmt] = builder->CreateBitCast(builder->CreateGEP(casted_ptr, llvm_val[stmt->offset]), llvm::PointerType::get(llvm_type, 0)); + llvm_val[stmt] = builder->CreateBitCast( + builder->CreateGEP(casted_ptr, llvm_val[stmt->offset]), + llvm::PointerType::get(llvm_type, 0)); } else { llvm_val[stmt] = builder->CreateGEP(llvm_val[stmt->origin], llvm_val[stmt->offset]); From 1b7315da117b69c28550f68124e240a77e096814 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 10 Aug 2022 15:32:25 -0400 Subject: [PATCH 32/35] try fix loop --- taichi/codegen/llvm/codegen_llvm.cpp | 1 + taichi/ir/statements.h | 5 ++++- taichi/transforms/type_check.cpp | 4 ++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 1207e16d17a62..f4e3355bbd03c 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1523,6 +1523,7 @@ void TaskCodeGenLLVM::create_global_load(GlobalLoadStmt *stmt, if (should_cache_as_read_only) { llvm_val[stmt] = create_intrinsic_load(ptr, llvm_type(stmt->ret_type)); } else { + ptr->dump(); llvm_val[stmt] = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), ptr); } diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 1a201acbcfaf5..a006e99b12ca0 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -367,7 +367,10 @@ class PtrOffsetStmt : public Stmt { bool is_local_ptr() const { if (origin->is() || origin->is()) { - TI_ASSERT_INFO(origin->ret_type->is(), + auto is_tensor_type = origin->ret_type->is() ? + origin->ret_type->cast()->get_pointee_type()->is() : + origin->ret_type->is(); + TI_ASSERT_INFO(is_tensor_type, "PtrOffsetStmt can only be used for Alloca (TensorType)."); } return origin->is() || origin->is(); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index d2f3e91e7ef19..1a69590287319 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -598,8 +598,8 @@ class TypeCheck : public IRVisitor { } void visit(GlobalTemporaryStmt *stmt) override { - if (!stmt->ret_type->is()) - stmt->ret_type.set_is_pointer(true); + // if (!stmt->ret_type->is()) + stmt->ret_type.set_is_pointer(true); } void visit(InternalFuncStmt *stmt) override { From 3dddae8c5c126e2af48813ef85925a48112d6570 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 10 Aug 2022 17:27:06 -0400 Subject: [PATCH 33/35] add one more pick --- taichi/codegen/llvm/codegen_llvm.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index f4e3355bbd03c..1207e16d17a62 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1523,7 +1523,6 @@ void TaskCodeGenLLVM::create_global_load(GlobalLoadStmt *stmt, if (should_cache_as_read_only) { llvm_val[stmt] = create_intrinsic_load(ptr, llvm_type(stmt->ret_type)); } else { - ptr->dump(); llvm_val[stmt] = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), ptr); } From 8699512718437b2fc1684b74632f75332bf4630c Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 11 Aug 2022 16:22:28 -0400 Subject: [PATCH 34/35] fix some cases --- python/taichi/lang/ast/ast_transformer.py | 4 ++-- python/taichi/lang/matrix.py | 10 +++++----- taichi/codegen/llvm/codegen_llvm.cpp | 4 ++++ taichi/ir/control_flow_graph.cpp | 1 - taichi/ir/frontend_ir.cpp | 5 +++-- 5 files changed, 14 insertions(+), 10 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 34d8df0db9e58..82d8508cd47b4 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -15,7 +15,7 @@ from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError from taichi.lang.field import Field from taichi.lang.matrix import (Matrix, MatrixType, _PyScopeMatrixImpl, - _TiScopeMatrixImpl) + _TiScopeMatrixImpl, Vector) from taichi.lang.snode import append from taichi.lang.util import in_taichi_scope, is_taichi_class, to_taichi_type from taichi.types import (annotations, ndarray_type, primitive_types, @@ -488,7 +488,7 @@ def build_Call(ctx, node): node.ptr = impl.ti_format(*args, **keywords) return node.ptr - if isinstance(node.func, ast.Attribute) and func == Matrix: + if isinstance(node.func, ast.Attribute) and func == Matrix or func == Vector: node.ptr = matrix.make_matrix(*args, **keywords) return node.ptr diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 1c8112f4db733..fa8f6de2e5a14 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -103,12 +103,12 @@ def make_matrix(arr, dt=None, suppress_warning=False, is_ref=False, **kwargs): cast = (lambda x: ops_mod.cast(x, dt)) if dt else ( lambda x: x if isinstance(x, expr.Expr) else expr.Expr(x)) if len(arr) == 0: - return impl.expr_init_matrix([0], dt, []) + return impl.expr_init(impl.expr_init_matrix([0], dt, [])) if not isinstance(arr[0], Iterable): - return impl.expr_init_matrix([len(arr)], dt, - [cast(elt).ptr for elt in arr]) - return impl.expr_init_matrix([len(arr), len(arr[0])], dt, - [cast(elt).ptr for row in arr for elt in row]) + return impl.expr_init(impl.expr_init_matrix([len(arr)], dt, + [cast(elt).ptr for elt in arr])) + return impl.expr_init(impl.expr_init_matrix([len(arr), len(arr[0])], dt, + [cast(elt).ptr for row in arr for elt in row])) class _MatrixBaseImpl: diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 1207e16d17a62..9ceef1c4335e9 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1484,6 +1484,10 @@ void TaskCodeGenLLVM::visit(GlobalStoreStmt *stmt) { TI_NOT_IMPLEMENTED; } } else { + TI_TRACE("Store {} to {}", stmt->val->name(), + stmt->dest->name()); + llvm_val[stmt->val]->dump(); + llvm_val[stmt->dest]->dump(); builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]); } } diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index baa64c21d8846..e49e234ab0635 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -463,7 +463,6 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { // Neither used in other nodes nor used in this node. if (!stmt->is()) { // Eliminate the dead store. - TI_TRACE("Elminate dead store: {}", block->operator[](i)->name()); erase(i); modified = true; continue; diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index e33b989f19fa5..fb856df25b553 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -222,8 +222,9 @@ void BinaryOpExpression::type_check(CompileConfig *config) { if (rhs_tensor_type->get_shape() != lhs_type->cast()->get_shape()) error(); - ret_type = - promoted_type(dtype, rhs_type->as()->get_element_type()); + auto rhs_elem_type = rhs_type->as()->get_element_type(); + if (rhs_elem_type != PrimitiveType::unknown) + ret_type = promoted_type(dtype, rhs_elem_type); } // TODO: shape check! ret_type = TypeFactory::create_tensor_type( From c513682ad55807001074219094d1a79ad9f0fa88 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Aug 2022 01:00:17 +0000 Subject: [PATCH 35/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/taichi/lang/ast/ast_transformer.py | 7 ++++--- python/taichi/lang/matrix.py | 12 +++++++----- taichi/codegen/llvm/codegen_llvm.cpp | 3 +-- taichi/ir/statements.h | 8 +++++--- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 82d8508cd47b4..f70b9e20c494e 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -14,8 +14,8 @@ from taichi.lang.ast.symbol_resolver import ASTResolver from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError from taichi.lang.field import Field -from taichi.lang.matrix import (Matrix, MatrixType, _PyScopeMatrixImpl, - _TiScopeMatrixImpl, Vector) +from taichi.lang.matrix import (Matrix, MatrixType, Vector, _PyScopeMatrixImpl, + _TiScopeMatrixImpl) from taichi.lang.snode import append from taichi.lang.util import in_taichi_scope, is_taichi_class, to_taichi_type from taichi.types import (annotations, ndarray_type, primitive_types, @@ -488,7 +488,8 @@ def build_Call(ctx, node): node.ptr = impl.ti_format(*args, **keywords) return node.ptr - if isinstance(node.func, ast.Attribute) and func == Matrix or func == Vector: + if isinstance(node.func, + ast.Attribute) and func == Matrix or func == Vector: node.ptr = matrix.make_matrix(*args, **keywords) return node.ptr diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index fa8f6de2e5a14..31ff0d47d8699 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -103,12 +103,14 @@ def make_matrix(arr, dt=None, suppress_warning=False, is_ref=False, **kwargs): cast = (lambda x: ops_mod.cast(x, dt)) if dt else ( lambda x: x if isinstance(x, expr.Expr) else expr.Expr(x)) if len(arr) == 0: - return impl.expr_init(impl.expr_init_matrix([0], dt, [])) + return impl.expr_init(impl.expr_init_matrix([0], dt, [])) if not isinstance(arr[0], Iterable): - return impl.expr_init(impl.expr_init_matrix([len(arr)], dt, - [cast(elt).ptr for elt in arr])) - return impl.expr_init(impl.expr_init_matrix([len(arr), len(arr[0])], dt, - [cast(elt).ptr for row in arr for elt in row])) + return impl.expr_init( + impl.expr_init_matrix([len(arr)], dt, + [cast(elt).ptr for elt in arr])) + return impl.expr_init( + impl.expr_init_matrix([len(arr), len(arr[0])], dt, + [cast(elt).ptr for row in arr for elt in row])) class _MatrixBaseImpl: diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 9ceef1c4335e9..cf3a2f0f9c3d0 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -1484,8 +1484,7 @@ void TaskCodeGenLLVM::visit(GlobalStoreStmt *stmt) { TI_NOT_IMPLEMENTED; } } else { - TI_TRACE("Store {} to {}", stmt->val->name(), - stmt->dest->name()); + TI_TRACE("Store {} to {}", stmt->val->name(), stmt->dest->name()); llvm_val[stmt->val]->dump(); llvm_val[stmt->dest]->dump(); builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]); diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index a006e99b12ca0..2a9c0a7aa59fa 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -367,9 +367,11 @@ class PtrOffsetStmt : public Stmt { bool is_local_ptr() const { if (origin->is() || origin->is()) { - auto is_tensor_type = origin->ret_type->is() ? - origin->ret_type->cast()->get_pointee_type()->is() : - origin->ret_type->is(); + auto is_tensor_type = origin->ret_type->is() + ? origin->ret_type->cast() + ->get_pointee_type() + ->is() + : origin->ret_type->is(); TI_ASSERT_INFO(is_tensor_type, "PtrOffsetStmt can only be used for Alloca (TensorType)."); }