From fbd3e920466dbc4c4852780361d6d7a61729628f Mon Sep 17 00:00:00 2001 From: CanftIn Date: Wed, 3 Jul 2024 14:13:54 +0800 Subject: [PATCH] add node struct --- .github/{workflow => workflows}/ci.yml | 0 src/ir/array_ref.h | 186 +++++++++++++++++++++ src/ir/graph_node_list.h | 158 ++++++++++++++++++ src/ir/ir.h | 222 +++++++++++++++++++++++++ 4 files changed, 566 insertions(+) rename .github/{workflow => workflows}/ci.yml (100%) create mode 100644 src/ir/array_ref.h create mode 100644 src/ir/graph_node_list.h diff --git a/.github/workflow/ci.yml b/.github/workflows/ci.yml similarity index 100% rename from .github/workflow/ci.yml rename to .github/workflows/ci.yml diff --git a/src/ir/array_ref.h b/src/ir/array_ref.h new file mode 100644 index 0000000..a08685a --- /dev/null +++ b/src/ir/array_ref.h @@ -0,0 +1,186 @@ +// Copyright (c) ONNX Project Contributors + +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +// ATTENTION: The code in this file is highly EXPERIMENTAL. +// Adventurous users should note that the APIs will probably change. + +//===--- ArrayRef.h - Array Reference Wrapper -------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +// ONNX: modified from llvm::ArrayRef. +// removed llvm-specific functionality +// removed some implicit const -> non-const conversions that rely on +// complicated std::enable_if meta-programming +// removed a bunch of slice variants for simplicity... + +#pragma once +#include + +#include +#include + +namespace my_ai_training::ir { +/// ArrayRef - Represent a constant reference to an array (0 or more elements +/// consecutively in memory), i.e. a start pointer and a length. It allows +/// various APIs to take consecutive elements easily and conveniently. +/// +/// This class does not own the underlying data, it is expected to be used in +/// situations where the data resides in some other buffer, whose lifetime +/// extends past that of the ArrayRef. For this reason, it is not in general +/// safe to store an ArrayRef. +/// +/// This is intended to be trivially copyable, so it should be passed by +/// value. +template +class ArrayRef { + public: + typedef const T* iterator; + typedef const T* const_iterator; + typedef size_t size_type; + + typedef std::reverse_iterator reverse_iterator; + + private: + /// The start of the array, in an external buffer. + const T* Data; + + /// The number of elements. + size_type Length; + + public: + /// @name Constructors + /// @{ + + /// Construct an empty ArrayRef. + /*implicit*/ ArrayRef() : Data(nullptr), Length(0) {} + + /// Construct an ArrayRef from a single element. + /*implicit*/ ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {} + + /// Construct an ArrayRef from a pointer and length. + /*implicit*/ ArrayRef(const T* data, size_t length) + : Data(data), Length(length) {} + + /// Construct an ArrayRef from a range. + ArrayRef(const T* begin, const T* end) : Data(begin), Length(end - begin) {} + + /// Construct an ArrayRef from a std::vector. + template + /*implicit*/ ArrayRef(const std::vector& Vec) + : Data(Vec.data()), Length(Vec.size()) {} + + /// Construct an ArrayRef from a std::array + template + /*implicit*/ constexpr ArrayRef(const std::array& Arr) + : Data(Arr.data()), Length(N) {} + + /// Construct an ArrayRef from a C array. + template + /*implicit*/ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {} + + /// Construct an ArrayRef from a std::initializer_list. + /*implicit*/ ArrayRef(const std::initializer_list& Vec) + : Data(Vec.begin() == Vec.end() ? (T*)nullptr : Vec.begin()), + Length(Vec.size()) {} + + /// @} + /// @name Simple Operations + /// @{ + + iterator begin() const { return Data; } + iterator end() const { return Data + Length; } + + reverse_iterator rbegin() const { return reverse_iterator(end()); } + reverse_iterator rend() const { return reverse_iterator(begin()); } + + /// empty - Check if the array is empty. + bool empty() const { return Length == 0; } + + const T* data() const { return Data; } + + /// size - Get the array size. + size_t size() const { return Length; } + + /// front - Get the first element. + const T& front() const { + assert(!empty()); + return Data[0]; + } + + /// back - Get the last element. + const T& back() const { + assert(!empty()); + return Data[Length - 1]; + } + + /// equals - Check for element-wise equality. + bool equals(ArrayRef RHS) const { + if (Length != RHS.Length) return false; + return std::equal(begin(), end(), RHS.begin()); + } + + /// slice(n, m) - Chop off the first N elements of the array, and keep M + /// elements in the array. + ArrayRef slice(size_t N, size_t M) const { + assert(N + M <= size() && "Invalid specifier"); + return ArrayRef(data() + N, M); + } + + /// slice(n) - Chop off the first N elements of the array. + ArrayRef slice(size_t N) const { return slice(N, size() - N); } + + /// @} + /// @name Operator Overloads + /// @{ + const T& operator[](size_t Index) const { + assert(Index < Length && "Invalid index!"); + return Data[Index]; + } + + /// Vector compatibility + const T& at(size_t Index) const { + assert(Index < Length && "Invalid index!"); + return Data[Index]; + } + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + typename std::enable_if::value, ArrayRef>::type& + operator=(U&& Temporary) = delete; + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + typename std::enable_if::value, ArrayRef>::type& + operator=(std::initializer_list) = delete; + + /// @} + /// @name Expensive Operations + /// @{ + std::vector vec() const { return std::vector(Data, Data + Length); } + + /// @} + /// @name Conversion operators + /// @{ + operator std::vector() const { + return std::vector(Data, Data + Length); + } + + /// @} +}; + +} // namespace my_ai_training::ir diff --git a/src/ir/graph_node_list.h b/src/ir/graph_node_list.h new file mode 100644 index 0000000..2841b02 --- /dev/null +++ b/src/ir/graph_node_list.h @@ -0,0 +1,158 @@ +#pragma once + +#include "ir/assertions.h" + +namespace my_ai_training::ir { + +// Intrusive doubly linked lists with sane reverse iterators. +// The header file is named graph_node_list.h because it is ONLY +// used for Graph's Node lists, and if you want to use it for other +// things, you will have to do some refactoring. +// +// At the moment, the templated type T must support a few operations: +// +// - It must have a field: T* next_in_graph[2] = { nullptr, nullptr }; +// which are used for the intrusive linked list pointers. +// +// - It must have a method 'destroy()', which removes T from the +// list and frees a T. +// +// In practice, we are only using it with Node and const Node. 'destroy()' +// needs to be renegotiated if you want to use this somewhere else. +// +// Besides the benefits of being intrusive, unlike std::list, these lists handle +// forward and backward iteration uniformly because we require a +// "before-first-element" sentinel. This means that reverse iterators +// physically point to the element they logically point to, rather than +// the off-by-one behavior for all standard library reverse iterators. + +static constexpr size_t kNextDirection = 0; +static constexpr size_t kPrevDirection = 1; + +template +struct generic_graph_node_list; + +template +struct generic_graph_node_list_iterator; + +struct Node; +using graph_node_list = generic_graph_node_list; +using const_graph_node_list = generic_graph_node_list; +using graph_node_list_iterator = generic_graph_node_list_iterator; +using const_graph_node_list_iterator = + generic_graph_node_list_iterator; + +template +struct generic_graph_node_list_iterator final { + generic_graph_node_list_iterator() : cur(nullptr), d(kNextDirection) {} + generic_graph_node_list_iterator(T* cur, size_t d) : cur(cur), d(d) {} + T* operator*() const { return cur; } + T* operator->() const { return cur; } + generic_graph_node_list_iterator& operator++() { + ONNX_ASSERT(cur); + cur = cur->next_in_graph[d]; + return *this; + } + generic_graph_node_list_iterator operator++(int) { + generic_graph_node_list_iterator old = *this; + ++(*this); + return old; + } + generic_graph_node_list_iterator& operator--() { + ONNX_ASSERT(cur); + cur = cur->next_in_graph[reverseDir()]; + return *this; + } + generic_graph_node_list_iterator operator--(int) { + generic_graph_node_list_iterator old = *this; + --(*this); + return old; + } + + // erase cur without invalidating this iterator + // named differently from destroy so that ->/. bugs do not + // silently cause the wrong one to be called. + // iterator will point to the previous entry after call + void destroyCurrent() { + T* n = cur; + cur = cur->next_in_graph[reverseDir()]; + n->destroy(); + } + generic_graph_node_list_iterator reverse() { + return generic_graph_node_list_iterator(cur, reverseDir()); + } + + private: + size_t reverseDir() { + return d == kNextDirection ? kPrevDirection : kNextDirection; + } + T* cur; + size_t d; // direction 0 is forward 1 is reverse, see next_in_graph +}; + +template +struct generic_graph_node_list final { + using iterator = generic_graph_node_list_iterator; + using const_iterator = generic_graph_node_list_iterator; + generic_graph_node_list_iterator begin() { + return generic_graph_node_list_iterator(head->next_in_graph[d], d); + } + generic_graph_node_list_iterator begin() const { + return generic_graph_node_list_iterator(head->next_in_graph[d], d); + } + generic_graph_node_list_iterator end() { + return generic_graph_node_list_iterator(head, d); + } + generic_graph_node_list_iterator end() const { + return generic_graph_node_list_iterator(head, d); + } + generic_graph_node_list_iterator rbegin() { return reverse().begin(); } + generic_graph_node_list_iterator rbegin() const { + return reverse().begin(); + } + generic_graph_node_list_iterator rend() { return reverse().end(); } + generic_graph_node_list_iterator rend() const { + return reverse().end(); + } + generic_graph_node_list reverse() { + return generic_graph_node_list( + head, d == kNextDirection ? kPrevDirection : kNextDirection); + } + const generic_graph_node_list reverse() const { + return generic_graph_node_list( + head, d == kNextDirection ? kPrevDirection : kNextDirection); + } + generic_graph_node_list(T* head, size_t d) : head(head), d(d) {} + + private: + T* head; + size_t d; +}; + +template +static inline bool operator==(generic_graph_node_list_iterator a, + generic_graph_node_list_iterator b) { + return *a == *b; +} + +template +static inline bool operator!=(generic_graph_node_list_iterator a, + generic_graph_node_list_iterator b) { + return *a != *b; +} + +} // namespace my_ai_training::ir + +namespace std { + +template +struct iterator_traits< + my_ai_training::ir::generic_graph_node_list_iterator> { + using difference_type = int64_t; + using value_type = T*; + using pointer = T**; + using reference = T*&; + using iterator_category = bidirectional_iterator_tag; +}; + +} // namespace std diff --git a/src/ir/ir.h b/src/ir/ir.h index b07f4e3..fde37e6 100644 --- a/src/ir/ir.h +++ b/src/ir/ir.h @@ -16,7 +16,10 @@ #include #include +#include "ir/array_ref.h" #include "ir/assertions.h" +#include "ir/graph_node_list.h" +#include "ir/interned_strings.h" #define MY_AI_TRAINING_DISALLOW_COPY_AND_ASSIGN(TypeName) \ TypeName(const TypeName&) = delete; \ @@ -24,6 +27,16 @@ namespace my_ai_training::ir { +namespace { // internal/private API + +std::string toVarName(size_t i) { + std::ostringstream oss; + oss << "_v_" << i; + return oss.str(); +} + +} // namespace + struct Graph; struct Node; @@ -86,6 +99,215 @@ static inline const char* toString(AttributeKind kind) { return names[int(kind)]; } +// Each use is represented by this type, see Node::uses() +// 'user' is the consumer of the value, offset is the index into +// 'user's input this where the produces will be found. +struct Use final { + Use(Node* user, size_t offset) : user(user), offset(offset) {} + Node* user; + size_t offset; +}; + +static inline bool operator==(const Use& a, const Use& b) { + return a.user == b.user && a.offset == b.offset; +} + +// the list types are intentionally simple, but we type-def +// them here so if we need to change them, refactoring will be easier +using node_list = std::vector; +using value_list = std::vector; +using use_list = std::vector; +using NodeKind = Symbol; + +struct Value final { + MY_AI_TRAINING_DISALLOW_COPY_AND_ASSIGN(Value); + Value(Node* node, size_t offset); + Value(Value&&) = default; + Value& operator=(Value&&) = default; + ~Value() = default; + + private: + friend struct Node; + friend struct Graph; + Node* node_; + size_t offset_; + size_t unique_ = 0; // unique id + size_t stage_ = 0; // 0-forward, 1-backward, 2-double-backward,... + use_list uses_in_current_graph_; + bool has_unique_name_; + std::string unique_name_; + int32_t elem_type_; + bool has_sizes_; + std::vector sizes_; + + public: + Value* setElemType(int32_t elem_type) { + elem_type_ = elem_type; + return this; + } + + int32_t elemType() const { return elem_type_; } + + bool has_sizes() const { return has_sizes_; } + + Value* setSizes(std::vector sizes) { + has_sizes_ = true; + sizes_ = std::move(sizes); + return this; + } + + Value* wipeSizes() { + has_sizes_ = false; + sizes_ = std::vector(); + return this; + } + + const std::vector& sizes() const { return sizes_; } + + size_t unique() const { return unique_; } + + bool has_unique_name() const { return has_unique_name_; } + std::string uniqueName() const { + if (has_unique_name()) return unique_name_; + return toVarName(unique()); + } + + Value* setUniqueName(const std::string& name, + bool rename_subgraph_captured_nodes = true); + Value* setStage(size_t s) { + stage_ = s; + return this; + } + size_t stage() const { return stage_; } + Node* node() { return node_; } + size_t offset() const { return offset_; } + const Node* node() const { return node_; } + Graph* owningGraph(); + const Graph* owningGraph() const; + // TODO: make this more const correct + const use_list uses() const; + + // Replaces all uses of this node with 'newValue'. + // + // Given: %3 = f(%1, %2) + // %4 = g(%3) + // %5 = h(%3, %3) + // Execute: %3.replaceAllUsesWith(%6) + // Result: %3 = f(%1, %2) + // %4 = g(%6) + // %5 = h(%6, %6) + void replaceAllUsesWith(Value* newValue); + + Value* copyMetadata(Value* from) { + setElemType(from->elemType()); + setSizes(from->sizes()); + if (from->has_unique_name()) { + setUniqueName(from->uniqueName()); + } + return this; + } +}; + +struct Node { + MY_AI_TRAINING_DISALLOW_COPY_AND_ASSIGN(Node); + friend struct Graph; + friend struct Value; + friend graph_node_list; + friend const_graph_node_list; + friend graph_node_list_iterator; + friend const_graph_node_list_iterator; + + private: + // each node but Return/Param + // is associated with exactly one place in the node list... + // of the graph_ + // this circular is a doubly-linked list, the Return node is used as the + // sentinel for the beginning and end of the list such that the list never has + // null pointers next_in_graph[0] is next pointer next_in_graph[1] is prev + // pointer using an array to allow the same iterator class for forward and + // reverse node lists This list represents a topological sort + + Node* next_in_graph[2] = {nullptr, nullptr}; + Node*& next() { return next_in_graph[kNextDirection]; } + Node*& prev() { return next_in_graph[kPrevDirection]; } + Node* const& next() const { return next_in_graph[kNextDirection]; } + Node* const& prev() const { return next_in_graph[kPrevDirection]; } + + const NodeKind kind_; + std::vector inputs_; + std::vector outputs_; + Graph* graph_; + size_t stage_; + bool has_name_; + std::string name_; + bool has_domain_; + std::string domain_; + bool has_doc_string_; + std::string doc_string_; + bool has_overload_; + std::string overload_; + + protected: + Node(Graph* graph_, NodeKind kind_); // defined after graph + + public: + bool has_name() const { return has_name_; } + const std::string& name() const { return name_; } + void setName(std::string name) { + has_name_ = true; + name_ = std::move(name); + } + bool has_domain() const { return has_domain_; } + const std::string& domain() const { return domain_; } + void setDomain(std::string domain) { + has_domain_ = true; + domain_ = std::move(domain); + } + bool has_overload() const { return has_overload_; } + const std::string& overload() const { return overload_; } + void setOverload(std::string overload) { + has_overload_ = true; + overload_ = std::move(overload); + } + bool has_doc_string() const { return has_doc_string_; } + const std::string& docString() const { return doc_string_; } + void setDocString(std::string doc_string) { + has_doc_string_ = true; + doc_string_ = std::move(doc_string); + } + NodeKind kind() const { return kind_; } + Graph* owningGraph() { return graph_; } + const Graph* owningGraph() const { return graph_; } + size_t stage() const { return stage_; } + Node* setStage(size_t s) { + stage_ = s; + return this; + } + // NB: This returns an ArrayRef; that means that it will + // get invalidated if you resize inputs (e.g., using addInput) + // We can't return a std::vector& because there's no + // way to soundly cast to std::vector (an insane + // implementation of std::vector could make this representationally + // different.) + ArrayRef inputs() { return inputs_; } + ArrayRef inputs() const { + // Vectors are not convertible in const-ness of elements, but + // raw pointers are. + return {inputs_.data(), inputs_.size()}; + } + // NB: This returns an ArrayRef; that means that it will + // get invalidated if you resize inputs (e.g., using addInput) + // We can't return a std::vector& because there's no + // way to soundly cast to std::vector (an insane + // implementation of std::vector could make this representationally + // different.) + ArrayRef outputs() { return outputs_; } + ArrayRef outputs() const { + // Vectors are not convertible in const-ness of elements, but + // raw pointers are. + return {outputs_.data(), outputs_.size()}; + } +}; } // namespace my_ai_training::ir