Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ir] Change type maps to unordered maps and add mutexes #7000

Merged
merged 2 commits into from
Dec 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion taichi/ir/type_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ TypeFactory::TypeFactory() {
}

Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) {
std::lock_guard<std::mutex> _(mut_);
std::lock_guard<std::mutex> _(primitive_mut_);

if (primitive_types_.find(id) == primitive_types_.end()) {
primitive_types_[id] = std::make_unique<PrimitiveType>(id);
Expand All @@ -23,6 +23,8 @@ Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) {
}

Type *TypeFactory::get_tensor_type(std::vector<int> shape, Type *element) {
std::lock_guard<std::mutex> _(tensor_mut_);

auto encode = [](const std::vector<int> &shape) -> std::string {
std::string s;
for (int i = 0; i < (int)shape.size(); ++i)
Expand All @@ -38,6 +40,7 @@ Type *TypeFactory::get_tensor_type(std::vector<int> shape, Type *element) {

Type *TypeFactory::get_struct_type(const std::vector<const Type *> &elements) {
std::lock_guard<std::mutex> _(struct_mut_);

if (struct_types_.find(elements) == struct_types_.end()) {
for (const auto &element : elements) {
TI_ASSERT_INFO(
Expand All @@ -51,6 +54,8 @@ Type *TypeFactory::get_struct_type(const std::vector<const Type *> &elements) {
}

Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) {
std::lock_guard<std::mutex> _(pointer_mut_);

auto key = std::make_pair(element, is_bit_pointer);
if (pointer_types_.find(key) == pointer_types_.end()) {
pointer_types_[key] =
Expand All @@ -62,6 +67,8 @@ Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) {
Type *TypeFactory::get_quant_int_type(int num_bits,
bool is_signed,
Type *compute_type) {
std::lock_guard<std::mutex> _(quant_int_mut_);

auto key = std::make_tuple(num_bits, is_signed, compute_type);
if (quant_int_types_.find(key) == quant_int_types_.end()) {
quant_int_types_[key] =
Expand All @@ -73,6 +80,8 @@ Type *TypeFactory::get_quant_int_type(int num_bits,
Type *TypeFactory::get_quant_fixed_type(Type *digits_type,
Type *compute_type,
float64 scale) {
std::lock_guard<std::mutex> _(quant_fixed_mut_);

auto key = std::make_tuple(digits_type, compute_type, scale);
if (quant_fixed_types_.find(key) == quant_fixed_types_.end()) {
quant_fixed_types_[key] =
Expand All @@ -84,6 +93,8 @@ Type *TypeFactory::get_quant_fixed_type(Type *digits_type,
Type *TypeFactory::get_quant_float_type(Type *digits_type,
Type *exponent_type,
Type *compute_type) {
std::lock_guard<std::mutex> _(quant_float_mut_);

auto key = std::make_tuple(digits_type, exponent_type, compute_type);
if (quant_float_types_.find(key) == quant_float_types_.end()) {
quant_float_types_[key] = std::make_unique<QuantFloatType>(
Expand All @@ -98,6 +109,8 @@ BitStructType *TypeFactory::get_bit_struct_type(
const std::vector<int> &member_bit_offsets,
const std::vector<int> &member_exponents,
const std::vector<std::vector<int>> &member_exponent_users) {
std::lock_guard<std::mutex> _(bit_struct_mut_);

bit_struct_types_.push_back(std::make_unique<BitStructType>(
physical_type, member_types, member_bit_offsets, member_exponents,
member_exponent_users));
Expand All @@ -107,6 +120,8 @@ BitStructType *TypeFactory::get_bit_struct_type(
Type *TypeFactory::get_quant_array_type(PrimitiveType *physical_type,
Type *element_type,
int num_elements) {
std::lock_guard<std::mutex> _(quant_array_mut_);

quant_array_types_.push_back(std::make_unique<QuantArrayType>(
physical_type, element_type, num_elements));
return quant_array_types_.back().get();
Expand Down
39 changes: 25 additions & 14 deletions taichi/ir/type_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,13 @@ class TypeFactory {
TypeFactory();

std::unordered_map<PrimitiveTypeID, std::unique_ptr<Type>> primitive_types_;
std::mutex primitive_mut_;

// TODO: use unordered map
std::map<std::pair<int, Type *>, std::unique_ptr<Type>> vector_types_;

// TODO: use unordered map
std::map<std::pair<std::string, Type *>, std::unique_ptr<Type>> tensor_types_;
std::unordered_map<std::pair<std::string, Type *>,
std::unique_ptr<Type>,
hashing::Hasher<std::pair<std::string, Type *>>>
tensor_types_;
std::mutex tensor_mut_;

std::unordered_map<std::vector<const Type *>,
std::unique_ptr<Type>,
Expand All @@ -67,27 +68,37 @@ class TypeFactory {
std::mutex struct_mut_;

// TODO: is_bit_ptr?
std::map<std::pair<Type *, bool>, std::unique_ptr<Type>> pointer_types_;
std::unordered_map<std::pair<Type *, bool>,
std::unique_ptr<Type>,
hashing::Hasher<std::pair<Type *, bool>>>
pointer_types_;
std::mutex pointer_mut_;

// TODO: use unordered map
std::map<std::tuple<int, bool, Type *>, std::unique_ptr<Type>>
std::unordered_map<std::tuple<int, bool, Type *>,
std::unique_ptr<Type>,
hashing::Hasher<std::tuple<int, bool, Type *>>>
quant_int_types_;
std::mutex quant_int_mut_;

// TODO: use unordered map
std::map<std::tuple<Type *, Type *, float64>, std::unique_ptr<Type>>
std::unordered_map<std::tuple<Type *, Type *, float64>,
std::unique_ptr<Type>,
hashing::Hasher<std::tuple<Type *, Type *, float64>>>
quant_fixed_types_;
std::mutex quant_fixed_mut_;

// TODO: use unordered map
std::map<std::tuple<Type *, Type *, Type *>, std::unique_ptr<Type>>
std::unordered_map<std::tuple<Type *, Type *, Type *>,
std::unique_ptr<Type>,
hashing::Hasher<std::tuple<Type *, Type *, Type *>>>
quant_float_types_;
std::mutex quant_float_mut_;

// TODO: avoid duplication
std::vector<std::unique_ptr<BitStructType>> bit_struct_types_;
std::mutex bit_struct_mut_;

// TODO: avoid duplication
std::vector<std::unique_ptr<Type>> quant_array_types_;

std::mutex mut_;
std::mutex quant_array_mut_;
};

DataType promoted_type(DataType a, DataType b);
Expand Down
32 changes: 32 additions & 0 deletions taichi/util/hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,36 @@ struct Hasher<std::vector<T>> {
return ret;
}
};

template <typename T1, typename T2>
struct Hasher<std::pair<T1, T2>> {
public:
size_t operator()(std::pair<T1, T2> const &val) const {
size_t ret = Hasher<T1>{}(val.first);
hash_combine(ret, val.second);
return ret;
}
};

template <typename... Ts>
struct Hasher<std::tuple<Ts...>> {
public:
size_t operator()(std::tuple<Ts...> const &val) const {
return hash<std::tuple_size_v<std::tuple<Ts...>> - 1>(val);
};

private:
template <int N>
size_t hash(std::tuple<Ts...> const &val) const {
size_t ret = hash<N - 1>(val);
hash_combine(ret, std::get<N>(val));
return ret;
}
template <>
size_t hash<0>(std::tuple<Ts...> const &val) const {
return Hasher<std::tuple_element_t<0, std::tuple<Ts...>>>{}(
std::get<0>(val));
}
};

} // namespace taichi::hashing