From 449a7f60a57fd0e2213bc8fe984cd6627feef9ce Mon Sep 17 00:00:00 2001 From: Lin Jiang Date: Wed, 28 Dec 2022 14:22:08 +0800 Subject: [PATCH] [ir] Add struct type to CHI-IR (#6982) Issue: #6983 ### Brief Summary Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- taichi/ir/type.cpp | 25 +++++++++++++++ taichi/ir/type.h | 46 +++++++++++++++++++++++++++- taichi/ir/type_factory.cpp | 14 +++++++++ taichi/ir/type_factory.h | 9 ++++++ taichi/python/export_lang.cpp | 10 ++++++ taichi/runtime/llvm/llvm_context.cpp | 7 +++++ taichi/util/hash.h | 36 ++++++++++++++++++++++ 7 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 taichi/util/hash.h diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index df1d32b818a30..9a58cd98a6be4 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -103,6 +103,31 @@ std::string TensorType::to_string() const { return s; } +std::string StructType::to_string() const { + std::string s = "struct("; + for (int i = 0; i < elements_.size(); i++) { + if (i) { + s += ", "; + } + s += std::to_string(i) + ": " + elements_[i]->to_string(); + } + s += ")"; + return s; +} + +Type *StructType::get_element_type(const std::vector &indices) const { + const Type *type_now = this; + for (auto ind : indices) { + if (auto tensor_type = type_now->cast()) { + TI_ASSERT(ind < tensor_type->get_num_elements()) + type_now = tensor_type->get_element_type(); + } else { + type_now = type_now->as()->elements_[ind]; + } + } + return (Type *)type_now; +} + bool Type::is_primitive(PrimitiveTypeID type) const { if (auto p = cast()) { return p->type == type; diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 8b7be6df6b43f..cd550a6b55ba5 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -40,6 +40,14 @@ class TI_DLL_EXPORT Type { return p; } + template + const T *as() const { + auto p = dynamic_cast(this); + TI_ASSERT_INFO(p != nullptr, "Cannot treat {} as {}", this->to_string(), + typeid(T).name()); + return p; + } + bool is_primitive(PrimitiveTypeID type) const; virtual Type *get_compute_type() { @@ -56,7 +64,7 @@ class TI_DLL_EXPORT DataType { DataType(); // NOLINTNEXTLINE(google-explicit-constructor) - DataType(Type *ptr) : ptr_(ptr) { + DataType(const Type *ptr) : ptr_((Type *)ptr) { } DataType(const DataType &o) : ptr_(o.ptr_) { @@ -193,6 +201,42 @@ class TensorType : public Type { Type *element_{nullptr}; }; +class StructType : public Type { + public: + explicit StructType(std::vector elements) + : elements_(std::move(elements)) { + } + + std::string to_string() const override; + + Type *get_element_type(const std::vector &indices) const; + const std::vector &elements() const { + return elements_; + } + + int get_num_elements() const { + int num = 0; + for (const auto &element : elements_) { + if (auto struct_type = element->cast()) { + num += struct_type->get_num_elements(); + } else if (auto tensor_type = element->cast()) { + num += tensor_type->get_num_elements(); + } else { + TI_ASSERT(element->is()); + num += 1; + } + } + return num; + } + + Type *get_compute_type() override { + return this; + } + + private: + std::vector elements_; +}; + class QuantIntType : public Type { public: QuantIntType(int num_bits, bool is_signed, Type *compute_type = nullptr); diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index 9c5cc7b9e03e2..1ecb88686f5fe 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -36,6 +36,20 @@ Type *TypeFactory::get_tensor_type(std::vector shape, Type *element) { return tensor_types_[key].get(); } +Type *TypeFactory::get_struct_type(const std::vector &elements) { + std::lock_guard _(struct_mut_); + if (struct_types_.find(elements) == struct_types_.end()) { + for (const auto &element : elements) { + TI_ASSERT_INFO( + element->is() || element->is() || + element->is() || element->is(), + "Unsupported struct element type: " + element->to_string()); + } + struct_types_[elements] = std::make_unique(elements); + } + return struct_types_[elements].get(); +} + Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) { auto key = std::make_pair(element, is_bit_pointer); if (pointer_types_.find(key) == pointer_types_.end()) { diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index a273091579ea9..9df35d3ba394c 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -1,6 +1,7 @@ #pragma once #include "taichi/ir/type.h" +#include "taichi/util/hash.h" #include @@ -21,6 +22,8 @@ class TypeFactory { Type *get_tensor_type(std::vector shape, Type *element); + Type *get_struct_type(const std::vector &elements); + Type *get_pointer_type(Type *element, bool is_bit_pointer = false); Type *get_quant_int_type(int num_bits, bool is_signed, Type *compute_type); @@ -57,6 +60,12 @@ class TypeFactory { // TODO: use unordered map std::map, std::unique_ptr> tensor_types_; + std::unordered_map, + std::unique_ptr, + hashing::Hasher>> + struct_types_; + std::mutex struct_mut_; + // TODO: is_bit_ptr? std::map, std::unique_ptr> pointer_types_; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 24709e899edf0..f0f42d142c602 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -1151,6 +1151,16 @@ void export_lang(py::module &m) { const DataType &element_type) { return factory->create_tensor_type(shape, element_type); }, + py::return_value_policy::reference) + .def( + "get_struct_type", + [&](TypeFactory *factory, std::vector elements) { + std::vector types; + for (auto &element : elements) { + types.push_back(element); + } + return DataType(factory->get_struct_type(types)); + }, py::return_value_policy::reference); m.def("get_type_factory_instance", TypeFactory::get_instance, diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index 7d2e7433e5db1..5c9d8d04b531c 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -148,6 +148,13 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) { /*scalable=*/false); } return llvm::ArrayType::get(element_type, num_elements); + } else if (dt->is()) { + std::vector types; + auto struct_type = dt->cast(); + for (const auto &element : struct_type->elements()) { + types.push_back(get_data_type(element)); + } + return llvm::StructType::get(*ctx, types); } else { TI_INFO(data_type_name(dt)); TI_NOT_IMPLEMENTED; diff --git a/taichi/util/hash.h b/taichi/util/hash.h new file mode 100644 index 0000000000000..e6042d0714a05 --- /dev/null +++ b/taichi/util/hash.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include + +namespace taichi::hashing { + +template +struct Hasher { + public: + size_t operator()(T const &val) const { + return std::hash{}(val); + } +}; + +namespace { +template +inline void hash_combine(size_t &seed, T const &value) { + // Reference: + // https://www.boost.org/doc/libs/1_55_0/doc/html/hash/reference.html#boost.hash_combine + seed ^= Hasher{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} +} // namespace + +template +struct Hasher> { + public: + size_t operator()(std::vector const &vec) const { + size_t ret = 0; + for (const auto &i : vec) { + hash_combine(ret, i); + } + return ret; + } +}; +} // namespace taichi::hashing