Skip to content

Commit

Permalink
add node struct
Browse files Browse the repository at this point in the history
  • Loading branch information
CanftIn committed Jul 3, 2024
1 parent 1488a79 commit fbd3e92
Show file tree
Hide file tree
Showing 4 changed files with 566 additions and 0 deletions.
File renamed without changes.
186 changes: 186 additions & 0 deletions src/ir/array_ref.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
// Copyright (c) ONNX Project Contributors

/*
* SPDX-License-Identifier: Apache-2.0
*/

// ATTENTION: The code in this file is highly EXPERIMENTAL.
// Adventurous users should note that the APIs will probably change.

//===--- ArrayRef.h - Array Reference Wrapper -------------------*- C++ -*-===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//

// ONNX: modified from llvm::ArrayRef.
// removed llvm-specific functionality
// removed some implicit const -> non-const conversions that rely on
// complicated std::enable_if meta-programming
// removed a bunch of slice variants for simplicity...

#pragma once
#include <assert.h>

#include <array>
#include <vector>

namespace my_ai_training::ir {
/// ArrayRef - Represent a constant reference to an array (0 or more elements
/// consecutively in memory), i.e. a start pointer and a length. It allows
/// various APIs to take consecutive elements easily and conveniently.
///
/// This class does not own the underlying data, it is expected to be used in
/// situations where the data resides in some other buffer, whose lifetime
/// extends past that of the ArrayRef. For this reason, it is not in general
/// safe to store an ArrayRef.
///
/// This is intended to be trivially copyable, so it should be passed by
/// value.
template <typename T>
class ArrayRef {
public:
typedef const T* iterator;
typedef const T* const_iterator;
typedef size_t size_type;

typedef std::reverse_iterator<iterator> reverse_iterator;

private:
/// The start of the array, in an external buffer.
const T* Data;

/// The number of elements.
size_type Length;

public:
/// @name Constructors
/// @{

/// Construct an empty ArrayRef.
/*implicit*/ ArrayRef() : Data(nullptr), Length(0) {}

/// Construct an ArrayRef from a single element.
/*implicit*/ ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}

/// Construct an ArrayRef from a pointer and length.
/*implicit*/ ArrayRef(const T* data, size_t length)
: Data(data), Length(length) {}

/// Construct an ArrayRef from a range.
ArrayRef(const T* begin, const T* end) : Data(begin), Length(end - begin) {}

/// Construct an ArrayRef from a std::vector.
template <typename A>
/*implicit*/ ArrayRef(const std::vector<T, A>& Vec)
: Data(Vec.data()), Length(Vec.size()) {}

/// Construct an ArrayRef from a std::array
template <size_t N>
/*implicit*/ constexpr ArrayRef(const std::array<T, N>& Arr)
: Data(Arr.data()), Length(N) {}

/// Construct an ArrayRef from a C array.
template <size_t N>
/*implicit*/ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {}

/// Construct an ArrayRef from a std::initializer_list.
/*implicit*/ ArrayRef(const std::initializer_list<T>& Vec)
: Data(Vec.begin() == Vec.end() ? (T*)nullptr : Vec.begin()),
Length(Vec.size()) {}

/// @}
/// @name Simple Operations
/// @{

iterator begin() const { return Data; }
iterator end() const { return Data + Length; }

reverse_iterator rbegin() const { return reverse_iterator(end()); }
reverse_iterator rend() const { return reverse_iterator(begin()); }

/// empty - Check if the array is empty.
bool empty() const { return Length == 0; }

const T* data() const { return Data; }

/// size - Get the array size.
size_t size() const { return Length; }

/// front - Get the first element.
const T& front() const {
assert(!empty());
return Data[0];
}

/// back - Get the last element.
const T& back() const {
assert(!empty());
return Data[Length - 1];
}

/// equals - Check for element-wise equality.
bool equals(ArrayRef RHS) const {
if (Length != RHS.Length) return false;
return std::equal(begin(), end(), RHS.begin());
}

/// slice(n, m) - Chop off the first N elements of the array, and keep M
/// elements in the array.
ArrayRef<T> slice(size_t N, size_t M) const {
assert(N + M <= size() && "Invalid specifier");
return ArrayRef<T>(data() + N, M);
}

/// slice(n) - Chop off the first N elements of the array.
ArrayRef<T> slice(size_t N) const { return slice(N, size() - N); }

/// @}
/// @name Operator Overloads
/// @{
const T& operator[](size_t Index) const {
assert(Index < Length && "Invalid index!");
return Data[Index];
}

/// Vector compatibility
const T& at(size_t Index) const {
assert(Index < Length && "Invalid index!");
return Data[Index];
}

/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
typename std::enable_if<std::is_same<U, T>::value, ArrayRef<T>>::type&
operator=(U&& Temporary) = delete;

/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
typename std::enable_if<std::is_same<U, T>::value, ArrayRef<T>>::type&
operator=(std::initializer_list<U>) = delete;

/// @}
/// @name Expensive Operations
/// @{
std::vector<T> vec() const { return std::vector<T>(Data, Data + Length); }

/// @}
/// @name Conversion operators
/// @{
operator std::vector<T>() const {
return std::vector<T>(Data, Data + Length);
}

/// @}
};

} // namespace my_ai_training::ir
158 changes: 158 additions & 0 deletions src/ir/graph_node_list.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
#pragma once

