Skip to content

Commit

Permalink
Initial Commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Ris-Bali committed Apr 19, 2023
1 parent 1d8af57 commit f905901
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 55 deletions.
4 changes: 2 additions & 2 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ bool QuantFloatType::get_is_signed() const {

BitStructType::BitStructType(
PrimitiveType *physical_type,
const std::vector<Type *> &member_types,
const std::vector<const Type *> &member_types,
const std::vector<int> &member_bit_offsets,
const std::vector<int> &member_exponents,
const std::vector<std::vector<int>> &member_exponent_users)
Expand All @@ -282,7 +282,7 @@ BitStructType::BitStructType(
int physical_type_bits = data_type_bits(physical_type_);
int member_total_bits = 0;
for (auto i = 0; i < member_types_.size(); ++i) {
QuantIntType *component_qit = nullptr;
const QuantIntType *component_qit = nullptr;
if (auto qit = member_types_[i]->cast<QuantIntType>()) {
component_qit = qit;
} else if (auto qfxt = member_types_[i]->cast<QuantFixedType>()) {
Expand Down
12 changes: 6 additions & 6 deletions taichi/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class TI_DLL_EXPORT DataType {
TI_IO_DEF(ptr_);

private:
Type *ptr_;
const Type *ptr_;
};

// Note that all types are immutable once created.
Expand Down Expand Up @@ -335,7 +335,7 @@ class TI_DLL_EXPORT QuantIntType : public Type {
private:
// TODO(type): for now we can uniformly use i32 as the "compute_type". It may
// be a good idea to make "compute_type" also customizable.
Type *compute_type_{nullptr};
const Type *compute_type_{nullptr};
int num_bits_{32};
bool is_signed_{true};
};
Expand All @@ -349,7 +349,7 @@ class TI_DLL_EXPORT QuantFixedType : public Type {

bool get_is_signed() const;

Type *get_digits_type() {
const Type *get_digits_type() {
return digits_type_;
}

Expand Down Expand Up @@ -379,7 +379,7 @@ class TI_DLL_EXPORT QuantFloatType : public Type {

std::string to_string() const override;

Type *get_digits_type() {
const Type *get_digits_type() {
return digits_type_;
}

Expand Down Expand Up @@ -411,7 +411,7 @@ class TI_DLL_EXPORT BitStructType : public Type {
public:
BitStructType() : Type(TypeKind::BitStruct){};
BitStructType(PrimitiveType *physical_type,
const std::vector<Type *> &member_types,
const std::vector<const Type *> &member_types,
const std::vector<int> &member_bit_offsets,
const std::vector<int> &member_exponents,
const std::vector<std::vector<int>> &member_exponent_users);
Expand Down Expand Up @@ -457,7 +457,7 @@ class TI_DLL_EXPORT BitStructType : public Type {

private:
PrimitiveType *physical_type_;
std::vector<Type *> member_types_;
std::vector<const Type *> member_types_;
std::vector<int> member_bit_offsets_;
std::vector<int> member_exponents_;
std::vector<std::vector<int>> member_exponent_users_;
Expand Down
36 changes: 18 additions & 18 deletions taichi/ir/type_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ TypeFactory &TypeFactory::get_instance() {
TypeFactory::TypeFactory() {
}

Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) {
const Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) {
std::lock_guard<std::mutex> _(primitive_mut_);

if (primitive_types_.find(id) == primitive_types_.end()) {
Expand All @@ -22,7 +22,7 @@ Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) {
return primitive_types_[id].get();
}

Type *TypeFactory::get_tensor_type(std::vector<int> shape, Type *element) {
const Type *TypeFactory::get_tensor_type(std::vector<int> shape, Type *element) {
std::lock_guard<std::mutex> _(tensor_mut_);

auto encode = [](const std::vector<int> &shape) -> std::string {
Expand Down Expand Up @@ -57,7 +57,7 @@ const Type *TypeFactory::get_struct_type(
return struct_types_[key].get();
}

Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) {
const Type *TypeFactory::get_pointer_type(const Type *element, bool is_bit_pointer) {
std::lock_guard<std::mutex> _(pointer_mut_);

auto key = std::make_pair(element, is_bit_pointer);
Expand All @@ -68,9 +68,9 @@ Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) {
return pointer_types_[key].get();
}

Type *TypeFactory::get_quant_int_type(int num_bits,
const Type *TypeFactory::get_quant_int_type(int num_bits,
bool is_signed,
Type *compute_type) {
const Type *compute_type) {
std::lock_guard<std::mutex> _(quant_int_mut_);

auto key = std::make_tuple(num_bits, is_signed, compute_type);
Expand All @@ -81,8 +81,8 @@ Type *TypeFactory::get_quant_int_type(int num_bits,
return quant_int_types_[key].get();
}

Type *TypeFactory::get_quant_fixed_type(Type *digits_type,
Type *compute_type,
const Type *TypeFactory::get_quant_fixed_type(const Type *digits_type,
const Type *compute_type,
float64 scale) {
std::lock_guard<std::mutex> _(quant_fixed_mut_);

Expand All @@ -94,9 +94,9 @@ Type *TypeFactory::get_quant_fixed_type(Type *digits_type,
return quant_fixed_types_[key].get();
}

Type *TypeFactory::get_quant_float_type(Type *digits_type,
Type *exponent_type,
Type *compute_type) {
const Type *TypeFactory::get_quant_float_type(const Type *digits_type,
const Type *exponent_type,
const Type *compute_type) {
std::lock_guard<std::mutex> _(quant_float_mut_);

auto key = std::make_tuple(digits_type, exponent_type, compute_type);
Expand All @@ -108,8 +108,8 @@ Type *TypeFactory::get_quant_float_type(Type *digits_type,
}

BitStructType *TypeFactory::get_bit_struct_type(
PrimitiveType *physical_type,
const std::vector<Type *> &member_types,
const PrimitiveType *physical_type,
const std::vector<const Type *> &member_types,
const std::vector<int> &member_bit_offsets,
const std::vector<int> &member_exponents,
const std::vector<std::vector<int>> &member_exponent_users) {
Expand All @@ -121,8 +121,8 @@ BitStructType *TypeFactory::get_bit_struct_type(
return bit_struct_types_.back().get();
}

Type *TypeFactory::get_quant_array_type(PrimitiveType *physical_type,
Type *element_type,
const Type *TypeFactory::get_quant_array_type(const PrimitiveType *physical_type,
const Type *element_type,
int num_elements) {
std::lock_guard<std::mutex> _(quant_array_mut_);

Expand All @@ -131,8 +131,8 @@ Type *TypeFactory::get_quant_array_type(PrimitiveType *physical_type,
return quant_array_types_.back().get();
}

PrimitiveType *TypeFactory::get_primitive_int_type(int bits, bool is_signed) {
Type *int_type;
const PrimitiveType *TypeFactory::get_primitive_int_type(int bits, bool is_signed) {
const Type *int_type;
if (bits == 8) {
int_type = get_primitive_type(PrimitiveTypeID::i8);
} else if (bits == 16) {
Expand All @@ -150,8 +150,8 @@ PrimitiveType *TypeFactory::get_primitive_int_type(int bits, bool is_signed) {
return int_type->cast<PrimitiveType>();
}

PrimitiveType *TypeFactory::get_primitive_real_type(int bits) {
Type *real_type;
const PrimitiveType *TypeFactory::get_primitive_real_type(int bits) {
const Type *real_type;
if (bits == 16) {
real_type = get_primitive_type(PrimitiveTypeID::f16);
} else if (bits == 32) {
Expand Down
52 changes: 26 additions & 26 deletions taichi/ir/type_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,38 @@ class TypeFactory {
// TODO(type): maybe it makes sense to let each get_X function return X*
// instead of generic Type*

Type *get_primitive_type(PrimitiveTypeID id);
const Type *get_primitive_type(PrimitiveTypeID id);

PrimitiveType *get_primitive_int_type(int bits, bool is_signed = true);
const PrimitiveType *get_primitive_int_type(int bits, bool is_signed = true);

PrimitiveType *get_primitive_real_type(int bits);
const PrimitiveType *get_primitive_real_type(int bits);

Type *get_tensor_type(std::vector<int> shape, Type *element);
const Type *get_tensor_type(std::vector<int> shape, Type *element);

const Type *get_struct_type(const std::vector<StructMember> &elements,
const Type *get_struct_type(const std::vector<StructMember> &elements,
const std::string &layout = "none");

Type *get_pointer_type(Type *element, bool is_bit_pointer = false);
const Type *get_pointer_type(const Type *element, bool is_bit_pointer = false);

Type *get_quant_int_type(int num_bits, bool is_signed, Type *compute_type);
const Type *get_quant_int_type(int num_bits, bool is_signed, const Type *compute_type);

Type *get_quant_fixed_type(Type *digits_type,
Type *compute_type,
const Type *get_quant_fixed_type(const Type *digits_type,
const Type *compute_type,
float64 scale);

Type *get_quant_float_type(Type *digits_type,
Type *exponent_type,
Type *compute_type);
const Type *get_quant_float_type(const Type *digits_type,
const Type *exponent_type,
const Type *compute_type);

BitStructType *get_bit_struct_type(
PrimitiveType *physical_type,
const std::vector<Type *> &member_types,
const PrimitiveType *physical_type,
const std::vector<const Type *> &member_types,
const std::vector<int> &member_bit_offsets,
const std::vector<int> &member_exponents,
const std::vector<std::vector<int>> &member_exponent_users);

Type *get_quant_array_type(PrimitiveType *physical_type,
Type *element_type,
const Type *get_quant_array_type(const PrimitiveType *physical_type,
const Type *element_type,
int num_elements);

static DataType create_tensor_type(std::vector<int> shape, DataType element);
Expand All @@ -56,9 +56,9 @@ class TypeFactory {
std::unordered_map<PrimitiveTypeID, std::unique_ptr<Type>> primitive_types_;
std::mutex primitive_mut_;

std::unordered_map<std::pair<std::string, Type *>,
std::unordered_map<std::pair<std::string, const Type *>,
std::unique_ptr<Type>,
hashing::Hasher<std::pair<std::string, Type *>>>
hashing::Hasher<std::pair<std::string, const Type *>>>
tensor_types_;
std::mutex tensor_mut_;

Expand All @@ -70,27 +70,27 @@ class TypeFactory {
std::mutex struct_mut_;

// TODO: is_bit_ptr?
std::unordered_map<std::pair<Type *, bool>,
std::unordered_map<std::pair<const Type *, bool>,
std::unique_ptr<Type>,
hashing::Hasher<std::pair<Type *, bool>>>
hashing::Hasher<std::pair<const Type *, bool>>>
pointer_types_;
std::mutex pointer_mut_;

std::unordered_map<std::tuple<int, bool, Type *>,
std::unordered_map<std::tuple<int, bool, const Type *>,
std::unique_ptr<Type>,
hashing::Hasher<std::tuple<int, bool, Type *>>>
hashing::Hasher<std::tuple<int, bool, const Type *>>>
quant_int_types_;
std::mutex quant_int_mut_;

std::unordered_map<std::tuple<Type *, Type *, float64>,
std::unordered_map<std::tuple<const Type *, const Type *, float64>,
std::unique_ptr<Type>,
hashing::Hasher<std::tuple<Type *, Type *, float64>>>
hashing::Hasher<std::tuple<const Type *, const Type *, float64>>>
quant_fixed_types_;
std::mutex quant_fixed_mut_;

std::unordered_map<std::tuple<Type *, Type *, Type *>,
std::unordered_map<std::tuple<const Type *, const Type *, const Type *>,
std::unique_ptr<Type>,
hashing::Hasher<std::tuple<Type *, Type *, Type *>>>
hashing::Hasher<std::tuple<const Type *, const Type *, const Type *>>>
quant_float_types_;
std::mutex quant_float_mut_;

Expand Down
6 changes: 3 additions & 3 deletions taichi/ir/type_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class BitStructTypeBuilder {
member_bit_offsets_.push_back(member_total_bits_);
member_exponents_.push_back(-1);
member_exponent_users_.push_back({});
QuantIntType *member_qit = nullptr;
const QuantIntType *member_qit = nullptr;
if (auto qit = member_type->cast<QuantIntType>()) {
member_qit = qit;
} else if (auto qfxt = member_type->cast<QuantFixedType>()) {
Expand All @@ -260,8 +260,8 @@ class BitStructTypeBuilder {
return old_num_members;
}

PrimitiveType *physical_type_{nullptr};
std::vector<Type *> member_types_;
const PrimitiveType *physical_type_{nullptr};
std::vector<const Type *> member_types_;
std::vector<int> member_bit_offsets_;
int member_total_bits_{0};
std::vector<int> member_exponents_;
Expand Down

0 comments on commit f905901

Please sign in to comment.