diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index d726cb2a77b16..c1b607e558267 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -22,7 +22,8 @@ const Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) { return primitive_types_[id].get(); } -const Type *TypeFactory::get_tensor_type(std::vector shape, Type *element) { +const Type *TypeFactory::get_tensor_type(std::vector shape, + Type *element) { std::lock_guard _(tensor_mut_); auto encode = [](const std::vector &shape) -> std::string { @@ -57,7 +58,8 @@ const Type *TypeFactory::get_struct_type( return struct_types_[key].get(); } -const Type *TypeFactory::get_pointer_type(const Type *element, bool is_bit_pointer) { +const Type *TypeFactory::get_pointer_type(const Type *element, + bool is_bit_pointer) { std::lock_guard _(pointer_mut_); auto key = std::make_pair(element, is_bit_pointer); @@ -69,8 +71,8 @@ const Type *TypeFactory::get_pointer_type(const Type *element, bool is_bit_point } const Type *TypeFactory::get_quant_int_type(int num_bits, - bool is_signed, - const Type *compute_type) { + bool is_signed, + const Type *compute_type) { std::lock_guard _(quant_int_mut_); auto key = std::make_tuple(num_bits, is_signed, compute_type); @@ -82,8 +84,8 @@ const Type *TypeFactory::get_quant_int_type(int num_bits, } const Type *TypeFactory::get_quant_fixed_type(const Type *digits_type, - const Type *compute_type, - float64 scale) { + const Type *compute_type, + float64 scale) { std::lock_guard _(quant_fixed_mut_); auto key = std::make_tuple(digits_type, compute_type, scale); @@ -95,8 +97,8 @@ const Type *TypeFactory::get_quant_fixed_type(const Type *digits_type, } const Type *TypeFactory::get_quant_float_type(const Type *digits_type, - const Type *exponent_type, - const Type *compute_type) { + const Type *exponent_type, + const Type *compute_type) { std::lock_guard _(quant_float_mut_); auto key = std::make_tuple(digits_type, exponent_type, compute_type); @@ -121,9 +123,10 @@ BitStructType *TypeFactory::get_bit_struct_type( return bit_struct_types_.back().get(); } -const Type *TypeFactory::get_quant_array_type(const PrimitiveType *physical_type, - const Type *element_type, - int num_elements) { +const Type *TypeFactory::get_quant_array_type( + const PrimitiveType *physical_type, + const Type *element_type, + int num_elements) { std::lock_guard _(quant_array_mut_); quant_array_types_.push_back(std::make_unique( @@ -131,7 +134,8 @@ const Type *TypeFactory::get_quant_array_type(const PrimitiveType *physical_type return quant_array_types_.back().get(); } -const PrimitiveType *TypeFactory::get_primitive_int_type(int bits, bool is_signed) { +const PrimitiveType *TypeFactory::get_primitive_int_type(int bits, + bool is_signed) { const Type *int_type; if (bits == 8) { int_type = get_primitive_type(PrimitiveTypeID::i8); diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index 1a4f0ec60ac66..80fef1878c753 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -14,28 +14,31 @@ class TypeFactory { // TODO(type): maybe it makes sense to let each get_X function return X* // instead of generic Type* - const Type *get_primitive_type(PrimitiveTypeID id); + const Type *get_primitive_type(PrimitiveTypeID id); -const PrimitiveType *get_primitive_int_type(int bits, bool is_signed = true); + const PrimitiveType *get_primitive_int_type(int bits, bool is_signed = true); -const PrimitiveType *get_primitive_real_type(int bits); + const PrimitiveType *get_primitive_real_type(int bits); -const Type *get_tensor_type(std::vector shape, Type *element); + const Type *get_tensor_type(std::vector shape, Type *element); -const Type *get_struct_type(const std::vector &elements, + const Type *get_struct_type(const std::vector &elements, const std::string &layout = "none"); -const Type *get_pointer_type(const Type *element, bool is_bit_pointer = false); + const Type *get_pointer_type(const Type *element, + bool is_bit_pointer = false); -const Type *get_quant_int_type(int num_bits, bool is_signed, const Type *compute_type); + const Type *get_quant_int_type(int num_bits, + bool is_signed, + const Type *compute_type); -const Type *get_quant_fixed_type(const Type *digits_type, - const Type *compute_type, - float64 scale); + const Type *get_quant_fixed_type(const Type *digits_type, + const Type *compute_type, + float64 scale); -const Type *get_quant_float_type(const Type *digits_type, - const Type *exponent_type, - const Type *compute_type); + const Type *get_quant_float_type(const Type *digits_type, + const Type *exponent_type, + const Type *compute_type); BitStructType *get_bit_struct_type( const PrimitiveType *physical_type, @@ -44,9 +47,9 @@ const Type *get_quant_float_type(const Type *digits_type, const std::vector &member_exponents, const std::vector> &member_exponent_users); -const Type *get_quant_array_type(const PrimitiveType *physical_type, - const Type *element_type, - int num_elements); + const Type *get_quant_array_type(const PrimitiveType *physical_type, + const Type *element_type, + int num_elements); static DataType create_tensor_type(std::vector shape, DataType element); @@ -82,15 +85,17 @@ const Type *get_quant_array_type(const PrimitiveType *physical_type, quant_int_types_; std::mutex quant_int_mut_; - std::unordered_map, - std::unique_ptr, - hashing::Hasher>> + std::unordered_map< + std::tuple, + std::unique_ptr, + hashing::Hasher>> quant_fixed_types_; std::mutex quant_fixed_mut_; - std::unordered_map, - std::unique_ptr, - hashing::Hasher>> + std::unordered_map< + std::tuple, + std::unique_ptr, + hashing::Hasher>> quant_float_types_; std::mutex quant_float_mut_;