From f905901ad5d8047f462a21a99f4073e1ade8c431 Mon Sep 17 00:00:00 2001 From: Ris-Bali Date: Wed, 19 Apr 2023 17:20:46 +0530 Subject: [PATCH 1/2] Initial Commit --- taichi/ir/type.cpp | 4 +-- taichi/ir/type.h | 12 ++++----- taichi/ir/type_factory.cpp | 36 +++++++++++++------------- taichi/ir/type_factory.h | 52 +++++++++++++++++++------------------- taichi/ir/type_utils.h | 6 ++--- 5 files changed, 55 insertions(+), 55 deletions(-) diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index eac8111ed3b7a..f4ef60e4da34a 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -266,7 +266,7 @@ bool QuantFloatType::get_is_signed() const { BitStructType::BitStructType( PrimitiveType *physical_type, - const std::vector &member_types, + const std::vector &member_types, const std::vector &member_bit_offsets, const std::vector &member_exponents, const std::vector> &member_exponent_users) @@ -282,7 +282,7 @@ BitStructType::BitStructType( int physical_type_bits = data_type_bits(physical_type_); int member_total_bits = 0; for (auto i = 0; i < member_types_.size(); ++i) { - QuantIntType *component_qit = nullptr; + const QuantIntType *component_qit = nullptr; if (auto qit = member_types_[i]->cast()) { component_qit = qit; } else if (auto qfxt = member_types_[i]->cast()) { diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 13162d7a1e4bb..4fd7bc79c01e1 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -143,7 +143,7 @@ class TI_DLL_EXPORT DataType { TI_IO_DEF(ptr_); private: - Type *ptr_; + const Type *ptr_; }; // Note that all types are immutable once created. @@ -335,7 +335,7 @@ class TI_DLL_EXPORT QuantIntType : public Type { private: // TODO(type): for now we can uniformly use i32 as the "compute_type". It may // be a good idea to make "compute_type" also customizable. - Type *compute_type_{nullptr}; + const Type *compute_type_{nullptr}; int num_bits_{32}; bool is_signed_{true}; }; @@ -349,7 +349,7 @@ class TI_DLL_EXPORT QuantFixedType : public Type { bool get_is_signed() const; - Type *get_digits_type() { + const Type *get_digits_type() { return digits_type_; } @@ -379,7 +379,7 @@ class TI_DLL_EXPORT QuantFloatType : public Type { std::string to_string() const override; - Type *get_digits_type() { + const Type *get_digits_type() { return digits_type_; } @@ -411,7 +411,7 @@ class TI_DLL_EXPORT BitStructType : public Type { public: BitStructType() : Type(TypeKind::BitStruct){}; BitStructType(PrimitiveType *physical_type, - const std::vector &member_types, + const std::vector &member_types, const std::vector &member_bit_offsets, const std::vector &member_exponents, const std::vector> &member_exponent_users); @@ -457,7 +457,7 @@ class TI_DLL_EXPORT BitStructType : public Type { private: PrimitiveType *physical_type_; - std::vector member_types_; + std::vector member_types_; std::vector member_bit_offsets_; std::vector member_exponents_; std::vector> member_exponent_users_; diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index 5fc0e9deb7e49..d726cb2a77b16 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -12,7 +12,7 @@ TypeFactory &TypeFactory::get_instance() { TypeFactory::TypeFactory() { } -Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) { +const Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) { std::lock_guard _(primitive_mut_); if (primitive_types_.find(id) == primitive_types_.end()) { @@ -22,7 +22,7 @@ Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) { return primitive_types_[id].get(); } -Type *TypeFactory::get_tensor_type(std::vector shape, Type *element) { +const Type *TypeFactory::get_tensor_type(std::vector shape, Type *element) { std::lock_guard _(tensor_mut_); auto encode = [](const std::vector &shape) -> std::string { @@ -57,7 +57,7 @@ const Type *TypeFactory::get_struct_type( return struct_types_[key].get(); } -Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) { +const Type *TypeFactory::get_pointer_type(const Type *element, bool is_bit_pointer) { std::lock_guard _(pointer_mut_); auto key = std::make_pair(element, is_bit_pointer); @@ -68,9 +68,9 @@ Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) { return pointer_types_[key].get(); } -Type *TypeFactory::get_quant_int_type(int num_bits, +const Type *TypeFactory::get_quant_int_type(int num_bits, bool is_signed, - Type *compute_type) { + const Type *compute_type) { std::lock_guard _(quant_int_mut_); auto key = std::make_tuple(num_bits, is_signed, compute_type); @@ -81,8 +81,8 @@ Type *TypeFactory::get_quant_int_type(int num_bits, return quant_int_types_[key].get(); } -Type *TypeFactory::get_quant_fixed_type(Type *digits_type, - Type *compute_type, +const Type *TypeFactory::get_quant_fixed_type(const Type *digits_type, + const Type *compute_type, float64 scale) { std::lock_guard _(quant_fixed_mut_); @@ -94,9 +94,9 @@ Type *TypeFactory::get_quant_fixed_type(Type *digits_type, return quant_fixed_types_[key].get(); } -Type *TypeFactory::get_quant_float_type(Type *digits_type, - Type *exponent_type, - Type *compute_type) { +const Type *TypeFactory::get_quant_float_type(const Type *digits_type, + const Type *exponent_type, + const Type *compute_type) { std::lock_guard _(quant_float_mut_); auto key = std::make_tuple(digits_type, exponent_type, compute_type); @@ -108,8 +108,8 @@ Type *TypeFactory::get_quant_float_type(Type *digits_type, } BitStructType *TypeFactory::get_bit_struct_type( - PrimitiveType *physical_type, - const std::vector &member_types, + const PrimitiveType *physical_type, + const std::vector &member_types, const std::vector &member_bit_offsets, const std::vector &member_exponents, const std::vector> &member_exponent_users) { @@ -121,8 +121,8 @@ BitStructType *TypeFactory::get_bit_struct_type( return bit_struct_types_.back().get(); } -Type *TypeFactory::get_quant_array_type(PrimitiveType *physical_type, - Type *element_type, +const Type *TypeFactory::get_quant_array_type(const PrimitiveType *physical_type, + const Type *element_type, int num_elements) { std::lock_guard _(quant_array_mut_); @@ -131,8 +131,8 @@ Type *TypeFactory::get_quant_array_type(PrimitiveType *physical_type, return quant_array_types_.back().get(); } -PrimitiveType *TypeFactory::get_primitive_int_type(int bits, bool is_signed) { - Type *int_type; +const PrimitiveType *TypeFactory::get_primitive_int_type(int bits, bool is_signed) { + const Type *int_type; if (bits == 8) { int_type = get_primitive_type(PrimitiveTypeID::i8); } else if (bits == 16) { @@ -150,8 +150,8 @@ PrimitiveType *TypeFactory::get_primitive_int_type(int bits, bool is_signed) { return int_type->cast(); } -PrimitiveType *TypeFactory::get_primitive_real_type(int bits) { - Type *real_type; +const PrimitiveType *TypeFactory::get_primitive_real_type(int bits) { + const Type *real_type; if (bits == 16) { real_type = get_primitive_type(PrimitiveTypeID::f16); } else if (bits == 32) { diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index 292973f652e2a..1a4f0ec60ac66 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -14,38 +14,38 @@ class TypeFactory { // TODO(type): maybe it makes sense to let each get_X function return X* // instead of generic Type* - Type *get_primitive_type(PrimitiveTypeID id); + const Type *get_primitive_type(PrimitiveTypeID id); - PrimitiveType *get_primitive_int_type(int bits, bool is_signed = true); +const PrimitiveType *get_primitive_int_type(int bits, bool is_signed = true); - PrimitiveType *get_primitive_real_type(int bits); +const PrimitiveType *get_primitive_real_type(int bits); - Type *get_tensor_type(std::vector shape, Type *element); +const Type *get_tensor_type(std::vector shape, Type *element); - const Type *get_struct_type(const std::vector &elements, +const Type *get_struct_type(const std::vector &elements, const std::string &layout = "none"); - Type *get_pointer_type(Type *element, bool is_bit_pointer = false); +const Type *get_pointer_type(const Type *element, bool is_bit_pointer = false); - Type *get_quant_int_type(int num_bits, bool is_signed, Type *compute_type); +const Type *get_quant_int_type(int num_bits, bool is_signed, const Type *compute_type); - Type *get_quant_fixed_type(Type *digits_type, - Type *compute_type, +const Type *get_quant_fixed_type(const Type *digits_type, + const Type *compute_type, float64 scale); - Type *get_quant_float_type(Type *digits_type, - Type *exponent_type, - Type *compute_type); +const Type *get_quant_float_type(const Type *digits_type, + const Type *exponent_type, + const Type *compute_type); BitStructType *get_bit_struct_type( - PrimitiveType *physical_type, - const std::vector &member_types, + const PrimitiveType *physical_type, + const std::vector &member_types, const std::vector &member_bit_offsets, const std::vector &member_exponents, const std::vector> &member_exponent_users); - Type *get_quant_array_type(PrimitiveType *physical_type, - Type *element_type, +const Type *get_quant_array_type(const PrimitiveType *physical_type, + const Type *element_type, int num_elements); static DataType create_tensor_type(std::vector shape, DataType element); @@ -56,9 +56,9 @@ class TypeFactory { std::unordered_map> primitive_types_; std::mutex primitive_mut_; - std::unordered_map, + std::unordered_map, std::unique_ptr, - hashing::Hasher>> + hashing::Hasher>> tensor_types_; std::mutex tensor_mut_; @@ -70,27 +70,27 @@ class TypeFactory { std::mutex struct_mut_; // TODO: is_bit_ptr? - std::unordered_map, + std::unordered_map, std::unique_ptr, - hashing::Hasher>> + hashing::Hasher>> pointer_types_; std::mutex pointer_mut_; - std::unordered_map, + std::unordered_map, std::unique_ptr, - hashing::Hasher>> + hashing::Hasher>> quant_int_types_; std::mutex quant_int_mut_; - std::unordered_map, + std::unordered_map, std::unique_ptr, - hashing::Hasher>> + hashing::Hasher>> quant_fixed_types_; std::mutex quant_fixed_mut_; - std::unordered_map, + std::unordered_map, std::unique_ptr, - hashing::Hasher>> + hashing::Hasher>> quant_float_types_; std::mutex quant_float_mut_; diff --git a/taichi/ir/type_utils.h b/taichi/ir/type_utils.h index 93d6bfc224b08..69292f340e019 100644 --- a/taichi/ir/type_utils.h +++ b/taichi/ir/type_utils.h @@ -242,7 +242,7 @@ class BitStructTypeBuilder { member_bit_offsets_.push_back(member_total_bits_); member_exponents_.push_back(-1); member_exponent_users_.push_back({}); - QuantIntType *member_qit = nullptr; + const QuantIntType *member_qit = nullptr; if (auto qit = member_type->cast()) { member_qit = qit; } else if (auto qfxt = member_type->cast()) { @@ -260,8 +260,8 @@ class BitStructTypeBuilder { return old_num_members; } - PrimitiveType *physical_type_{nullptr}; - std::vector member_types_; + const PrimitiveType *physical_type_{nullptr}; + std::vector member_types_; std::vector member_bit_offsets_; int member_total_bits_{0}; std::vector member_exponents_; From 08659107a5976dc51847e42a5cab10e7c46fa49b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Apr 2023 11:55:03 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/ir/type_factory.cpp | 28 ++++++++++++---------- taichi/ir/type_factory.h | 49 +++++++++++++++++++++----------------- 2 files changed, 43 insertions(+), 34 deletions(-) diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index d726cb2a77b16..c1b607e558267 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -22,7 +22,8 @@ const Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) { return primitive_types_[id].get(); } -const Type *TypeFactory::get_tensor_type(std::vector shape, Type *element) { +const Type *TypeFactory::get_tensor_type(std::vector shape, + Type *element) { std::lock_guard _(tensor_mut_); auto encode = [](const std::vector &shape) -> std::string { @@ -57,7 +58,8 @@ const Type *TypeFactory::get_struct_type( return struct_types_[key].get(); } -const Type *TypeFactory::get_pointer_type(const Type *element, bool is_bit_pointer) { +const Type *TypeFactory::get_pointer_type(const Type *element, + bool is_bit_pointer) { std::lock_guard _(pointer_mut_); auto key = std::make_pair(element, is_bit_pointer); @@ -69,8 +71,8 @@ const Type *TypeFactory::get_pointer_type(const Type *element, bool is_bit_point } const Type *TypeFactory::get_quant_int_type(int num_bits, - bool is_signed, - const Type *compute_type) { + bool is_signed, + const Type *compute_type) { std::lock_guard _(quant_int_mut_); auto key = std::make_tuple(num_bits, is_signed, compute_type); @@ -82,8 +84,8 @@ const Type *TypeFactory::get_quant_int_type(int num_bits, } const Type *TypeFactory::get_quant_fixed_type(const Type *digits_type, - const Type *compute_type, - float64 scale) { + const Type *compute_type, + float64 scale) { std::lock_guard _(quant_fixed_mut_); auto key = std::make_tuple(digits_type, compute_type, scale); @@ -95,8 +97,8 @@ const Type *TypeFactory::get_quant_fixed_type(const Type *digits_type, } const Type *TypeFactory::get_quant_float_type(const Type *digits_type, - const Type *exponent_type, - const Type *compute_type) { + const Type *exponent_type, + const Type *compute_type) { std::lock_guard _(quant_float_mut_); auto key = std::make_tuple(digits_type, exponent_type, compute_type); @@ -121,9 +123,10 @@ BitStructType *TypeFactory::get_bit_struct_type( return bit_struct_types_.back().get(); } -const Type *TypeFactory::get_quant_array_type(const PrimitiveType *physical_type, - const Type *element_type, - int num_elements) { +const Type *TypeFactory::get_quant_array_type( + const PrimitiveType *physical_type, + const Type *element_type, + int num_elements) { std::lock_guard _(quant_array_mut_); quant_array_types_.push_back(std::make_unique( @@ -131,7 +134,8 @@ const Type *TypeFactory::get_quant_array_type(const PrimitiveType *physical_type return quant_array_types_.back().get(); } -const PrimitiveType *TypeFactory::get_primitive_int_type(int bits, bool is_signed) { +const PrimitiveType *TypeFactory::get_primitive_int_type(int bits, + bool is_signed) { const Type *int_type; if (bits == 8) { int_type = get_primitive_type(PrimitiveTypeID::i8); diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index 1a4f0ec60ac66..80fef1878c753 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -14,28 +14,31 @@ class TypeFactory { // TODO(type): maybe it makes sense to let each get_X function return X* // instead of generic Type* - const Type *get_primitive_type(PrimitiveTypeID id); + const Type *get_primitive_type(PrimitiveTypeID id); -const PrimitiveType *get_primitive_int_type(int bits, bool is_signed = true); + const PrimitiveType *get_primitive_int_type(int bits, bool is_signed = true); -const PrimitiveType *get_primitive_real_type(int bits); + const PrimitiveType *get_primitive_real_type(int bits); -const Type *get_tensor_type(std::vector shape, Type *element); + const Type *get_tensor_type(std::vector shape, Type *element); -const Type *get_struct_type(const std::vector &elements, + const Type *get_struct_type(const std::vector &elements, const std::string &layout = "none"); -const Type *get_pointer_type(const Type *element, bool is_bit_pointer = false); + const Type *get_pointer_type(const Type *element, + bool is_bit_pointer = false); -const Type *get_quant_int_type(int num_bits, bool is_signed, const Type *compute_type); + const Type *get_quant_int_type(int num_bits, + bool is_signed, + const Type *compute_type); -const Type *get_quant_fixed_type(const Type *digits_type, - const Type *compute_type, - float64 scale); + const Type *get_quant_fixed_type(const Type *digits_type, + const Type *compute_type, + float64 scale); -const Type *get_quant_float_type(const Type *digits_type, - const Type *exponent_type, - const Type *compute_type); + const Type *get_quant_float_type(const Type *digits_type, + const Type *exponent_type, + const Type *compute_type); BitStructType *get_bit_struct_type( const PrimitiveType *physical_type, @@ -44,9 +47,9 @@ const Type *get_quant_float_type(const Type *digits_type, const std::vector &member_exponents, const std::vector> &member_exponent_users); -const Type *get_quant_array_type(const PrimitiveType *physical_type, - const Type *element_type, - int num_elements); + const Type *get_quant_array_type(const PrimitiveType *physical_type, + const Type *element_type, + int num_elements); static DataType create_tensor_type(std::vector shape, DataType element); @@ -82,15 +85,17 @@ const Type *get_quant_array_type(const PrimitiveType *physical_type, quant_int_types_; std::mutex quant_int_mut_; - std::unordered_map, - std::unique_ptr, - hashing::Hasher>> + std::unordered_map< + std::tuple, + std::unique_ptr, + hashing::Hasher>> quant_fixed_types_; std::mutex quant_fixed_mut_; - std::unordered_map, - std::unique_ptr, - hashing::Hasher>> + std::unordered_map< + std::tuple, + std::unique_ptr, + hashing::Hasher>> quant_float_types_; std::mutex quant_float_mut_;