Skip to content

Commit

Permalink
[ir] Add struct type to CHI-IR (#6982)
Browse files Browse the repository at this point in the history
Issue: #6983 

### Brief Summary

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lin-hitonami and pre-commit-ci[bot] authored Dec 28, 2022
1 parent 59d8909 commit 449a7f6
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 1 deletion.
25 changes: 25 additions & 0 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,31 @@ std::string TensorType::to_string() const {
return s;
}

std::string StructType::to_string() const {
std::string s = "struct(";
for (int i = 0; i < elements_.size(); i++) {
if (i) {
s += ", ";
}
s += std::to_string(i) + ": " + elements_[i]->to_string();
}
s += ")";
return s;
}

Type *StructType::get_element_type(const std::vector<int> &indices) const {
const Type *type_now = this;
for (auto ind : indices) {
if (auto tensor_type = type_now->cast<TensorType>()) {
TI_ASSERT(ind < tensor_type->get_num_elements())
type_now = tensor_type->get_element_type();
} else {
type_now = type_now->as<StructType>()->elements_[ind];
}
}
return (Type *)type_now;
}

bool Type::is_primitive(PrimitiveTypeID type) const {
if (auto p = cast<PrimitiveType>()) {
return p->type == type;
Expand Down
46 changes: 45 additions & 1 deletion taichi/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ class TI_DLL_EXPORT Type {
return p;
}

template <typename T>
const T *as() const {
auto p = dynamic_cast<const T *>(this);
TI_ASSERT_INFO(p != nullptr, "Cannot treat {} as {}", this->to_string(),
typeid(T).name());
return p;
}

bool is_primitive(PrimitiveTypeID type) const;

virtual Type *get_compute_type() {
Expand All @@ -56,7 +64,7 @@ class TI_DLL_EXPORT DataType {
DataType();

// NOLINTNEXTLINE(google-explicit-constructor)
DataType(Type *ptr) : ptr_(ptr) {
DataType(const Type *ptr) : ptr_((Type *)ptr) {
}

DataType(const DataType &o) : ptr_(o.ptr_) {
Expand Down Expand Up @@ -193,6 +201,42 @@ class TensorType : public Type {
Type *element_{nullptr};
};

class StructType : public Type {
public:
explicit StructType(std::vector<const Type *> elements)
: elements_(std::move(elements)) {
}

std::string to_string() const override;

Type *get_element_type(const std::vector<int> &indices) const;
const std::vector<const Type *> &elements() const {
return elements_;
}

int get_num_elements() const {
int num = 0;
for (const auto &element : elements_) {
if (auto struct_type = element->cast<StructType>()) {
num += struct_type->get_num_elements();
} else if (auto tensor_type = element->cast<TensorType>()) {
num += tensor_type->get_num_elements();
} else {
TI_ASSERT(element->is<PrimitiveType>());
num += 1;
}
}
return num;
}

Type *get_compute_type() override {
return this;
}

private:
std::vector<const Type *> elements_;
};

class QuantIntType : public Type {
public:
QuantIntType(int num_bits, bool is_signed, Type *compute_type = nullptr);
Expand Down
14 changes: 14 additions & 0 deletions taichi/ir/type_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ Type *TypeFactory::get_tensor_type(std::vector<int> shape, Type *element) {
return tensor_types_[key].get();
}

Type *TypeFactory::get_struct_type(const std::vector<const Type *> &elements) {
std::lock_guard<std::mutex> _(struct_mut_);
if (struct_types_.find(elements) == struct_types_.end()) {
for (const auto &element : elements) {
TI_ASSERT_INFO(
element->is<PrimitiveType>() || element->is<TensorType>() ||
element->is<StructType>() || element->is<PointerType>(),
"Unsupported struct element type: " + element->to_string());
}
struct_types_[elements] = std::make_unique<StructType>(elements);
}
return struct_types_[elements].get();
}

Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) {
auto key = std::make_pair(element, is_bit_pointer);
if (pointer_types_.find(key) == pointer_types_.end()) {
Expand Down
9 changes: 9 additions & 0 deletions taichi/ir/type_factory.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "taichi/ir/type.h"
#include "taichi/util/hash.h"

#include <mutex>

Expand All @@ -21,6 +22,8 @@ class TypeFactory {

Type *get_tensor_type(std::vector<int> shape, Type *element);

Type *get_struct_type(const std::vector<const Type *> &elements);

Type *get_pointer_type(Type *element, bool is_bit_pointer = false);

Type *get_quant_int_type(int num_bits, bool is_signed, Type *compute_type);
Expand Down Expand Up @@ -57,6 +60,12 @@ class TypeFactory {
// TODO: use unordered map
std::map<std::pair<std::string, Type *>, std::unique_ptr<Type>> tensor_types_;

std::unordered_map<std::vector<const Type *>,
std::unique_ptr<Type>,
hashing::Hasher<std::vector<const Type *>>>
struct_types_;
std::mutex struct_mut_;

// TODO: is_bit_ptr?
std::map<std::pair<Type *, bool>, std::unique_ptr<Type>> pointer_types_;

Expand Down
10 changes: 10 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1151,6 +1151,16 @@ void export_lang(py::module &m) {
const DataType &element_type) {
return factory->create_tensor_type(shape, element_type);
},
py::return_value_policy::reference)
.def(
"get_struct_type",
[&](TypeFactory *factory, std::vector<DataType> elements) {
std::vector<const Type *> types;
for (auto &element : elements) {
types.push_back(element);
}
return DataType(factory->get_struct_type(types));
},
py::return_value_policy::reference);

m.def("get_type_factory_instance", TypeFactory::get_instance,
Expand Down
7 changes: 7 additions & 0 deletions taichi/runtime/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) {
/*scalable=*/false);
}
return llvm::ArrayType::get(element_type, num_elements);
} else if (dt->is<StructType>()) {
std::vector<llvm::Type *> types;
auto struct_type = dt->cast<StructType>();
for (const auto &element : struct_type->elements()) {
types.push_back(get_data_type(element));
}
return llvm::StructType::get(*ctx, types);
} else {
TI_INFO(data_type_name(dt));
TI_NOT_IMPLEMENTED;
Expand Down
36 changes: 36 additions & 0 deletions taichi/util/hash.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#pragma once

#include <functional>
#include <stddef.h>

namespace taichi::hashing {

template <typename T>
struct Hasher {
public:
size_t operator()(T const &val) const {
return std::hash<T>{}(val);
}
};

namespace {
template <typename T>
inline void hash_combine(size_t &seed, T const &value) {
// Reference:
// https://www.boost.org/doc/libs/1_55_0/doc/html/hash/reference.html#boost.hash_combine
seed ^= Hasher<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
} // namespace

template <typename T>
struct Hasher<std::vector<T>> {
public:
size_t operator()(std::vector<T> const &vec) const {
size_t ret = 0;
for (const auto &i : vec) {
hash_combine(ret, i);
}
return ret;
}
};
} // namespace taichi::hashing

0 comments on commit 449a7f6

Please sign in to comment.