#include "ir/assertions.h"

namespace my_ai_training::ir {

// Intrusive doubly linked lists with sane reverse iterators.
// The header file is named graph_node_list.h because it is ONLY
// used for Graph's Node lists, and if you want to use it for other
// things, you will have to do some refactoring.
//
// At the moment, the templated type T must support a few operations:
//
// - It must have a field: T* next_in_graph[2] = { nullptr, nullptr };
// which are used for the intrusive linked list pointers.
//
// - It must have a method 'destroy()', which removes T from the
// list and frees a T.
//
// In practice, we are only using it with Node and const Node. 'destroy()'
// needs to be renegotiated if you want to use this somewhere else.
//
// Besides the benefits of being intrusive, unlike std::list, these lists handle
// forward and backward iteration uniformly because we require a
// "before-first-element" sentinel. This means that reverse iterators
// physically point to the element they logically point to, rather than
// the off-by-one behavior for all standard library reverse iterators.

static constexpr size_t kNextDirection = 0;
static constexpr size_t kPrevDirection = 1;

template <typename T>
struct generic_graph_node_list;

template <typename T>
struct generic_graph_node_list_iterator;

struct Node;
using graph_node_list = generic_graph_node_list<Node>;
using const_graph_node_list = generic_graph_node_list<const Node>;
using graph_node_list_iterator = generic_graph_node_list_iterator<Node>;
using const_graph_node_list_iterator =
generic_graph_node_list_iterator<const Node>;

template <typename T>
struct generic_graph_node_list_iterator final {
generic_graph_node_list_iterator() : cur(nullptr), d(kNextDirection) {}
generic_graph_node_list_iterator(T* cur, size_t d) : cur(cur), d(d) {}
T* operator*() const { return cur; }
T* operator->() const { return cur; }
generic_graph_node_list_iterator& operator++() {
ONNX_ASSERT(cur);
cur = cur->next_in_graph[d];
return *this;
}
generic_graph_node_list_iterator operator++(int) {
generic_graph_node_list_iterator old = *this;
++(*this);
return old;
}
generic_graph_node_list_iterator& operator--() {
ONNX_ASSERT(cur);
cur = cur->next_in_graph[reverseDir()];
return *this;
}
generic_graph_node_list_iterator operator--(int) {
generic_graph_node_list_iterator old = *this;
--(*this);
return old;
}

// erase cur without invalidating this iterator
// named differently from destroy so that ->/. bugs do not
// silently cause the wrong one to be called.
// iterator will point to the previous entry after call
void destroyCurrent() {
T* n = cur;
cur = cur->next_in_graph[reverseDir()];
n->destroy();
}
generic_graph_node_list_iterator reverse() {
return generic_graph_node_list_iterator(cur, reverseDir());
}

private:
size_t reverseDir() {
return d == kNextDirection ? kPrevDirection : kNextDirection;
}
T* cur;
size_t d; // direction 0 is forward 1 is reverse, see next_in_graph
};

template <typename T>
struct generic_graph_node_list final {
using iterator = generic_graph_node_list_iterator<T>;
using const_iterator = generic_graph_node_list_iterator<const T>;
generic_graph_node_list_iterator<T> begin() {
return generic_graph_node_list_iterator<T>(head->next_in_graph[d], d);
}
generic_graph_node_list_iterator<const T> begin() const {
return generic_graph_node_list_iterator<const T>(head->next_in_graph[d], d);
}
generic_graph_node_list_iterator<T> end() {
return generic_graph_node_list_iterator<T>(head, d);
}
generic_graph_node_list_iterator<const T> end() const {
return generic_graph_node_list_iterator<const T>(head, d);
}
generic_graph_node_list_iterator<T> rbegin() { return reverse().begin(); }
generic_graph_node_list_iterator<const T> rbegin() const {
return reverse().begin();
}
generic_graph_node_list_iterator<T> rend() { return reverse().end(); }
generic_graph_node_list_iterator<const T> rend() const {
return reverse().end();
}
generic_graph_node_list reverse() {
return generic_graph_node_list(
head, d == kNextDirection ? kPrevDirection : kNextDirection);
}
const generic_graph_node_list reverse() const {
return generic_graph_node_list(
head, d == kNextDirection ? kPrevDirection : kNextDirection);
}
generic_graph_node_list(T* head, size_t d) : head(head), d(d) {}

private:
T* head;
size_t d;
};

template <typename T>
static inline bool operator==(generic_graph_node_list_iterator<T> a,
generic_graph_node_list_iterator<T> b) {
return *a == *b;
}

template <typename T>
static inline bool operator!=(generic_graph_node_list_iterator<T> a,
generic_graph_node_list_iterator<T> b) {
return *a != *b;
}

} // namespace my_ai_training::ir

namespace std {

template <typename T>
struct iterator_traits<
my_ai_training::ir::generic_graph_node_list_iterator<T>> {
using difference_type = int64_t;
using value_type = T*;
using pointer = T**;
using reference = T*&;
using iterator_category = bidirectional_iterator_tag;
};

} // namespace std
Loading

0 comments on commit fbd3e92

Please sign in to comment.