diff --git a/misc/prtags.json b/misc/prtags.json index f085670436ee1..69324f2740278 100644 --- a/misc/prtags.json +++ b/misc/prtags.json @@ -33,5 +33,6 @@ "error" : "Error messages", "blender" : "Blender intergration", "export" : "Exporting kernels", + "type" : "Type system", "release" : "Release" } diff --git a/python/taichi/lang/util.py b/python/taichi/lang/util.py index 6849cb28ef549..dd0cac67b6dfd 100644 --- a/python/taichi/lang/util.py +++ b/python/taichi/lang/util.py @@ -30,9 +30,9 @@ def is_taichi_class(rhs): # Real types -float32 = taichi_lang_core.DataType.float32 +float32 = taichi_lang_core.DataType_float32 f32 = float32 -float64 = taichi_lang_core.DataType.float64 +float64 = taichi_lang_core.DataType_float64 f64 = float64 real_types = [f32, f64, float] @@ -40,22 +40,22 @@ def is_taichi_class(rhs): # Integer types -int8 = taichi_lang_core.DataType.int8 +int8 = taichi_lang_core.DataType_int8 i8 = int8 -int16 = taichi_lang_core.DataType.int16 +int16 = taichi_lang_core.DataType_int16 i16 = int16 -int32 = taichi_lang_core.DataType.int32 +int32 = taichi_lang_core.DataType_int32 i32 = int32 -int64 = taichi_lang_core.DataType.int64 +int64 = taichi_lang_core.DataType_int64 i64 = int64 -uint8 = taichi_lang_core.DataType.uint8 +uint8 = taichi_lang_core.DataType_uint8 u8 = uint8 -uint16 = taichi_lang_core.DataType.uint16 +uint16 = taichi_lang_core.DataType_uint16 u16 = uint16 -uint32 = taichi_lang_core.DataType.uint32 +uint32 = taichi_lang_core.DataType_uint32 u32 = uint32 -uint64 = taichi_lang_core.DataType.uint64 +uint64 = taichi_lang_core.DataType_uint64 u64 = uint64 integer_types = [i8, i16, i32, i64, u8, u16, u32, u64, int] diff --git a/taichi/backends/cc/codegen_cc.cpp b/taichi/backends/cc/codegen_cc.cpp index eca9371f9d66c..2170057ebda42 100644 --- a/taichi/backends/cc/codegen_cc.cpp +++ b/taichi/backends/cc/codegen_cc.cpp @@ -257,19 +257,17 @@ class CCTransformer : public IRVisitor { } static std::string _get_libc_function_name(std::string name, DataType dt) { - switch (dt) { - case DataType::i32: - return name; - case DataType::i64: - return "ll" + name; - case DataType::f32: - return name + "f"; - case DataType::f64: - return name; - default: - TI_ERROR("Unsupported function \"{}\" for DataType={} on C backend", - name, data_type_name(dt)); - } + if (dt == DataType::i32) + return name; + else if (dt == DataType::i64) + return "ll" + name; + else if (dt == DataType::f32) + return name + "f"; + else if (dt == DataType::f64) + return name; + else + TI_ERROR("Unsupported function \"{}\" for DataType={} on C backend", name, + data_type_name(dt)); } static std::string get_libc_function_name(std::string name, DataType dt) { @@ -598,7 +596,7 @@ class CCTransformer : public IRVisitor { void emit_header(std::string f, Args &&... args) { line_appender_header.append(std::move(f), std::move(args)...); } -}; +}; // namespace cccp std::unique_ptr CCKernelGen::compile() { auto program = kernel->program.cc_program.get(); diff --git a/taichi/backends/metal/data_types.cpp b/taichi/backends/metal/data_types.cpp index 121c07205bf79..11fd443ec666c 100644 --- a/taichi/backends/metal/data_types.cpp +++ b/taichi/backends/metal/data_types.cpp @@ -4,28 +4,24 @@ TLANG_NAMESPACE_BEGIN namespace metal { MetalDataType to_metal_type(DataType dt) { - switch (dt) { -#define METAL_CASE(x) \ - case DataType::x: \ - return MetalDataType::x - - METAL_CASE(f32); - METAL_CASE(f64); - METAL_CASE(i8); - METAL_CASE(i16); - METAL_CASE(i32); - METAL_CASE(i64); - METAL_CASE(u8); - METAL_CASE(u16); - METAL_CASE(u32); - METAL_CASE(u64); - METAL_CASE(unknown); -#undef METAL_CASE - - default: - TI_NOT_IMPLEMENTED; - break; +#define METAL_CASE(x) else if (dt == DataType::x) return MetalDataType::x + if (false) { } + METAL_CASE(f32); + METAL_CASE(f64); + METAL_CASE(i8); + METAL_CASE(i16); + METAL_CASE(i32); + METAL_CASE(i64); + METAL_CASE(u8); + METAL_CASE(u16); + METAL_CASE(u32); + METAL_CASE(u64); + METAL_CASE(unknown); + else { + TI_NOT_IMPLEMENTED; + } +#undef METAL_CASE return MetalDataType::unknown; } diff --git a/taichi/backends/opengl/opengl_data_types.h b/taichi/backends/opengl/opengl_data_types.h index d247bdb5830f5..2834db90a992a 100644 --- a/taichi/backends/opengl/opengl_data_types.h +++ b/taichi/backends/opengl/opengl_data_types.h @@ -8,20 +8,16 @@ namespace opengl { inline std::string opengl_data_type_name(DataType dt) { // https://www.khronos.org/opengl/wiki/Data_Type_(GLSL) - switch (dt) { - case DataType::f32: - return "float"; - case DataType::f64: - return "double"; - case DataType::i32: - return "int"; - case DataType::i64: - return "int64_t"; - default: - TI_NOT_IMPLEMENTED; - break; - } - return ""; + if (dt == DataType::f32) + return "float"; + else if (dt == DataType::f64) + return "double"; + else if (dt == DataType::i32) + return "int"; + else if (dt == DataType::i64) + return "int64_t"; + else + TI_NOT_IMPLEMENTED; } inline bool is_opengl_binary_op_infix(BinaryOpType type) { @@ -36,15 +32,12 @@ inline bool is_opengl_binary_op_different_return_type(BinaryOpType type) { } inline int opengl_data_address_shifter(DataType type) { - switch (type) { - case DataType::f32: - case DataType::i32: - return 2; - case DataType::f64: - case DataType::i64: - return 3; - default: - TI_NOT_IMPLEMENTED + if (type == DataType::f32 || type == DataType::i32) + return 2; + else if (type == DataType::f64 || type == DataType::i64) { + return 3; + } else { + TI_NOT_IMPLEMENTED } } diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index b80a2a7525af1..4bbc4efe46995 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -860,7 +860,7 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) { dest_ty = tlctx->get_data_type(stmt->ret_type.data_type); auto dest_bits = dest_ty->getPrimitiveSizeInBits(); auto truncated = builder->CreateTrunc( - raw_arg, Type::getIntNTy(*llvm_context, dest_bits)); + raw_arg, llvm::Type::getIntNTy(*llvm_context, dest_bits)); llvm_val[stmt] = builder->CreateBitCast(truncated, dest_ty); } } @@ -1327,7 +1327,7 @@ void CodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, bool spmd) { // per-leaf-block for loop auto loop_index = - create_entry_block_alloca(Type::getInt32Ty(*llvm_context)); + create_entry_block_alloca(llvm::Type::getInt32Ty(*llvm_context)); llvm::Value *thread_idx = nullptr, *block_dim = nullptr; diff --git a/taichi/lang_util.cpp b/taichi/lang_util.cpp index 68b4b3b56876b..31147e83c63ef 100644 --- a/taichi/lang_util.cpp +++ b/taichi/lang_util.cpp @@ -1,6 +1,6 @@ // Definitions of utility functions and enums -#include "lang_util.h" +#include "taichi/lang_util.h" #include "taichi/math/linalg.h" #include "taichi/program/arch.h" @@ -29,6 +29,37 @@ real get_cpu_frequency() { real default_measurement_time = 1; +// Note: these primitive types should never be freed. They are supposed to live +// together with the process. +#define PER_TYPE(x) \ + DataType DataType::x = \ + DataType(new PrimitiveType(PrimitiveType::primitive_type::x)); +#include "taichi/inc/data_type.inc.h" +#undef PER_TYPE + +DataType PrimitiveType::get(PrimitiveType::primitive_type t) { + if (false) { + } +#define PER_TYPE(x) else if (t == primitive_type::x) return DataType::x; +#include "taichi/inc/data_type.inc.h" +#undef PER_TYPE + else { + TI_NOT_IMPLEMENTED + } +} + +std::size_t DataType::hash() const { + if (auto primitive = dynamic_cast(ptr_)) { + return (std::size_t)primitive->type; + } else { + TI_NOT_IMPLEMENTED + } +} + +std::string PrimitiveType::to_string() const { + return data_type_name(DataType(this)); +} + real measure_cpe(std::function target, int64 elements_per_call, real time_second) { @@ -65,30 +96,26 @@ real measure_cpe(std::function target, } std::string data_type_name(DataType t) { - switch (t) { -#define REGISTER_DATA_TYPE(i, j) \ - case DataType::i: \ - return #j; - - REGISTER_DATA_TYPE(f16, float16); - REGISTER_DATA_TYPE(f32, float32); - REGISTER_DATA_TYPE(f64, float64); - REGISTER_DATA_TYPE(u1, int1); - REGISTER_DATA_TYPE(i8, int8); - REGISTER_DATA_TYPE(i16, int16); - REGISTER_DATA_TYPE(i32, int32); - REGISTER_DATA_TYPE(i64, int64); - REGISTER_DATA_TYPE(u8, uint8); - REGISTER_DATA_TYPE(u16, uint16); - REGISTER_DATA_TYPE(u32, uint32); - REGISTER_DATA_TYPE(u64, uint64); - REGISTER_DATA_TYPE(gen, generic); - REGISTER_DATA_TYPE(unknown, unknown); +#define REGISTER_DATA_TYPE(i, j) else if (t == DataType::i) return #j + if (false) { + } + REGISTER_DATA_TYPE(f16, float16); + REGISTER_DATA_TYPE(f32, float32); + REGISTER_DATA_TYPE(f64, float64); + REGISTER_DATA_TYPE(u1, int1); + REGISTER_DATA_TYPE(i8, int8); + REGISTER_DATA_TYPE(i16, int16); + REGISTER_DATA_TYPE(i32, int32); + REGISTER_DATA_TYPE(i64, int64); + REGISTER_DATA_TYPE(u8, uint8); + REGISTER_DATA_TYPE(u16, uint16); + REGISTER_DATA_TYPE(u32, uint32); + REGISTER_DATA_TYPE(u64, uint64); + REGISTER_DATA_TYPE(gen, generic); + REGISTER_DATA_TYPE(unknown, unknown); #undef REGISTER_DATA_TYPE - default: - TI_NOT_IMPLEMENTED - } + else TI_NOT_IMPLEMENTED } std::string data_type_format(DataType dt) { @@ -110,48 +137,42 @@ std::string data_type_format(DataType dt) { } int data_type_size(DataType t) { - switch (t) { - case DataType::f16: - return 2; - case DataType::gen: - return 0; - case DataType::unknown: - return -1; - -#define REGISTER_DATA_TYPE(i, j) \ - case DataType::i: \ - return sizeof(j); - - REGISTER_DATA_TYPE(f32, float32); - REGISTER_DATA_TYPE(f64, float64); - REGISTER_DATA_TYPE(i8, int8); - REGISTER_DATA_TYPE(i16, int16); - REGISTER_DATA_TYPE(i32, int32); - REGISTER_DATA_TYPE(i64, int64); - REGISTER_DATA_TYPE(u8, uint8); - REGISTER_DATA_TYPE(u16, uint16); - REGISTER_DATA_TYPE(u32, uint32); - REGISTER_DATA_TYPE(u64, uint64); + if (false) { + } else if (t == DataType::f16) + return 2; + else if (t == DataType::gen) + return 0; + else if (t == DataType::unknown) + return -1; + +#define REGISTER_DATA_TYPE(i, j) else if (t == DataType::i) return sizeof(j) + + REGISTER_DATA_TYPE(f32, float32); + REGISTER_DATA_TYPE(f64, float64); + REGISTER_DATA_TYPE(i8, int8); + REGISTER_DATA_TYPE(i16, int16); + REGISTER_DATA_TYPE(i32, int32); + REGISTER_DATA_TYPE(i64, int64); + REGISTER_DATA_TYPE(u8, uint8); + REGISTER_DATA_TYPE(u16, uint16); + REGISTER_DATA_TYPE(u32, uint32); + REGISTER_DATA_TYPE(u64, uint64); #undef REGISTER_DATA_TYPE - default: - TI_NOT_IMPLEMENTED + else { + TI_NOT_IMPLEMENTED } } std::string data_type_short_name(DataType t) { - switch (t) { -#define PER_TYPE(i) \ - case DataType::i: \ - return #i; - + if (false) { + } +#define PER_TYPE(i) else if (t == DataType::i) return #i; #include "taichi/inc/data_type.inc.h" - #undef PER_TYPE - default: - TI_NOT_IMPLEMENTED - } -} + else + TI_NOT_IMPLEMENTED +} // namespace lang std::string snode_type_name(SNodeType t) { switch (t) { @@ -328,8 +349,9 @@ namespace { class TypePromotionMapping { public: TypePromotionMapping() { -#define TRY_SECOND(x, y) \ - mapping[std::make_pair(get_data_type(), get_data_type())] = \ +#define TRY_SECOND(x, y) \ + mapping[std::make_pair(to_primitive_type(get_data_type()), \ + to_primitive_type(get_data_type()))] = \ get_data_type() + std::declval())>(); #define TRY_FIRST(x) \ TRY_SECOND(x, float32); \ @@ -355,11 +377,19 @@ class TypePromotionMapping { TRY_FIRST(uint64); } DataType query(DataType x, DataType y) { - return mapping[std::make_pair(x, y)]; + return mapping[std::make_pair(to_primitive_type(x), to_primitive_type(y))]; } private: - std::map, DataType> mapping; + std::map< + std::pair, + DataType> + mapping; + static PrimitiveType::primitive_type to_primitive_type(const DataType d) { + auto primitive = dynamic_cast(d.get_ptr()); + TI_ASSERT(primitive); + return primitive->type; + }; }; TypePromotionMapping type_promotion_mapping; } // namespace diff --git a/taichi/lang_util.h b/taichi/lang_util.h index 2b8344da61a3c..b3572447716c5 100644 --- a/taichi/lang_util.h +++ b/taichi/lang_util.h @@ -19,10 +19,64 @@ struct Context; using FunctionType = std::function; -enum class DataType : int { +class Type { + public: + virtual std::string to_string() const = 0; + virtual ~Type() { + } +}; + +// A "Type" handle. This should be removed later. +class DataType { + public: +#define PER_TYPE(x) static DataType x; +#include "taichi/inc/data_type.inc.h" +#undef PER_TYPE + DataType() : ptr_(unknown.ptr_) { + } + + DataType(const Type *ptr) : ptr_(ptr) { + } + + bool operator==(const DataType &o) const { + return ptr_ == o.ptr_; + } + + bool operator!=(const DataType &o) const { + return !(*this == o); + } + + std::size_t hash() const; + + std::string to_string() const { + return ptr_->to_string(); + }; + + // TODO: DataType itself should be a pointer in the future + const Type *get_ptr() const { + return ptr_; + } + + private: + const Type *ptr_; +}; + +class PrimitiveType : public Type { + public: + enum class primitive_type : int { #define PER_TYPE(x) x, #include "taichi/inc/data_type.inc.h" #undef PER_TYPE + }; + + primitive_type type; + + PrimitiveType(primitive_type type) : type(type) { + } + + std::string to_string() const override; + + static DataType get(primitive_type type); }; template @@ -78,7 +132,7 @@ enum class UnaryOpType : int { std::string unary_op_type_name(UnaryOpType type); -inline bool unary_op_is_cast(UnaryOpType op) { +inline bool constexpr unary_op_is_cast(UnaryOpType op) { return op == UnaryOpType::cast_value || op == UnaryOpType::cast_bits; } @@ -88,41 +142,39 @@ inline bool constexpr is_trigonometric(UnaryOpType op) { op == UnaryOpType::tan || op == UnaryOpType::tanh; } -inline bool constexpr is_real(DataType dt) { +inline bool is_real(DataType dt) { return dt == DataType::f16 || dt == DataType::f32 || dt == DataType::f64; } -inline bool constexpr is_integral(DataType dt) { +inline bool is_integral(DataType dt) { return dt == DataType::i8 || dt == DataType::i16 || dt == DataType::i32 || dt == DataType::i64 || dt == DataType::u8 || dt == DataType::u16 || dt == DataType::u32 || dt == DataType::u64; } -inline bool constexpr is_signed(DataType dt) { +inline bool is_signed(DataType dt) { TI_ASSERT(is_integral(dt)); return dt == DataType::i8 || dt == DataType::i16 || dt == DataType::i32 || dt == DataType::i64; } -inline bool constexpr is_unsigned(DataType dt) { +inline bool is_unsigned(DataType dt) { TI_ASSERT(is_integral(dt)); return !is_signed(dt); } inline DataType to_unsigned(DataType dt) { TI_ASSERT(is_signed(dt)); - switch (dt) { - case DataType::i8: - return DataType::u8; - case DataType::i16: - return DataType::u16; - case DataType::i32: - return DataType::u32; - case DataType::i64: - return DataType::u64; - default: - return DataType::unknown; - } + if (dt == DataType::i8) + return DataType::u8; + else if (dt == DataType::i16) + return DataType::u16; + else if (dt == DataType::i32) + return DataType::u32; + else if (dt == DataType::i64) + return DataType::u64; + else + return DataType::unknown; } inline bool needs_grad(DataType dt) { @@ -130,8 +182,8 @@ inline bool needs_grad(DataType dt) { } // Regular binary ops: -// Operations that take two oprands, and returns a single operand with the same -// type +// Operations that take two oprands, and returns a single operand with the +// same type enum class BinaryOpType : int { #define PER_BINARY_OP(x) x, diff --git a/taichi/llvm/llvm_context.cpp b/taichi/llvm/llvm_context.cpp index 10701e2d651de..f3526493d7de1 100644 --- a/taichi/llvm/llvm_context.cpp +++ b/taichi/llvm/llvm_context.cpp @@ -373,7 +373,7 @@ std::unique_ptr TaichiLLVMContext::clone_runtime_module() { auto patch_intrinsic = [&](std::string name, Intrinsic::ID intrin, bool ret = true, - std::vector types = {}, + std::vector types = {}, std::vector extra_args = {}) { auto func = runtime_module->getFunction(name); TI_ERROR_UNLESS(func, "Function {} not found", name); diff --git a/taichi/program/program.h b/taichi/program/program.h index 645ef5770406b..7db40e8123480 100644 --- a/taichi/program/program.h +++ b/taichi/program/program.h @@ -57,9 +57,8 @@ template <> struct hash { std::size_t operator()(taichi::lang::JITEvaluatorId const &id) const noexcept { - return ((std::size_t)id.op | ((std::size_t)id.ret << 8) | - ((std::size_t)id.lhs << 16) | ((std::size_t)id.rhs << 24) | - ((std::size_t)id.is_binary << 31)) ^ + return ((std::size_t)id.op | (id.ret.hash() << 8) | (id.lhs.hash() << 16) | + (id.rhs.hash() << 24) | ((std::size_t)id.is_binary << 31)) ^ (std::hash{}(id.thread_id) << 32); } }; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index adf530e8503bf..4a3d6404cd467 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -82,6 +82,25 @@ void export_lang(py::module &m) { #undef PER_EXTENSION .export_values(); + py::class_(m, "DataType") + .def(py::self == py::self) + .def(py::pickle( + [](const DataType &dt) { + // Note: this only works for primitive types, which is fine for now. + auto primitive = dynamic_cast(dt.get_ptr()); + TI_ASSERT(primitive); + return py::make_tuple((std::size_t)primitive->type); + }, + [](py::tuple t) { + if (t.size() != 1) + throw std::runtime_error("Invalid state!"); + + DataType dt = PrimitiveType::get( + (PrimitiveType::primitive_type)(t[0].cast())); + + return dt; + })); + py::class_(m, "CompileConfig") .def(py::init<>()) .def_readwrite("arch", &CompileConfig::arch) @@ -529,11 +548,10 @@ void export_lang(py::module &m) { unary.export_values(); m.def("make_unary_op_expr", Expr::make); - - auto &&data_type = py::enum_(m, "DataType", py::arithmetic()); - for (int t = 0; t <= (int)DataType::unknown; t++) - data_type.value(data_type_name(DataType(t)).c_str(), DataType(t)); - data_type.export_values(); +#define PER_TYPE(x) \ + m.attr(("DataType_" + data_type_name(DataType::x)).c_str()) = DataType::x; +#include "taichi/inc/data_type.inc.h" +#undef PER_TYPE m.def("is_integral", is_integral); m.def("is_signed", is_signed); @@ -672,7 +690,7 @@ void export_lang(py::module &m) { #if defined(TI_WITH_CUDA) return CUDAContext::get_instance().get_compute_capability(); #else - TI_NOT_IMPLEMENTED + TI_NOT_IMPLEMENTED #endif } else { TI_ERROR("Key {} not supported in query_int64", key); diff --git a/taichi/transforms/constant_fold.cpp b/taichi/transforms/constant_fold.cpp index 6c2f5713f3d3f..8c44a6dddb108 100644 --- a/taichi/transforms/constant_fold.cpp +++ b/taichi/transforms/constant_fold.cpp @@ -76,15 +76,11 @@ class ConstantFold : public BasicStmtVisitor { // ConstStmt of `bad` types like `i8` is not supported by LLVM. // Discussion: // https://github.com/taichi-dev/taichi/pull/839#issuecomment-625902727 - switch (dt) { - case DataType::i32: - case DataType::f32: - case DataType::i64: - case DataType::f64: - return true; - default: - return false; - } + if (dt == DataType::i32 || dt == DataType::f32 || dt == DataType::i64 || + dt == DataType::f64) + return true; + else + return false; } static bool jit_evaluate_binary_op(TypedConstant &ret, diff --git a/tests/python/test_ad_if.py b/tests/python/test_ad_if.py index 4020cc9a654ad..41c053e542f65 100644 --- a/tests/python/test_ad_if.py +++ b/tests/python/test_ad_if.py @@ -208,8 +208,8 @@ def func(): assert x.grad[1] == -0.25 -@ti.require(ti.extension.adstack, ti.extension.data64) -@ti.all_archs_with(default_fp=ti.f64) +@ti.test(require=[ti.extension.adstack, ti.extension.data64], + default_fp=ti.f64) def test_ad_if_parallel_complex_f64(): x = ti.field(ti.f64, shape=2) y = ti.field(ti.f64, shape=2)