Skip to content

Commit

Permalink
[ir] Change type maps to unordered maps and add mutexes
Browse files Browse the repository at this point in the history
  • Loading branch information
lin-hitonami committed Dec 28, 2022
1 parent 449a7f6 commit bd5e4b5
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 15 deletions.
18 changes: 17 additions & 1 deletion taichi/ir/type_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ TypeFactory &TypeFactory::get_instance() {
TypeFactory::TypeFactory() {
}


Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) {
std::lock_guard<std::mutex> _(mut_);
std::lock_guard<std::mutex> _(primitive_mut_);

if (primitive_types_.find(id) == primitive_types_.end()) {
primitive_types_[id] = std::make_unique<PrimitiveType>(id);
Expand All @@ -23,6 +24,8 @@ Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) {
}

Type *TypeFactory::get_tensor_type(std::vector<int> shape, Type *element) {
std::lock_guard<std::mutex> _(tensor_mut_);

auto encode = [](const std::vector<int> &shape) -> std::string {
std::string s;
for (int i = 0; i < (int)shape.size(); ++i)
Expand All @@ -38,6 +41,7 @@ Type *TypeFactory::get_tensor_type(std::vector<int> shape, Type *element) {

Type *TypeFactory::get_struct_type(const std::vector<const Type *> &elements) {
std::lock_guard<std::mutex> _(struct_mut_);

if (struct_types_.find(elements) == struct_types_.end()) {
for (const auto &element : elements) {
TI_ASSERT_INFO(
Expand All @@ -51,6 +55,8 @@ Type *TypeFactory::get_struct_type(const std::vector<const Type *> &elements) {
}

Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) {
std::lock_guard<std::mutex> _(pointer_mut_);

auto key = std::make_pair(element, is_bit_pointer);
if (pointer_types_.find(key) == pointer_types_.end()) {
pointer_types_[key] =
Expand All @@ -62,6 +68,8 @@ Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) {
Type *TypeFactory::get_quant_int_type(int num_bits,
bool is_signed,
Type *compute_type) {
std::lock_guard<std::mutex> _(quant_int_mut_);

auto key = std::make_tuple(num_bits, is_signed, compute_type);
if (quant_int_types_.find(key) == quant_int_types_.end()) {
quant_int_types_[key] =
Expand All @@ -73,6 +81,8 @@ Type *TypeFactory::get_quant_int_type(int num_bits,
Type *TypeFactory::get_quant_fixed_type(Type *digits_type,
Type *compute_type,
float64 scale) {
std::lock_guard<std::mutex> _(quant_fixed_mut_);

auto key = std::make_tuple(digits_type, compute_type, scale);
if (quant_fixed_types_.find(key) == quant_fixed_types_.end()) {
quant_fixed_types_[key] =
Expand All @@ -84,6 +94,8 @@ Type *TypeFactory::get_quant_fixed_type(Type *digits_type,
Type *TypeFactory::get_quant_float_type(Type *digits_type,
Type *exponent_type,
Type *compute_type) {
std::lock_guard<std::mutex> _(quant_float_mut_);

auto key = std::make_tuple(digits_type, exponent_type, compute_type);
if (quant_float_types_.find(key) == quant_float_types_.end()) {
quant_float_types_[key] = std::make_unique<QuantFloatType>(
Expand All @@ -98,6 +110,8 @@ BitStructType *TypeFactory::get_bit_struct_type(
const std::vector<int> &member_bit_offsets,
const std::vector<int> &member_exponents,
const std::vector<std::vector<int>> &member_exponent_users) {
std::lock_guard<std::mutex> _(bit_struct_mut_);

bit_struct_types_.push_back(std::make_unique<BitStructType>(
physical_type, member_types, member_bit_offsets, member_exponents,
member_exponent_users));
Expand All @@ -107,6 +121,8 @@ BitStructType *TypeFactory::get_bit_struct_type(
Type *TypeFactory::get_quant_array_type(PrimitiveType *physical_type,
Type *element_type,
int num_elements) {
std::lock_guard<std::mutex> _(quant_array_mut_);

quant_array_types_.push_back(std::make_unique<QuantArrayType>(
physical_type, element_type, num_elements));
return quant_array_types_.back().get();
Expand Down
39 changes: 25 additions & 14 deletions taichi/ir/type_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,13 @@ class TypeFactory {
TypeFactory();

std::unordered_map<PrimitiveTypeID, std::unique_ptr<Type>> primitive_types_;
std::mutex primitive_mut_;

// TODO: use unordered map
std::map<std::pair<int, Type *>, std::unique_ptr<Type>> vector_types_;

// TODO: use unordered map
std::map<std::pair<std::string, Type *>, std::unique_ptr<Type>> tensor_types_;
std::unordered_map<std::pair<std::string, Type *>,
std::unique_ptr<Type>,
hashing::Hasher<std::pair<std::string, Type *>>>
tensor_types_;
std::mutex tensor_mut_;

std::unordered_map<std::vector<const Type *>,
std::unique_ptr<Type>,
Expand All @@ -67,27 +68,37 @@ class TypeFactory {
std::mutex struct_mut_;

// TODO: is_bit_ptr?
std::map<std::pair<Type *, bool>, std::unique_ptr<Type>> pointer_types_;
std::unordered_map<std::pair<Type *, bool>,
std::unique_ptr<Type>,
hashing::Hasher<std::pair<Type *, bool>>>
pointer_types_;
std::mutex pointer_mut_;

// TODO: use unordered map
std::map<std::tuple<int, bool, Type *>, std::unique_ptr<Type>>
std::unordered_map<std::tuple<int, bool, Type *>,
std::unique_ptr<Type>,
hashing::Hasher<std::tuple<int, bool, Type *>>>
quant_int_types_;
std::mutex quant_int_mut_;

// TODO: use unordered map
std::map<std::tuple<Type *, Type *, float64>, std::unique_ptr<Type>>
std::unordered_map<std::tuple<Type *, Type *, float64>,
std::unique_ptr<Type>,
hashing::Hasher<std::tuple<Type *, Type *, float64>>>
quant_fixed_types_;
std::mutex quant_fixed_mut_;

// TODO: use unordered map
std::map<std::tuple<Type *, Type *, Type *>, std::unique_ptr<Type>>
std::unordered_map<std::tuple<Type *, Type *, Type *>,
std::unique_ptr<Type>,
hashing::Hasher<std::tuple<Type *, Type *, Type *>>>
quant_float_types_;
std::mutex quant_float_mut_;

// TODO: avoid duplication
std::vector<std::unique_ptr<BitStructType>> bit_struct_types_;
std::mutex bit_struct_mut_;

// TODO: avoid duplication
std::vector<std::unique_ptr<Type>> quant_array_types_;

std::mutex mut_;
std::mutex quant_array_mut_;
};

DataType promoted_type(DataType a, DataType b);
Expand Down
32 changes: 32 additions & 0 deletions taichi/util/hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,36 @@ struct Hasher<std::vector<T>> {
return ret;
}
};

template <typename T1, typename T2>
struct Hasher<std::pair<T1, T2>> {
public:
size_t operator()(std::pair<T1, T2> const &val) const {
size_t ret = Hasher<T1>{}(val.first);
hash_combine(ret, val.second);
return ret;
}
};

template <typename... Ts>
struct Hasher<std::tuple<Ts...>> {
public:
size_t operator()(std::tuple<Ts...> const &val) const {
return hash<std::tuple_size_v<std::tuple<Ts...>> - 1>(val);
};

private:
template <int N>
size_t hash(std::tuple<Ts...> const &val) const {
size_t ret = hash<N - 1>(val);
hash_combine(ret, std::get<N>(val));
return ret;
}
template <>
size_t hash<0>(std::tuple<Ts...> const &val) const {
return Hasher<std::tuple_element_t<0, std::tuple<Ts...>>>{}(
std::get<0>(val));
}
};

} // namespace taichi::hashing

0 comments on commit bd5e4b5

Please sign in to comment.