Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 19, 2023
1 parent f905901 commit 0865910
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 34 deletions.
28 changes: 16 additions & 12 deletions taichi/ir/type_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ const Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) {
return primitive_types_[id].get();
}

const Type *TypeFactory::get_tensor_type(std::vector<int> shape, Type *element) {
const Type *TypeFactory::get_tensor_type(std::vector<int> shape,
Type *element) {
std::lock_guard<std::mutex> _(tensor_mut_);

auto encode = [](const std::vector<int> &shape) -> std::string {
Expand Down Expand Up @@ -57,7 +58,8 @@ const Type *TypeFactory::get_struct_type(
return struct_types_[key].get();
}

const Type *TypeFactory::get_pointer_type(const Type *element, bool is_bit_pointer) {
const Type *TypeFactory::get_pointer_type(const Type *element,
bool is_bit_pointer) {
std::lock_guard<std::mutex> _(pointer_mut_);

auto key = std::make_pair(element, is_bit_pointer);
Expand All @@ -69,8 +71,8 @@ const Type *TypeFactory::get_pointer_type(const Type *element, bool is_bit_point
}

const Type *TypeFactory::get_quant_int_type(int num_bits,
bool is_signed,
const Type *compute_type) {
bool is_signed,
const Type *compute_type) {
std::lock_guard<std::mutex> _(quant_int_mut_);

auto key = std::make_tuple(num_bits, is_signed, compute_type);
Expand All @@ -82,8 +84,8 @@ const Type *TypeFactory::get_quant_int_type(int num_bits,
}

const Type *TypeFactory::get_quant_fixed_type(const Type *digits_type,
const Type *compute_type,
float64 scale) {
const Type *compute_type,
float64 scale) {
std::lock_guard<std::mutex> _(quant_fixed_mut_);

auto key = std::make_tuple(digits_type, compute_type, scale);
Expand All @@ -95,8 +97,8 @@ const Type *TypeFactory::get_quant_fixed_type(const Type *digits_type,
}

const Type *TypeFactory::get_quant_float_type(const Type *digits_type,
const Type *exponent_type,
const Type *compute_type) {
const Type *exponent_type,
const Type *compute_type) {
std::lock_guard<std::mutex> _(quant_float_mut_);

auto key = std::make_tuple(digits_type, exponent_type, compute_type);
Expand All @@ -121,17 +123,19 @@ BitStructType *TypeFactory::get_bit_struct_type(
return bit_struct_types_.back().get();
}

const Type *TypeFactory::get_quant_array_type(const PrimitiveType *physical_type,
const Type *element_type,
int num_elements) {
const Type *TypeFactory::get_quant_array_type(
const PrimitiveType *physical_type,
const Type *element_type,
int num_elements) {
std::lock_guard<std::mutex> _(quant_array_mut_);

quant_array_types_.push_back(std::make_unique<QuantArrayType>(
physical_type, element_type, num_elements));
return quant_array_types_.back().get();
}

const PrimitiveType *TypeFactory::get_primitive_int_type(int bits, bool is_signed) {
const PrimitiveType *TypeFactory::get_primitive_int_type(int bits,
bool is_signed) {
const Type *int_type;
if (bits == 8) {
int_type = get_primitive_type(PrimitiveTypeID::i8);
Expand Down
49 changes: 27 additions & 22 deletions taichi/ir/type_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,31 @@ class TypeFactory {
// TODO(type): maybe it makes sense to let each get_X function return X*
// instead of generic Type*

const Type *get_primitive_type(PrimitiveTypeID id);
const Type *get_primitive_type(PrimitiveTypeID id);

const PrimitiveType *get_primitive_int_type(int bits, bool is_signed = true);
const PrimitiveType *get_primitive_int_type(int bits, bool is_signed = true);

const PrimitiveType *get_primitive_real_type(int bits);
const PrimitiveType *get_primitive_real_type(int bits);

const Type *get_tensor_type(std::vector<int> shape, Type *element);
const Type *get_tensor_type(std::vector<int> shape, Type *element);

const Type *get_struct_type(const std::vector<StructMember> &elements,
const Type *get_struct_type(const std::vector<StructMember> &elements,
const std::string &layout = "none");

const Type *get_pointer_type(const Type *element, bool is_bit_pointer = false);
const Type *get_pointer_type(const Type *element,
bool is_bit_pointer = false);

const Type *get_quant_int_type(int num_bits, bool is_signed, const Type *compute_type);
const Type *get_quant_int_type(int num_bits,
bool is_signed,
const Type *compute_type);

const Type *get_quant_fixed_type(const Type *digits_type,
const Type *compute_type,
float64 scale);
const Type *get_quant_fixed_type(const Type *digits_type,
const Type *compute_type,
float64 scale);

const Type *get_quant_float_type(const Type *digits_type,
const Type *exponent_type,
const Type *compute_type);
const Type *get_quant_float_type(const Type *digits_type,
const Type *exponent_type,
const Type *compute_type);

BitStructType *get_bit_struct_type(
const PrimitiveType *physical_type,
Expand All @@ -44,9 +47,9 @@ const Type *get_quant_float_type(const Type *digits_type,
const std::vector<int> &member_exponents,
const std::vector<std::vector<int>> &member_exponent_users);

const Type *get_quant_array_type(const PrimitiveType *physical_type,
const Type *element_type,
int num_elements);
const Type *get_quant_array_type(const PrimitiveType *physical_type,
const Type *element_type,
int num_elements);

static DataType create_tensor_type(std::vector<int> shape, DataType element);

Expand Down Expand Up @@ -82,15 +85,17 @@ const Type *get_quant_array_type(const PrimitiveType *physical_type,
quant_int_types_;
std::mutex quant_int_mut_;

std::unordered_map<std::tuple<const Type *, const Type *, float64>,
std::unique_ptr<Type>,
hashing::Hasher<std::tuple<const Type *, const Type *, float64>>>
std::unordered_map<
std::tuple<const Type *, const Type *, float64>,
std::unique_ptr<Type>,
hashing::Hasher<std::tuple<const Type *, const Type *, float64>>>
quant_fixed_types_;
std::mutex quant_fixed_mut_;

std::unordered_map<std::tuple<const Type *, const Type *, const Type *>,
std::unique_ptr<Type>,
hashing::Hasher<std::tuple<const Type *, const Type *, const Type *>>>
std::unordered_map<
std::tuple<const Type *, const Type *, const Type *>,
std::unique_ptr<Type>,
hashing::Hasher<std::tuple<const Type *, const Type *, const Type *>>>
quant_float_types_;
std::mutex quant_float_mut_;

Expand Down

0 comments on commit 0865910

Please sign in to comment.