diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index 1ecb88686f5fe..40882dea5ccb4 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -12,8 +12,9 @@ TypeFactory &TypeFactory::get_instance() { TypeFactory::TypeFactory() { } + Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) { - std::lock_guard _(mut_); + std::lock_guard _(primitive_mut_); if (primitive_types_.find(id) == primitive_types_.end()) { primitive_types_[id] = std::make_unique(id); @@ -23,6 +24,8 @@ Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) { } Type *TypeFactory::get_tensor_type(std::vector shape, Type *element) { + std::lock_guard _(tensor_mut_); + auto encode = [](const std::vector &shape) -> std::string { std::string s; for (int i = 0; i < (int)shape.size(); ++i) @@ -38,6 +41,7 @@ Type *TypeFactory::get_tensor_type(std::vector shape, Type *element) { Type *TypeFactory::get_struct_type(const std::vector &elements) { std::lock_guard _(struct_mut_); + if (struct_types_.find(elements) == struct_types_.end()) { for (const auto &element : elements) { TI_ASSERT_INFO( @@ -51,6 +55,8 @@ Type *TypeFactory::get_struct_type(const std::vector &elements) { } Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) { + std::lock_guard _(pointer_mut_); + auto key = std::make_pair(element, is_bit_pointer); if (pointer_types_.find(key) == pointer_types_.end()) { pointer_types_[key] = @@ -62,6 +68,8 @@ Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) { Type *TypeFactory::get_quant_int_type(int num_bits, bool is_signed, Type *compute_type) { + std::lock_guard _(quant_int_mut_); + auto key = std::make_tuple(num_bits, is_signed, compute_type); if (quant_int_types_.find(key) == quant_int_types_.end()) { quant_int_types_[key] = @@ -73,6 +81,8 @@ Type *TypeFactory::get_quant_int_type(int num_bits, Type *TypeFactory::get_quant_fixed_type(Type *digits_type, Type *compute_type, float64 scale) { + std::lock_guard _(quant_fixed_mut_); + auto key = std::make_tuple(digits_type, compute_type, scale); if (quant_fixed_types_.find(key) == quant_fixed_types_.end()) { quant_fixed_types_[key] = @@ -84,6 +94,8 @@ Type *TypeFactory::get_quant_fixed_type(Type *digits_type, Type *TypeFactory::get_quant_float_type(Type *digits_type, Type *exponent_type, Type *compute_type) { + std::lock_guard _(quant_float_mut_); + auto key = std::make_tuple(digits_type, exponent_type, compute_type); if (quant_float_types_.find(key) == quant_float_types_.end()) { quant_float_types_[key] = std::make_unique( @@ -98,6 +110,8 @@ BitStructType *TypeFactory::get_bit_struct_type( const std::vector &member_bit_offsets, const std::vector &member_exponents, const std::vector> &member_exponent_users) { + std::lock_guard _(bit_struct_mut_); + bit_struct_types_.push_back(std::make_unique( physical_type, member_types, member_bit_offsets, member_exponents, member_exponent_users)); @@ -107,6 +121,8 @@ BitStructType *TypeFactory::get_bit_struct_type( Type *TypeFactory::get_quant_array_type(PrimitiveType *physical_type, Type *element_type, int num_elements) { + std::lock_guard _(quant_array_mut_); + quant_array_types_.push_back(std::make_unique( physical_type, element_type, num_elements)); return quant_array_types_.back().get(); diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index 9df35d3ba394c..d413c8385100d 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -53,12 +53,13 @@ class TypeFactory { TypeFactory(); std::unordered_map> primitive_types_; + std::mutex primitive_mut_; - // TODO: use unordered map - std::map, std::unique_ptr> vector_types_; - - // TODO: use unordered map - std::map, std::unique_ptr> tensor_types_; + std::unordered_map, + std::unique_ptr, + hashing::Hasher>> + tensor_types_; + std::mutex tensor_mut_; std::unordered_map, std::unique_ptr, @@ -67,27 +68,37 @@ class TypeFactory { std::mutex struct_mut_; // TODO: is_bit_ptr? - std::map, std::unique_ptr> pointer_types_; + std::unordered_map, + std::unique_ptr, + hashing::Hasher>> + pointer_types_; + std::mutex pointer_mut_; - // TODO: use unordered map - std::map, std::unique_ptr> + std::unordered_map, + std::unique_ptr, + hashing::Hasher>> quant_int_types_; + std::mutex quant_int_mut_; - // TODO: use unordered map - std::map, std::unique_ptr> + std::unordered_map, + std::unique_ptr, + hashing::Hasher>> quant_fixed_types_; + std::mutex quant_fixed_mut_; - // TODO: use unordered map - std::map, std::unique_ptr> + std::unordered_map, + std::unique_ptr, + hashing::Hasher>> quant_float_types_; + std::mutex quant_float_mut_; // TODO: avoid duplication std::vector> bit_struct_types_; + std::mutex bit_struct_mut_; // TODO: avoid duplication std::vector> quant_array_types_; - - std::mutex mut_; + std::mutex quant_array_mut_; }; DataType promoted_type(DataType a, DataType b); diff --git a/taichi/util/hash.h b/taichi/util/hash.h index e6042d0714a05..74c4363204851 100644 --- a/taichi/util/hash.h +++ b/taichi/util/hash.h @@ -33,4 +33,36 @@ struct Hasher> { return ret; } }; + +template +struct Hasher> { + public: + size_t operator()(std::pair const &val) const { + size_t ret = Hasher{}(val.first); + hash_combine(ret, val.second); + return ret; + } +}; + +template +struct Hasher> { + public: + size_t operator()(std::tuple const &val) const { + return hash> - 1>(val); + }; + + private: + template + size_t hash(std::tuple const &val) const { + size_t ret = hash(val); + hash_combine(ret, std::get(val)); + return ret; + } + template <> + size_t hash<0>(std::tuple const &val) const { + return Hasher>>{}( + std::get<0>(val)); + } +}; + } // namespace taichi::hashing