diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 1d30ffceae77d..f70b9e20c494e 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -14,7 +14,7 @@ from taichi.lang.ast.symbol_resolver import ASTResolver from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError from taichi.lang.field import Field -from taichi.lang.matrix import (Matrix, MatrixType, _PyScopeMatrixImpl, +from taichi.lang.matrix import (Matrix, MatrixType, Vector, _PyScopeMatrixImpl, _TiScopeMatrixImpl) from taichi.lang.snode import append from taichi.lang.util import in_taichi_scope, is_taichi_class, to_taichi_type @@ -488,6 +488,11 @@ def build_Call(ctx, node): node.ptr = impl.ti_format(*args, **keywords) return node.ptr + if isinstance(node.func, + ast.Attribute) and func == Matrix or func == Vector: + node.ptr = matrix.make_matrix(*args, **keywords) + return node.ptr + if ASTTransformer.build_call_if_is_builtin(ctx, node, args, keywords): return node.ptr diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 2f60e86aeb284..01fc8dff46627 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -38,6 +38,15 @@ def __init__(self, *args, tb=None, dtype=None): self.ptr.set_tb(self.tb) self.ptr.type_check(impl.get_runtime().prog.config) + def __getitem__(self, *indices): + if not isinstance(indices, (list, tuple)): + indices = (indices, ) + + indices = make_expr_group(*indices) + return Expr( + impl.get_runtime().prog.current_ast_builder().expr_indexed_matrix( + self.ptr, indices)) + def __hash__(self): return self.ptr.get_raw_address() diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 3cd2f20738184..cfa94b701a95e 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -35,6 +35,12 @@ def expr_init_local_tensor(shape, element_type, elements): shape, element_type, elements) +@taichi_scope +def expr_init_matrix(shape, element_type, elements): + return get_runtime().prog.current_ast_builder().expr_alloca_matrix( + shape, element_type, elements) + + @taichi_scope def expr_init_shared_array(shape, element_type): return get_runtime().prog.current_ast_builder().expr_alloca_shared_array( @@ -48,6 +54,8 @@ def expr_init(rhs): if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")): return type(rhs)(*rhs.to_list()) if isinstance(rhs, Matrix): + if current_cfg().real_matrix: + return rhs return Matrix(rhs.to_list()) if isinstance(rhs, SharedArray): return rhs diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 629751e3412fb..31ff0d47d8699 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -97,6 +97,22 @@ def prop_setter(instance, value): return cls +def make_matrix(arr, dt=None, suppress_warning=False, is_ref=False, **kwargs): + if not impl.current_cfg().real_matrix or in_python_scope(): + return Matrix(arr, dt, suppress_warning, is_ref, **kwargs) + cast = (lambda x: ops_mod.cast(x, dt)) if dt else ( + lambda x: x if isinstance(x, expr.Expr) else expr.Expr(x)) + if len(arr) == 0: + return impl.expr_init(impl.expr_init_matrix([0], dt, [])) + if not isinstance(arr[0], Iterable): + return impl.expr_init( + impl.expr_init_matrix([len(arr)], dt, + [cast(elt).ptr for elt in arr])) + return impl.expr_init( + impl.expr_init_matrix([len(arr), len(arr[0])], dt, + [cast(elt).ptr for row in arr for elt in row])) + + class _MatrixBaseImpl: def __init__(self, m, n, entries): self.m = m diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index 4a018afa6bf47..9baad69fb80fb 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -37,6 +37,10 @@ std::vector get_load_pointers(Stmt *load_stmt) { return external_func->arg_stmts; } else if (auto ref = load_stmt->cast()) { return {ref->var}; + } else if (auto matrix_init = load_stmt->cast()) { + return matrix_init->values; + } else if (auto ptr_offset = load_stmt->cast()) { + return {ptr_offset->origin}; } else { return std::vector(); } @@ -59,7 +63,7 @@ Stmt *get_store_data(Stmt *store_stmt) { std::vector get_store_destination(Stmt *store_stmt) { // If store_stmt provides some data sources, return the pointers of the data. - if (store_stmt->is() && !store_stmt->ret_type->is()) { + if (store_stmt->is()) { // The statement itself provides a data source (const [0]). return std::vector(1, store_stmt); } else if (auto local_store = store_stmt->cast()) { diff --git a/taichi/analysis/gen_offline_cache_key.cpp b/taichi/analysis/gen_offline_cache_key.cpp index d1dfd5166c8ba..8ad565fdd84b9 100644 --- a/taichi/analysis/gen_offline_cache_key.cpp +++ b/taichi/analysis/gen_offline_cache_key.cpp @@ -159,6 +159,14 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { emit(expr->dual); } + void visit(MatrixExpression *expr) override { + emit(ExprOpCode::MatrixExpression); + emit(expr->dt); + for (auto elt : expr->elements) { + emit(elt); + } + } + void visit(IndexExpression *expr) override { emit(ExprOpCode::IndexExpression); emit(expr->var); diff --git a/taichi/analysis/offline_cache_util.cpp b/taichi/analysis/offline_cache_util.cpp index 5c8bd1ee1942c..53337391010de 100644 --- a/taichi/analysis/offline_cache_util.cpp +++ b/taichi/analysis/offline_cache_util.cpp @@ -64,6 +64,7 @@ static std::vector get_offline_cache_key_of_compile_config( serializer(config->demote_no_access_mesh_fors); serializer(config->experimental_auto_mesh_local); serializer(config->auto_mesh_local_default_occupacy); + serializer(config->real_matrix); serializer.finalize(); return serializer.data; diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index c26957e81f4a8..5f908603c9497 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -196,6 +196,24 @@ class IRNodeComparator : public IRVisitor { basic_check(stmt); } + void visit(MatrixInitStmt *stmt) override { + basic_check(stmt); + if (!same) + return; + auto o = other_node_->as(); + if (stmt->values.size() != o->values.size()) { + same = false; + return; + } + for (int i = 0; i < stmt->values.size(); ++i) { + other_node_ = o->values[i]; + stmt->values[i]->accept(this); + other_node_ = o; + if (!same) + return; + } + } + void visit(IfStmt *stmt) override { basic_check(stmt); if (!same) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index ee651d61deacf..cf3a2f0f9c3d0 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -124,7 +124,7 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) { void TaskCodeGenLLVM::visit(AllocaStmt *stmt) { if (stmt->ret_type->is()) { auto tensor_type = stmt->ret_type->cast(); - auto type = tlctx->get_data_type(tensor_type->get_element_type()); + auto type = tlctx->get_data_type(tensor_type); auto array_size = tlctx->get_constant(tensor_type->get_num_elements()); // Return type is [array_size x type]*. if (stmt->is_shared) { @@ -347,24 +347,45 @@ void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) { llvm_val[stmt] = \ builder->CreateIntrinsic(llvm::Intrinsic::x, {input_type}, {input}); \ } + auto get_cast_op = [](auto from, auto to) { + llvm::CastInst::CastOps cast_op; + if (is_real(from) && is_integral(to)) { + cast_op = is_signed(to) ? llvm::Instruction::CastOps::FPToSI + : llvm::Instruction::CastOps::FPToUI; + } else if (is_integral(from) && is_real(to)) { + cast_op = is_signed(from) ? llvm::Instruction::CastOps::SIToFP + : llvm::Instruction::CastOps::UIToFP; + } else { + TI_P(data_type_name(from)); + TI_P(data_type_name(to)); + TI_NOT_IMPLEMENTED; + } + return cast_op; + }; if (stmt->op_type == UnaryOpType::cast_value) { llvm::CastInst::CastOps cast_op; auto from = stmt->operand->ret_type; auto to = stmt->cast_type; if (from == to) { llvm_val[stmt] = llvm_val[stmt->operand]; - } else if (is_real(from) != is_real(to)) { - if (is_real(from) && is_integral(to)) { - cast_op = is_signed(to) ? llvm::Instruction::CastOps::FPToSI - : llvm::Instruction::CastOps::FPToUI; - } else if (is_integral(from) && is_real(to)) { - cast_op = is_signed(from) ? llvm::Instruction::CastOps::SIToFP - : llvm::Instruction::CastOps::UIToFP; - } else { - TI_P(data_type_name(from)); - TI_P(data_type_name(to)); - TI_NOT_IMPLEMENTED; + } else if (from->is()) { + TI_ASSERT_INFO(to->is(), + "Only tensor to tensor cast is supported, {} provided", + to->to_string()); + auto from_ty = from->cast()->get_element_type(); + auto to_ty = to->cast()->get_element_type(); + cast_op = get_cast_op(from_ty, to_ty); + auto type = tlctx->get_data_type(to->cast()); + llvm::Value *vec = llvm::UndefValue::get(type); + for (int i = 0; i < from->cast()->get_num_elements(); ++i) { + auto elem = builder->CreateExtractElement(llvm_val[stmt->operand], i); + auto cast_input = + builder->CreateCast(cast_op, elem, tlctx->get_data_type(to_ty)); + vec = builder->CreateInsertElement(vec, cast_input, i); } + llvm_val[stmt] = vec; + } else if (is_real(from) != is_real(to)) { + cast_op = get_cast_op(from, to); auto cast_type = to->is_primitive(PrimitiveTypeID::f16) ? PrimitiveType::f32 : stmt->cast_type; @@ -434,8 +455,18 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { auto op = stmt->op_type; auto ret_type = stmt->ret_type; + auto is_real_tensor = [](const DataType &dt) { + return dt->is() && + is_real(dt->cast()->get_element_type()); + }; + + auto is_integral_tensor = [](const DataType &dt) { + return dt->is() && + is_integral(dt->cast()->get_element_type()); + }; + if (op == BinaryOpType::add) { - if (is_real(stmt->ret_type)) { + if (is_real(stmt->ret_type) || is_real_tensor(stmt->ret_type)) { llvm_val[stmt] = builder->CreateFAdd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { @@ -443,7 +474,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { builder->CreateAdd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } } else if (op == BinaryOpType::sub) { - if (is_real(stmt->ret_type)) { + if (is_real(stmt->ret_type) || is_real_tensor(stmt->ret_type)) { llvm_val[stmt] = builder->CreateFSub(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { @@ -451,7 +482,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { builder->CreateSub(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } } else if (op == BinaryOpType::mul) { - if (is_real(stmt->ret_type)) { + if (is_real(stmt->ret_type) || is_real_tensor(stmt->ret_type)) { llvm_val[stmt] = builder->CreateFMul(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { @@ -459,7 +490,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { builder->CreateMul(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } } else if (op == BinaryOpType::floordiv) { - if (is_integral(ret_type)) + if (is_integral(ret_type) || is_integral_tensor(ret_type)) llvm_val[stmt] = create_call(fmt::format("floordiv_{}", data_type_name(ret_type)), {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); @@ -469,7 +500,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { llvm::Intrinsic::floor, {tlctx->get_data_type(ret_type)}, {div}); } } else if (op == BinaryOpType::div) { - if (is_real(stmt->ret_type)) { + if (is_real(stmt->ret_type) || is_real_tensor(stmt->ret_type)) { llvm_val[stmt] = builder->CreateFDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { @@ -506,7 +537,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { create_call("max_" #x, {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); \ } - if (is_real(ret_type)) { + if (is_real(ret_type) || is_real_tensor(ret_type)) { llvm_val[stmt] = builder->CreateMaxNum(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } @@ -527,7 +558,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { create_call("min_" #x, {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); \ } - if (is_real(ret_type)) { + if (is_real(ret_type) || is_real_tensor(ret_type)) { llvm_val[stmt] = builder->CreateMinNum(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } @@ -545,13 +576,13 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { llvm::Value *cmp = nullptr; auto input_type = stmt->lhs->ret_type; if (op == BinaryOpType::cmp_eq) { - if (is_real(input_type)) { + if (is_real(input_type) || is_real_tensor(input_type)) { cmp = builder->CreateFCmpOEQ(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { cmp = builder->CreateICmpEQ(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } } else if (op == BinaryOpType::cmp_le) { - if (is_real(input_type)) { + if (is_real(input_type) || is_real_tensor(input_type)) { cmp = builder->CreateFCmpOLE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { if (is_signed(input_type)) { @@ -563,7 +594,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { } } } else if (op == BinaryOpType::cmp_ge) { - if (is_real(input_type)) { + if (is_real(input_type) || is_real_tensor(input_type)) { cmp = builder->CreateFCmpOGE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { if (is_signed(input_type)) { @@ -575,7 +606,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { } } } else if (op == BinaryOpType::cmp_lt) { - if (is_real(input_type)) { + if (is_real(input_type) || is_real_tensor(input_type)) { cmp = builder->CreateFCmpOLT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { if (is_signed(input_type)) { @@ -587,7 +618,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { } } } else if (op == BinaryOpType::cmp_gt) { - if (is_real(input_type)) { + if (is_real(input_type) || is_real_tensor(input_type)) { cmp = builder->CreateFCmpOGT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { if (is_signed(input_type)) { @@ -599,7 +630,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { } } } else if (op == BinaryOpType::cmp_ne) { - if (is_real(input_type)) { + if (is_real(input_type) || is_real_tensor(input_type)) { cmp = builder->CreateFCmpONE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { cmp = builder->CreateICmpNE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); @@ -688,6 +719,11 @@ llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) { return llvm::Type::getDoubleTy(*llvm_context); } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*llvm_context); + } else if (dt->is()) { + auto tensor_type = dt->cast(); + auto element_type = llvm_type(tensor_type->get_element_type()); + return llvm::VectorType::get(element_type, tensor_type->get_num_elements(), + false); } else { TI_NOT_IMPLEMENTED; } @@ -800,12 +836,20 @@ void TaskCodeGenLLVM::visit(PrintStmt *stmt) { if (std::holds_alternative(content)) { auto arg_stmt = std::get(content); auto value = llvm_val[arg_stmt]; - if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) || - arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) - value = builder->CreateFPExt(value, - tlctx->get_data_type(PrimitiveType::f64)); - args.push_back(value); - formats += data_type_format(arg_stmt->ret_type); + if (arg_stmt->ret_type->is()) { + auto dtype = arg_stmt->ret_type->cast(); + for (int i = 0; i < dtype->get_num_elements(); ++i) { + args.push_back(builder->CreateExtractElement(value, i)); + } + formats += data_type_format(arg_stmt->ret_type); + } else { + if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) || + arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f16)) + value = builder->CreateFPExt( + value, tlctx->get_data_type(PrimitiveType::f64)); + args.push_back(value); + formats += data_type_format(arg_stmt->ret_type); + } } else { auto arg_str = std::get(content); auto value = builder->CreateGlobalStringPtr(arg_str, "content_string"); @@ -1440,6 +1484,9 @@ void TaskCodeGenLLVM::visit(GlobalStoreStmt *stmt) { TI_NOT_IMPLEMENTED; } } else { + TI_TRACE("Store {} to {}", stmt->val->name(), stmt->dest->name()); + llvm_val[stmt->val]->dump(); + llvm_val[stmt->dest]->dump(); builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]); } } @@ -1731,8 +1778,30 @@ void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) { llvm_val[stmt] = builder->CreateGEP(ptr_ty, llvm_val[stmt->origin], llvm_val[stmt->offset]); #else - llvm_val[stmt] = - builder->CreateGEP(llvm_val[stmt->origin], llvm_val[stmt->offset]); + if (stmt->origin->ret_type->is() || + (stmt->origin->ret_type->is() && + stmt->origin->ret_type->cast() + ->get_pointee_type() + ->is())) { + TensorType *stmt_dtype; + if (stmt->origin->ret_type->is()) { + stmt_dtype = stmt->origin->ret_type->cast() + ->get_pointee_type() + ->cast(); + } else { + stmt_dtype = stmt->origin->ret_type->cast(); + } + auto element_dtype = stmt_dtype->get_element_type(); + auto llvm_type = tlctx->get_data_type(element_dtype); + auto casted_ptr = builder->CreateBitCast( + llvm_val[stmt->origin], llvm::PointerType::get(llvm_type, 0)); + llvm_val[stmt] = builder->CreateBitCast( + builder->CreateGEP(casted_ptr, llvm_val[stmt->offset]), + llvm::PointerType::get(llvm_type, 0)); + } else { + llvm_val[stmt] = + builder->CreateGEP(llvm_val[stmt->origin], llvm_val[stmt->offset]); + } #endif } else { auto origin_address = builder->CreatePtrToInt( @@ -2508,6 +2577,16 @@ void TaskCodeGenLLVM::visit(MeshPatchIndexStmt *stmt) { llvm_val[stmt] = get_arg(2); } +void TaskCodeGenLLVM::visit(MatrixInitStmt *stmt) { + auto type = tlctx->get_data_type(stmt->ret_type->as()); + llvm::Value *vec = llvm::UndefValue::get(type); + for (int i = 0; i < stmt->values.size(); ++i) { + auto *elem = llvm_val[stmt->values[i]]; + vec = builder->CreateInsertElement(vec, elem, i); + } + llvm_val[stmt] = vec; +} + void TaskCodeGenLLVM::eliminate_unused_functions() { TaichiLLVMContext::eliminate_unused_functions( module.get(), [&](std::string func_name) { diff --git a/taichi/codegen/llvm/codegen_llvm.h b/taichi/codegen/llvm/codegen_llvm.h index 6f97ed7dff0f4..356866b12b8c4 100644 --- a/taichi/codegen/llvm/codegen_llvm.h +++ b/taichi/codegen/llvm/codegen_llvm.h @@ -369,6 +369,8 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { void visit(ReferenceStmt *stmt) override; + void visit(MatrixInitStmt *stmt) override; + llvm::Value *create_xlogue(std::unique_ptr &block); llvm::Value *create_mesh_xlogue(std::unique_ptr &block); diff --git a/taichi/inc/expressions.inc.h b/taichi/inc/expressions.inc.h index 9b20ba86bd80a..7fb3ef3975947 100644 --- a/taichi/inc/expressions.inc.h +++ b/taichi/inc/expressions.inc.h @@ -6,6 +6,7 @@ PER_EXPRESSION(TernaryOpExpression) PER_EXPRESSION(InternalFuncCallExpression) PER_EXPRESSION(ExternalTensorExpression) PER_EXPRESSION(GlobalVariableExpression) +PER_EXPRESSION(MatrixExpression) PER_EXPRESSION(IndexExpression) PER_EXPRESSION(StrideExpression) PER_EXPRESSION(RangeAssumptionExpression) diff --git a/taichi/inc/statements.inc.h b/taichi/inc/statements.inc.h index fe12a8941f7f5..05056ce46b9ae 100644 --- a/taichi/inc/statements.inc.h +++ b/taichi/inc/statements.inc.h @@ -38,6 +38,7 @@ PER_STATEMENT(LoopUniqueStmt) PER_STATEMENT(AssertStmt) PER_STATEMENT(ExternalFuncCallStmt) PER_STATEMENT(ExternalTensorShapeAlongAxisStmt) +PER_STATEMENT(MatrixInitStmt) // Locals with reverse-mode autodiff PER_STATEMENT(AdStackAllocaStmt) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 3c4ddfedf2cac..e49e234ab0635 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -274,6 +274,19 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access, if (regular) { result = get_store_forwarding_data(alloca, i); } + // for (int i = 0; i < local_load->src.size(); ++i) { + // bool regular = true; + // auto alloca = local_load->src[i].var; + // for (int l = 0; l < stmt->width(); l++) { + // if (local_load->src[l].offset != l || + // local_load->src[l].var != alloca) { + // regular = false; + // } + // } + // if (regular) { + // result = get_store_forwarding_data(alloca, i); + // } + // } } else if (auto global_load = stmt->cast()) { if (!after_lower_access && !autodiff_enabled) { result = get_store_forwarding_data(global_load->src, i); @@ -501,7 +514,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { } } auto load_ptrs = irpass::analysis::get_load_pointers(stmt); - if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1) { + if (load_ptrs.size() == 1 && store_ptrs.empty() && stmt->width() == 1 && + !stmt->is()) { // Identical load elimination auto load_ptr = load_ptrs.front(); if (!after_lower_access || diff --git a/taichi/ir/expression_printer.h b/taichi/ir/expression_printer.h index 74887c890e099..a32a0468d6907 100644 --- a/taichi/ir/expression_printer.h +++ b/taichi/ir/expression_printer.h @@ -110,6 +110,13 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { } } + void visit(MatrixExpression *expr) override { + emit('['); + emit_vector(expr->elements); + emit(']'); + emit(fmt::format(" (dt={})", expr->dt->to_string())); + } + void visit(IndexExpression *expr) override { expr->var->accept(this); emit('['); diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 56d5de025a9e8..fb856df25b553 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -195,8 +195,8 @@ void BinaryOpExpression::type_check(CompileConfig *config) { binary_op_type_symbol(type), lhs->ret_type->to_string(), rhs->ret_type->to_string())); }; - if (!lhs_type->is() || !rhs_type->is()) - error(); + // if (!lhs_type->is() || !rhs_type->is()) + // error(); if (binary_is_bitwise(type) && (!is_integral(lhs_type) || !is_integral(rhs_type))) error(); @@ -212,6 +212,26 @@ void BinaryOpExpression::type_check(CompileConfig *config) { return; } + if (lhs_type->is()) { + auto dtype = lhs_type->as()->get_element_type(); + if (rhs_type->is()) { + ret_type = promoted_type(dtype, rhs_type); + } else { + TI_ASSERT(rhs_type->is()); + auto rhs_tensor_type = rhs_type->cast(); + if (rhs_tensor_type->get_shape() != + lhs_type->cast()->get_shape()) + error(); + auto rhs_elem_type = rhs_type->as()->get_element_type(); + if (rhs_elem_type != PrimitiveType::unknown) + ret_type = promoted_type(dtype, rhs_elem_type); + } + // TODO: shape check! + ret_type = TypeFactory::create_tensor_type( + lhs_type->cast()->get_shape(), ret_type); + return; + } + // Some backends such as vulkan doesn't support fp64 // Try not promoting to fp64 unless necessary if (type == BinaryOpType::atan2) { @@ -412,6 +432,9 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, std::vector shape, int stride) { flatten_lvalue(var, ctx); + if (var->stmt->ret_type->is_primitive(PrimitiveTypeID::unknown)) { + var->stmt->ret_type = var->ret_type; + } Stmt *offset_stmt = ctx->push_back(TypedConstant(0)); for (int i = 0; i < (int)indices.size(); ++i) { flatten_rvalue(indices[i], ctx); @@ -429,6 +452,30 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, return ctx->push_back(var->stmt, offset_stmt); } +void MatrixExpression::type_check(CompileConfig *config) { + // TODO: typecheck matrix + for (auto &arg : elements) { + TI_ASSERT_TYPE_CHECKED(arg); + } +} + +void MatrixExpression::flatten(FlattenContext *ctx) { + // TODO: implement flatten + TI_ASSERT(this->dt->is()); + // std::vector values; + std::vector values; + for (auto &elt : elements) { + flatten_rvalue(elt, ctx); + // auto elt_alloca = ctx->push_back(elt->stmt->ret_type); + // ctx->push_back(elt_alloca, elt->stmt); + values.push_back(elt->stmt); + } + // stmt = ctx->push_back(values, + // this->dt->as()->get_shape()); + stmt = ctx->push_back(values); + stmt->ret_type = this->dt; +} + bool IndexExpression::is_field() const { return var.is(); } @@ -960,6 +1007,35 @@ Expr ASTBuilder::expr_alloca() { return var; } +Expr ASTBuilder::expr_alloca_local_matrix(const std::vector &shape, + const std::optional &dt, + const std::vector &elements) { + auto dtype = dt.value_or(PrimitiveType::unknown); + return Expr(std::make_shared(elements, shape, dtype)); +} + +Expr ASTBuilder::expr_indexed_matrix(const Expr &matrix, + const ExprGroup &indices) { + TI_ASSERT(matrix.get_ret_type()->is()); + auto shape = matrix.get_ret_type()->as()->get_shape(); + if (indices.size() != shape.size()) { + std::string shape_str = "["; + if (shape.size() > 0) { + shape_str += std::to_string(shape[0]); + for (int i = 1; i < shape.size(); i++) { + shape_str += ", " + std::to_string(shape[i]); + } + } + shape_str += "]"; + TI_ERROR( + "Indexed matrix of shape {} has wrong number of indices. Expected {} " + "but got " + "{}.", + shape_str, shape.size(), indices.size()); + } + return Expr(std::make_shared(matrix, indices)); +} + Expr ASTBuilder::expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements) { @@ -1129,6 +1205,8 @@ void flatten_rvalue(Expr ptr, Expression::FlattenContext *ctx) { } } else if (ptr.is()) { auto ix = ptr.cast(); + // if (ix->var->ret_type->is()) + // return; if (ix->is_local()) { flatten_local_load(ptr, ctx); } else { diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index adbee435b6d7d..ccab1cc5bf7d1 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -378,7 +378,34 @@ class BinaryOpExpression : public Expression { Expr lhs, rhs; BinaryOpExpression(const BinaryOpType &type, const Expr &lhs, const Expr &rhs) - : type(type), lhs(lhs), rhs(rhs) { + : type(type) { + auto to_broadcast_tensor = [](const Expr &elt, const DataType &dt) -> Expr { + TI_ASSERT(dt->is()); + auto tensor_type = dt->as(); + auto elt_type = tensor_type->get_element_type(); + TI_ASSERT_INFO(elt_type->is(), + "Only primitive types are supported in Tensors, got {}", + elt_type->to_string()); + std::vector broadcast_values(tensor_type->get_num_elements(), elt); + return Expr::make(broadcast_values, + tensor_type->get_shape(), elt_type); + }; + + auto unify_expr = [&](const Expr &e1, const Expr &e2) { + if ((!e1->ret_type->is() && + !e2->ret_type->is()) || + (e1->ret_type->is() && e2->ret_type->is())) { + return std::tuple(e1, e2); + } + if (!e1->ret_type->is()) { + return std::tuple(to_broadcast_tensor(e1, e2->ret_type), e2); + } + return std::tuple(e1, to_broadcast_tensor(e2, e1->ret_type)); + }; + + auto [unified_l, unified_r] = unify_expr(lhs, rhs); + this->lhs = unified_l; + this->rhs = unified_r; } void type_check(CompileConfig *config) override; @@ -503,6 +530,26 @@ class GlobalVariableExpression : public Expression { TI_DEFINE_ACCEPT_FOR_EXPRESSION }; +class MatrixExpression : public Expression { + public: + std::vector elements; + DataType dt; + + MatrixExpression(const std::vector &elements, + std::vector shape, + DataType element_type) + : elements(elements) { + this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type)); + this->ret_type = this->dt; + } + + void type_check(CompileConfig *config) override; + + void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION +}; + class IndexExpression : public Expression { public: // `var` is one of GlobalVariableExpression, ExternalTensorExpression, @@ -875,6 +922,10 @@ class ASTBuilder { const ExprGroup &args, const ExprGroup &outputs); Expr expr_alloca(); + Expr expr_alloca_local_matrix(const std::vector &shape, + const std::optional &dt, + const std::vector &elements); + Expr expr_indexed_matrix(const Expr &matrix, const ExprGroup &indices); Expr expr_alloca_local_tensor(const std::vector &shape, const DataType &element_type, const ExprGroup &elements); diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 9a2ea841e6a66..2a9c0a7aa59fa 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -367,7 +367,12 @@ class PtrOffsetStmt : public Stmt { bool is_local_ptr() const { if (origin->is() || origin->is()) { - TI_ASSERT_INFO(origin->ret_type->is(), + auto is_tensor_type = origin->ret_type->is() + ? origin->ret_type->cast() + ->get_pointee_type() + ->is() + : origin->ret_type->is(); + TI_ASSERT_INFO(is_tensor_type, "PtrOffsetStmt can only be used for Alloca (TensorType)."); } return origin->is() || origin->is(); @@ -601,8 +606,16 @@ class GlobalStoreStmt : public Stmt { class LocalLoadStmt : public Stmt { public: LaneAttribute src; + std::vector shape; - explicit LocalLoadStmt(const LaneAttribute &src) : src(src) { + explicit LocalLoadStmt(const LaneAttribute &src) + : src(src), shape({static_cast(src.data.size())}) { + TI_STMT_REG_FIELDS; + } + + LocalLoadStmt(const LaneAttribute &src, + const std::vector &shape) + : src(src), shape(shape) { TI_STMT_REG_FIELDS; } @@ -1807,5 +1820,17 @@ class MeshPatchIndexStmt : public Stmt { TI_DEFINE_ACCEPT_AND_CLONE }; +class MatrixInitStmt : public Stmt { + public: + std::vector values; + + MatrixInitStmt(const std::vector &values) : values(values) { + TI_STMT_REG_FIELDS; + } + + TI_STMT_DEF_FIELDS(ret_type, values); + TI_DEFINE_ACCEPT_AND_CLONE +}; + } // namespace lang } // namespace taichi diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 6b9d1a51e7990..0188722f15e6f 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -87,6 +87,14 @@ std::string TensorType::to_string() const { return s; } +int TensorType::vector_width() const { + int vw = 1; + for (auto dim : shape_) { + vw *= dim; + } + return vw; +} + int Type::vector_width() const { return 1; // TODO: CPU vectorization } diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 339e2553ffb32..fa5fe7a6b8b6b 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -175,6 +175,8 @@ class TensorType : public Type { return shape_; } + int vector_width() const; + Type *get_compute_type() override { return this; } @@ -380,6 +382,12 @@ class TypedConstant { } TypedConstant(DataType dt) : dt(dt) { + if (!dt->is()) { + assert(false); + } + TI_ASSERT_INFO(dt->is(), + "TypedConstant can only be PrimitiveType, got {}", + dt->to_string()); value_bits = 0; } diff --git a/taichi/ir/type_utils.cpp b/taichi/ir/type_utils.cpp index e49428b022445..dc5f816ecfa19 100644 --- a/taichi/ir/type_utils.cpp +++ b/taichi/ir/type_utils.cpp @@ -53,6 +53,36 @@ int data_type_size(DataType t) { } } +std::string tensor_type_format_helper(const std::vector &shape, + std::string format_str, + int dim) { + std::string fmt = "["; + for (int i = 0; i < shape[dim]; ++i) { + if (dim != shape.size() - 1) { + fmt += tensor_type_format_helper(shape, format_str, dim + 1); + } else { + fmt += format_str; + } + if (i != shape[dim] - 1) { + fmt += ", "; + if (dim == 0) { + fmt += "\n"; + } + } + } + fmt += "]"; + return fmt; +} + +std::string tensor_type_format(DataType t) { + TI_ASSERT(t->is()); + auto tensor_type = t->as(); + auto shape = tensor_type->get_shape(); + auto element_type = tensor_type->get_element_type(); + auto element_type_format = data_type_format(element_type); + return tensor_type_format_helper(shape, element_type_format, 0); +} + std::string data_type_format(DataType dt) { if (dt->is_primitive(PrimitiveTypeID::i16)) { return "%hd"; @@ -79,6 +109,8 @@ std::string data_type_format(DataType dt) { // TaskCodeGenLLVM::visit(PrintStmt *stmt) and // TaskCodeGenCUDA::visit(PrintStmt *stmt) for more details. return "%f"; + } else if (dt->is()) { + return tensor_type_format(dt); } else { TI_NOT_IMPLEMENTED } diff --git a/taichi/ir/type_utils.h b/taichi/ir/type_utils.h index 6d3b97154c94f..74423eaaebdfc 100644 --- a/taichi/ir/type_utils.h +++ b/taichi/ir/type_utils.h @@ -87,6 +87,8 @@ inline bool is_real(DataType dt) { } inline bool is_integral(DataType dt) { + if (dt->is()) + return is_integral(dt->as()->get_element_type()); return dt->is_primitive(PrimitiveTypeID::i8) || dt->is_primitive(PrimitiveTypeID::i16) || dt->is_primitive(PrimitiveTypeID::i32) || @@ -100,6 +102,8 @@ inline bool is_integral(DataType dt) { inline bool is_signed(DataType dt) { // Shall we return false if is_integral returns false? TI_ASSERT(is_integral(dt)); + if (auto t = dt->cast()) + return is_signed(t->get_element_type()); if (auto t = dt->cast()) return t->get_is_signed(); return dt->is_primitive(PrimitiveTypeID::i8) || diff --git a/taichi/program/compile_config.cpp b/taichi/program/compile_config.cpp index 930cc701a258b..5c978523b98e3 100644 --- a/taichi/program/compile_config.cpp +++ b/taichi/program/compile_config.cpp @@ -47,6 +47,7 @@ CompileConfig::CompileConfig() { detect_read_only = true; ndarray_use_cached_allocator = true; use_mesh = false; + real_matrix = false; saturating_grid_dim = 0; max_block_dim = 0; diff --git a/taichi/program/compile_config.h b/taichi/program/compile_config.h index 651bf8d3d4719..0ae10d82a9434 100644 --- a/taichi/program/compile_config.h +++ b/taichi/program/compile_config.h @@ -43,6 +43,7 @@ struct CompileConfig { bool detect_read_only; bool ndarray_use_cached_allocator; bool use_mesh; + bool real_matrix; DataType default_fp; DataType default_ip; DataType default_up; diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 931d54dab086f..bdf0d2df91246 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -62,6 +62,7 @@ Kernel::Kernel(Program &program, void Kernel::compile() { CurrentCallableGuard _(program, this); + TI_TRACE("compiling kernel {}", name); compiled_ = program->compile(*this); } @@ -409,12 +410,14 @@ void Kernel::init(Program &program, // concurrently, we need to lock this block of code together with // taichi::lang::context with a mutex. CurrentCallableGuard _(this->program, this); + TI_TRACE("Calling func in {}", name); func(); ir->as()->kernel = this; } if (!program.config.lazy_compilation) compile(); + TI_TRACE("Finish compiling {}", name); } // static diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index a34464d86c957..f8815285fe752 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -191,6 +191,7 @@ void export_lang(py::module &m) { .def_readwrite("ndarray_use_cached_allocator", &CompileConfig::ndarray_use_cached_allocator) .def_readwrite("use_mesh", &CompileConfig::use_mesh) + .def_readwrite("real_matrix", &CompileConfig::real_matrix) .def_readwrite("cc_compile_cmd", &CompileConfig::cc_compile_cmd) .def_readwrite("cc_link_cmd", &CompileConfig::cc_link_cmd) .def_readwrite("quant_opt_store_fusion", @@ -282,6 +283,8 @@ void export_lang(py::module &m) { .def("insert_activate", &ASTBuilder::insert_snode_activate) .def("insert_external_func_call", &ASTBuilder::insert_external_func_call) .def("expr_alloca", &ASTBuilder::expr_alloca) + .def("expr_alloca_matrix", &ASTBuilder::expr_alloca_local_matrix) + .def("expr_indexed_matrix", &ASTBuilder::expr_indexed_matrix) .def("expr_alloca_local_tensor", &ASTBuilder::expr_alloca_local_tensor) .def("expr_alloca_shared_array", &ASTBuilder::expr_alloca_shared_array) .def("create_assert_stmt", &ASTBuilder::create_assert_stmt) diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index 379acb1efdae0..edde242f48f96 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -135,6 +135,10 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { return llvm::Type::getInt64Ty(*ctx); } else if (dt->is_primitive(PrimitiveTypeID::f16)) { return llvm::Type::getHalfTy(*ctx); + } else if (dt->is()) { + auto vectorty = dt->as(); + auto dtype = this->get_data_type(vectorty->get_element_type()); + return llvm::VectorType::get(dtype, vectorty->get_num_elements(), false); } else { TI_INFO(data_type_name(dt)); TI_NOT_IMPLEMENTED diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index a3d04ca5936df..c32dc2b2f4a8c 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -112,6 +112,11 @@ class AlgSimp : public BasicStmtVisitor { if ((fast_math || is_integral(stmt->ret_type)) && (alg_is_zero(lhs) || alg_is_zero(rhs))) { // fast_math or integral operands: 0 * a -> 0, a * 0 -> 0 + if (stmt->ret_type->is() || + stmt->rhs->ret_type->is()) { + // TODO: handle 0-tensor + return false; + } replace_with_zero(stmt); return true; } @@ -163,8 +168,14 @@ class AlgSimp : public BasicStmtVisitor { if ((fast_math || is_integral(stmt->ret_type)) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a / a -> 1 - replace_with_one(stmt); - return true; + if (stmt->lhs->ret_type->is() && + stmt->rhs->ret_type->is()) { + replace_with_one(stmt); + return true; + } else { + // TODO: handle tensor division + return false; + } } if (fast_math && rhs && is_real(rhs->ret_type) && stmt->op_type != BinaryOpType::floordiv) { @@ -244,7 +255,13 @@ class AlgSimp : public BasicStmtVisitor { (fast_math || is_integral(stmt->ret_type)) && irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // fast_math or integral operands: a -^ a -> 0 - replace_with_zero(stmt); + if (stmt->lhs->ret_type->is() && + stmt->rhs->ret_type->is()) { + replace_with_zero(stmt); + } else { + // TODO: handle tensor operations + return; + } } } else if (rhs && stmt->op_type == BinaryOpType::pow) { float64 exponent = rhs->val[0].val_cast_to_float64(); @@ -329,6 +346,11 @@ class AlgSimp : public BasicStmtVisitor { modifier.erase(stmt); } else if (alg_is_zero(lhs) || alg_is_zero(rhs)) { // 0 & a -> 0, a & 0 -> 0 + if (stmt->ret_type->is() || + stmt->rhs->ret_type->is()) { + // TODO: support tensor type + return; + } replace_with_zero(stmt); } else if (irpass::analysis::same_value(stmt->lhs, stmt->rhs)) { // a & a -> a @@ -343,6 +365,11 @@ class AlgSimp : public BasicStmtVisitor { // a << 0 -> a // 0 << a -> 0 // 0 >> a -> 0 + if (stmt->ret_type->is() || + stmt->rhs->ret_type->is()) { + // TODO: support tensor type + return; + } TI_ASSERT(stmt->lhs->ret_type == stmt->ret_type); stmt->replace_usages_with(stmt->lhs); modifier.erase(stmt); diff --git a/taichi/transforms/constant_fold.cpp b/taichi/transforms/constant_fold.cpp index b8bf03d990ae6..6ea23b3c3c621 100644 --- a/taichi/transforms/constant_fold.cpp +++ b/taichi/transforms/constant_fold.cpp @@ -133,6 +133,9 @@ class ConstantFold : public BasicStmtVisitor { } void visit(BinaryOpStmt *stmt) override { + if (stmt->lhs->ret_type->is() || + stmt->rhs->ret_type->is()) + return; auto lhs = stmt->lhs->cast(); auto rhs = stmt->rhs->cast(); if (!lhs || !rhs) diff --git a/taichi/transforms/die.cpp b/taichi/transforms/die.cpp index 3176d8f576949..f6d696c4da498 100644 --- a/taichi/transforms/die.cpp +++ b/taichi/transforms/die.cpp @@ -108,6 +108,13 @@ class DIE : public IRVisitor { } stmt->all_blocks_accept(this, true); } + + void visit(MatrixInitStmt *stmt) override { + register_usage(stmt); + for (auto &elts : stmt->values) { + elts->accept(this); + } + } }; namespace irpass { diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index ca462e42773e8..a377e24a6dbfe 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -37,7 +37,7 @@ std::string block_dim_info(int block_dim) { std::string to_string(const LaneAttribute &ptr) { std::string ret = " ["; for (int i = 0; i < (int)ptr.size(); i++) { - ret += fmt::format("{}[{}]", ptr[i].var->name(), ptr[i].offset); + ret += fmt::format("{}", ptr[i].var->name()); if (i + 1 < (int)ptr.size()) ret += ", "; } @@ -794,6 +794,19 @@ class IRPrinter : public IRVisitor { print("{}{} = ref({})", stmt->type_hint(), stmt->name(), stmt->var->name()); } + void visit(MatrixInitStmt *stmt) override { + std::string result = ""; + result += fmt::format("{}{} = [", stmt->type_hint(), stmt->name()); + for (int i = 0; i < stmt->values.size(); ++i) { + result += stmt->values[i]->name(); + if (i != stmt->values.size() - 1) { + result += ", "; + } + } + result += "]"; + print(result); + } + private: std::string expr_to_string(Expr &expr) { return expr_to_string(expr.expr.get()); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index e16c62501ba3c..1a69590287319 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -106,9 +106,21 @@ class TypeCheck : public IRVisitor { .ptr_removed(); stmt->ret_type = lookup; } - } else { + } else if (stmt->src.size() == 1) { auto lookup = stmt->src[0].var->ret_type; stmt->ret_type = lookup; + } else { + TI_ASSERT(stmt->src.size() > 1); + auto acc = stmt->src[0].var->ret_type; + for (int i = 1; i < stmt->src.size(); i++) { + acc = promoted_type(acc, stmt->src[i].var->ret_type); + } + if (stmt->ret_type != PrimitiveType::unknown) { + TI_ASSERT(stmt->ret_type->is()); + acc = promoted_type( + acc, stmt->ret_type->as()->get_element_type()); + } + stmt->ret_type = TypeFactory::create_tensor_type(stmt->shape, acc); } } @@ -266,6 +278,7 @@ class TypeCheck : public IRVisitor { } void cast(Stmt *&val, DataType dt) { + TI_TRACE("Cast {} to {}", val->name(), dt->to_string()); if (val->ret_type == dt) return; @@ -294,10 +307,26 @@ class TypeCheck : public IRVisitor { if (stmt->op_type == BinaryOpType::truediv) { auto default_fp = config_.default_fp; if (!is_real(stmt->lhs->ret_type)) { - cast(stmt->lhs, default_fp); + if (stmt->lhs->ret_type->is()) { + cast(stmt->lhs, default_fp); + } else { + TI_ASSERT(stmt->lhs->ret_type->is()); + cast(stmt->lhs, + TypeFactory::create_tensor_type( + stmt->lhs->ret_type->as()->get_shape(), + default_fp)); + } } if (!is_real(stmt->rhs->ret_type)) { - cast(stmt->rhs, default_fp); + if (stmt->rhs->ret_type->is()) { + cast(stmt->rhs, default_fp); + } else { + TI_ASSERT(stmt->rhs->ret_type->is()); + cast(stmt->rhs, + TypeFactory::create_tensor_type( + stmt->rhs->ret_type->as()->get_shape(), + default_fp)); + } } stmt->op_type = BinaryOpType::div; } @@ -317,6 +346,40 @@ class TypeCheck : public IRVisitor { } } + auto lhs_is_tensor = stmt->lhs->ret_type->is(); + auto rhs_is_tensor = stmt->rhs->ret_type->is(); + + if (lhs_is_tensor || rhs_is_tensor) { + auto lhs_dtype = + lhs_is_tensor + ? DataType( + stmt->lhs->ret_type->as()->get_element_type()) + : stmt->lhs->ret_type; + auto rhs_dtype = + rhs_is_tensor + ? DataType( + stmt->rhs->ret_type->as()->get_element_type()) + : stmt->rhs->ret_type; + auto dtype = promoted_type(lhs_dtype, rhs_dtype); + if (dtype != lhs_dtype) + cast( + stmt->lhs, + lhs_is_tensor + ? TypeFactory::create_tensor_type( + stmt->lhs->ret_type->as()->get_shape(), dtype) + : dtype); + if (dtype != rhs_dtype) + cast( + stmt->rhs, + rhs_is_tensor + ? TypeFactory::create_tensor_type( + stmt->rhs->ret_type->as()->get_shape(), dtype) + : dtype); + // TODO: add shape inference for matrix ops below + stmt->ret_type = stmt->lhs->ret_type; + return; + } + if (stmt->lhs->ret_type != stmt->rhs->ret_type) { DataType ret_type; if (is_shift_op(stmt->op_type)) { @@ -535,8 +598,8 @@ class TypeCheck : public IRVisitor { } void visit(GlobalTemporaryStmt *stmt) override { - if (!stmt->ret_type->is()) - stmt->ret_type.set_is_pointer(true); + // if (!stmt->ret_type->is()) + stmt->ret_type.set_is_pointer(true); } void visit(InternalFuncStmt *stmt) override { @@ -553,6 +616,26 @@ class TypeCheck : public IRVisitor { stmt->ret_type = stmt->var->ret_type; stmt->ret_type.set_is_pointer(true); } + + void visit(MatrixInitStmt *stmt) override { + TI_ASSERT_INFO(stmt->ret_type->is(), + "Matrix should have tensor type, got {}", + stmt->ret_type->to_string()); + auto tensor_type = stmt->ret_type->as(); + auto element_dtype = tensor_type->get_element_type(); + for (auto elt : stmt->values) { + element_dtype = promoted_type(element_dtype, elt->ret_type); + } + for (int i = 0; i < stmt->values.size(); ++i) { + if (element_dtype != stmt->values[i]->ret_type) { + cast(stmt->values[i], element_dtype); + } + } + if (element_dtype != tensor_type->get_element_type()) { + stmt->ret_type = TypeFactory::create_tensor_type(tensor_type->get_shape(), + element_dtype); + } + } }; namespace irpass {