diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index eac8111ed3b7a..f4ef60e4da34a 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -266,7 +266,7 @@ bool QuantFloatType::get_is_signed() const { BitStructType::BitStructType( PrimitiveType *physical_type, - const std::vector &member_types, + const std::vector &member_types, const std::vector &member_bit_offsets, const std::vector &member_exponents, const std::vector> &member_exponent_users) @@ -282,7 +282,7 @@ BitStructType::BitStructType( int physical_type_bits = data_type_bits(physical_type_); int member_total_bits = 0; for (auto i = 0; i < member_types_.size(); ++i) { - QuantIntType *component_qit = nullptr; + const QuantIntType *component_qit = nullptr; if (auto qit = member_types_[i]->cast()) { component_qit = qit; } else if (auto qfxt = member_types_[i]->cast()) { diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 13162d7a1e4bb..4fd7bc79c01e1 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -143,7 +143,7 @@ class TI_DLL_EXPORT DataType { TI_IO_DEF(ptr_); private: - Type *ptr_; + const Type *ptr_; }; // Note that all types are immutable once created. @@ -335,7 +335,7 @@ class TI_DLL_EXPORT QuantIntType : public Type { private: // TODO(type): for now we can uniformly use i32 as the "compute_type". It may // be a good idea to make "compute_type" also customizable. - Type *compute_type_{nullptr}; + const Type *compute_type_{nullptr}; int num_bits_{32}; bool is_signed_{true}; }; @@ -349,7 +349,7 @@ class TI_DLL_EXPORT QuantFixedType : public Type { bool get_is_signed() const; - Type *get_digits_type() { + const Type *get_digits_type() { return digits_type_; } @@ -379,7 +379,7 @@ class TI_DLL_EXPORT QuantFloatType : public Type { std::string to_string() const override; - Type *get_digits_type() { + const Type *get_digits_type() { return digits_type_; } @@ -411,7 +411,7 @@ class TI_DLL_EXPORT BitStructType : public Type { public: BitStructType() : Type(TypeKind::BitStruct){}; BitStructType(PrimitiveType *physical_type, - const std::vector &member_types, + const std::vector &member_types, const std::vector &member_bit_offsets, const std::vector &member_exponents, const std::vector> &member_exponent_users); @@ -457,7 +457,7 @@ class TI_DLL_EXPORT BitStructType : public Type { private: PrimitiveType *physical_type_; - std::vector member_types_; + std::vector member_types_; std::vector member_bit_offsets_; std::vector member_exponents_; std::vector> member_exponent_users_; diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index 5fc0e9deb7e49..c1b607e558267 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -12,7 +12,7 @@ TypeFactory &TypeFactory::get_instance() { TypeFactory::TypeFactory() { } -Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) { +const Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) { std::lock_guard _(primitive_mut_); if (primitive_types_.find(id) == primitive_types_.end()) { @@ -22,7 +22,8 @@ Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) { return primitive_types_[id].get(); } -Type *TypeFactory::get_tensor_type(std::vector shape, Type *element) { +const Type *TypeFactory::get_tensor_type(std::vector shape, + Type *element) { std::lock_guard _(tensor_mut_); auto encode = [](const std::vector &shape) -> std::string { @@ -57,7 +58,8 @@ const Type *TypeFactory::get_struct_type( return struct_types_[key].get(); } -Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) { +const Type *TypeFactory::get_pointer_type(const Type *element, + bool is_bit_pointer) { std::lock_guard _(pointer_mut_); auto key = std::make_pair(element, is_bit_pointer); @@ -68,9 +70,9 @@ Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) { return pointer_types_[key].get(); } -Type *TypeFactory::get_quant_int_type(int num_bits, - bool is_signed, - Type *compute_type) { +const Type *TypeFactory::get_quant_int_type(int num_bits, + bool is_signed, + const Type *compute_type) { std::lock_guard _(quant_int_mut_); auto key = std::make_tuple(num_bits, is_signed, compute_type); @@ -81,9 +83,9 @@ Type *TypeFactory::get_quant_int_type(int num_bits, return quant_int_types_[key].get(); } -Type *TypeFactory::get_quant_fixed_type(Type *digits_type, - Type *compute_type, - float64 scale) { +const Type *TypeFactory::get_quant_fixed_type(const Type *digits_type, + const Type *compute_type, + float64 scale) { std::lock_guard _(quant_fixed_mut_); auto key = std::make_tuple(digits_type, compute_type, scale); @@ -94,9 +96,9 @@ Type *TypeFactory::get_quant_fixed_type(Type *digits_type, return quant_fixed_types_[key].get(); } -Type *TypeFactory::get_quant_float_type(Type *digits_type, - Type *exponent_type, - Type *compute_type) { +const Type *TypeFactory::get_quant_float_type(const Type *digits_type, + const Type *exponent_type, + const Type *compute_type) { std::lock_guard _(quant_float_mut_); auto key = std::make_tuple(digits_type, exponent_type, compute_type); @@ -108,8 +110,8 @@ Type *TypeFactory::get_quant_float_type(Type *digits_type, } BitStructType *TypeFactory::get_bit_struct_type( - PrimitiveType *physical_type, - const std::vector &member_types, + const PrimitiveType *physical_type, + const std::vector &member_types, const std::vector &member_bit_offsets, const std::vector &member_exponents, const std::vector> &member_exponent_users) { @@ -121,9 +123,10 @@ BitStructType *TypeFactory::get_bit_struct_type( return bit_struct_types_.back().get(); } -Type *TypeFactory::get_quant_array_type(PrimitiveType *physical_type, - Type *element_type, - int num_elements) { +const Type *TypeFactory::get_quant_array_type( + const PrimitiveType *physical_type, + const Type *element_type, + int num_elements) { std::lock_guard _(quant_array_mut_); quant_array_types_.push_back(std::make_unique( @@ -131,8 +134,9 @@ Type *TypeFactory::get_quant_array_type(PrimitiveType *physical_type, return quant_array_types_.back().get(); } -PrimitiveType *TypeFactory::get_primitive_int_type(int bits, bool is_signed) { - Type *int_type; +const PrimitiveType *TypeFactory::get_primitive_int_type(int bits, + bool is_signed) { + const Type *int_type; if (bits == 8) { int_type = get_primitive_type(PrimitiveTypeID::i8); } else if (bits == 16) { @@ -150,8 +154,8 @@ PrimitiveType *TypeFactory::get_primitive_int_type(int bits, bool is_signed) { return int_type->cast(); } -PrimitiveType *TypeFactory::get_primitive_real_type(int bits) { - Type *real_type; +const PrimitiveType *TypeFactory::get_primitive_real_type(int bits) { + const Type *real_type; if (bits == 16) { real_type = get_primitive_type(PrimitiveTypeID::f16); } else if (bits == 32) { diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index 292973f652e2a..80fef1878c753 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -14,39 +14,42 @@ class TypeFactory { // TODO(type): maybe it makes sense to let each get_X function return X* // instead of generic Type* - Type *get_primitive_type(PrimitiveTypeID id); + const Type *get_primitive_type(PrimitiveTypeID id); - PrimitiveType *get_primitive_int_type(int bits, bool is_signed = true); + const PrimitiveType *get_primitive_int_type(int bits, bool is_signed = true); - PrimitiveType *get_primitive_real_type(int bits); + const PrimitiveType *get_primitive_real_type(int bits); - Type *get_tensor_type(std::vector shape, Type *element); + const Type *get_tensor_type(std::vector shape, Type *element); const Type *get_struct_type(const std::vector &elements, const std::string &layout = "none"); - Type *get_pointer_type(Type *element, bool is_bit_pointer = false); + const Type *get_pointer_type(const Type *element, + bool is_bit_pointer = false); - Type *get_quant_int_type(int num_bits, bool is_signed, Type *compute_type); + const Type *get_quant_int_type(int num_bits, + bool is_signed, + const Type *compute_type); - Type *get_quant_fixed_type(Type *digits_type, - Type *compute_type, - float64 scale); + const Type *get_quant_fixed_type(const Type *digits_type, + const Type *compute_type, + float64 scale); - Type *get_quant_float_type(Type *digits_type, - Type *exponent_type, - Type *compute_type); + const Type *get_quant_float_type(const Type *digits_type, + const Type *exponent_type, + const Type *compute_type); BitStructType *get_bit_struct_type( - PrimitiveType *physical_type, - const std::vector &member_types, + const PrimitiveType *physical_type, + const std::vector &member_types, const std::vector &member_bit_offsets, const std::vector &member_exponents, const std::vector> &member_exponent_users); - Type *get_quant_array_type(PrimitiveType *physical_type, - Type *element_type, - int num_elements); + const Type *get_quant_array_type(const PrimitiveType *physical_type, + const Type *element_type, + int num_elements); static DataType create_tensor_type(std::vector shape, DataType element); @@ -56,9 +59,9 @@ class TypeFactory { std::unordered_map> primitive_types_; std::mutex primitive_mut_; - std::unordered_map, + std::unordered_map, std::unique_ptr, - hashing::Hasher>> + hashing::Hasher>> tensor_types_; std::mutex tensor_mut_; @@ -70,27 +73,29 @@ class TypeFactory { std::mutex struct_mut_; // TODO: is_bit_ptr? - std::unordered_map, + std::unordered_map, std::unique_ptr, - hashing::Hasher>> + hashing::Hasher>> pointer_types_; std::mutex pointer_mut_; - std::unordered_map, + std::unordered_map, std::unique_ptr, - hashing::Hasher>> + hashing::Hasher>> quant_int_types_; std::mutex quant_int_mut_; - std::unordered_map, - std::unique_ptr, - hashing::Hasher>> + std::unordered_map< + std::tuple, + std::unique_ptr, + hashing::Hasher>> quant_fixed_types_; std::mutex quant_fixed_mut_; - std::unordered_map, - std::unique_ptr, - hashing::Hasher>> + std::unordered_map< + std::tuple, + std::unique_ptr, + hashing::Hasher>> quant_float_types_; std::mutex quant_float_mut_; diff --git a/taichi/ir/type_utils.h b/taichi/ir/type_utils.h index 93d6bfc224b08..69292f340e019 100644 --- a/taichi/ir/type_utils.h +++ b/taichi/ir/type_utils.h @@ -242,7 +242,7 @@ class BitStructTypeBuilder { member_bit_offsets_.push_back(member_total_bits_); member_exponents_.push_back(-1); member_exponent_users_.push_back({}); - QuantIntType *member_qit = nullptr; + const QuantIntType *member_qit = nullptr; if (auto qit = member_type->cast()) { member_qit = qit; } else if (auto qfxt = member_type->cast()) { @@ -260,8 +260,8 @@ class BitStructTypeBuilder { return old_num_members; } - PrimitiveType *physical_type_{nullptr}; - std::vector member_types_; + const PrimitiveType *physical_type_{nullptr}; + std::vector member_types_; std::vector member_bit_offsets_; int member_total_bits_{0}; std::vector member_exponents_;