-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
566 additions
and
0 deletions.
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
// Copyright (c) ONNX Project Contributors | ||
|
||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
// ATTENTION: The code in this file is highly EXPERIMENTAL. | ||
// Adventurous users should note that the APIs will probably change. | ||
|
||
//===--- ArrayRef.h - Array Reference Wrapper -------------------*- C++ -*-===// | ||
// | ||
// The LLVM Compiler Infrastructure | ||
// | ||
// This file is distributed under the University of Illinois Open Source | ||
// License. See LICENSE.TXT for details. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
// ONNX: modified from llvm::ArrayRef. | ||
// removed llvm-specific functionality | ||
// removed some implicit const -> non-const conversions that rely on | ||
// complicated std::enable_if meta-programming | ||
// removed a bunch of slice variants for simplicity... | ||
|
||
#pragma once | ||
#include <assert.h> | ||
|
||
#include <array> | ||
#include <vector> | ||
|
||
namespace my_ai_training::ir { | ||
/// ArrayRef - Represent a constant reference to an array (0 or more elements | ||
/// consecutively in memory), i.e. a start pointer and a length. It allows | ||
/// various APIs to take consecutive elements easily and conveniently. | ||
/// | ||
/// This class does not own the underlying data, it is expected to be used in | ||
/// situations where the data resides in some other buffer, whose lifetime | ||
/// extends past that of the ArrayRef. For this reason, it is not in general | ||
/// safe to store an ArrayRef. | ||
/// | ||
/// This is intended to be trivially copyable, so it should be passed by | ||
/// value. | ||
template <typename T> | ||
class ArrayRef { | ||
public: | ||
typedef const T* iterator; | ||
typedef const T* const_iterator; | ||
typedef size_t size_type; | ||
|
||
typedef std::reverse_iterator<iterator> reverse_iterator; | ||
|
||
private: | ||
/// The start of the array, in an external buffer. | ||
const T* Data; | ||
|
||
/// The number of elements. | ||
size_type Length; | ||
|
||
public: | ||
/// @name Constructors | ||
/// @{ | ||
|
||
/// Construct an empty ArrayRef. | ||
/*implicit*/ ArrayRef() : Data(nullptr), Length(0) {} | ||
|
||
/// Construct an ArrayRef from a single element. | ||
/*implicit*/ ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {} | ||
|
||
/// Construct an ArrayRef from a pointer and length. | ||
/*implicit*/ ArrayRef(const T* data, size_t length) | ||
: Data(data), Length(length) {} | ||
|
||
/// Construct an ArrayRef from a range. | ||
ArrayRef(const T* begin, const T* end) : Data(begin), Length(end - begin) {} | ||
|
||
/// Construct an ArrayRef from a std::vector. | ||
template <typename A> | ||
/*implicit*/ ArrayRef(const std::vector<T, A>& Vec) | ||
: Data(Vec.data()), Length(Vec.size()) {} | ||
|
||
/// Construct an ArrayRef from a std::array | ||
template <size_t N> | ||
/*implicit*/ constexpr ArrayRef(const std::array<T, N>& Arr) | ||
: Data(Arr.data()), Length(N) {} | ||
|
||
/// Construct an ArrayRef from a C array. | ||
template <size_t N> | ||
/*implicit*/ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {} | ||
|
||
/// Construct an ArrayRef from a std::initializer_list. | ||
/*implicit*/ ArrayRef(const std::initializer_list<T>& Vec) | ||
: Data(Vec.begin() == Vec.end() ? (T*)nullptr : Vec.begin()), | ||
Length(Vec.size()) {} | ||
|
||
/// @} | ||
/// @name Simple Operations | ||
/// @{ | ||
|
||
iterator begin() const { return Data; } | ||
iterator end() const { return Data + Length; } | ||
|
||
reverse_iterator rbegin() const { return reverse_iterator(end()); } | ||
reverse_iterator rend() const { return reverse_iterator(begin()); } | ||
|
||
/// empty - Check if the array is empty. | ||
bool empty() const { return Length == 0; } | ||
|
||
const T* data() const { return Data; } | ||
|
||
/// size - Get the array size. | ||
size_t size() const { return Length; } | ||
|
||
/// front - Get the first element. | ||
const T& front() const { | ||
assert(!empty()); | ||
return Data[0]; | ||
} | ||
|
||
/// back - Get the last element. | ||
const T& back() const { | ||
assert(!empty()); | ||
return Data[Length - 1]; | ||
} | ||
|
||
/// equals - Check for element-wise equality. | ||
bool equals(ArrayRef RHS) const { | ||
if (Length != RHS.Length) return false; | ||
return std::equal(begin(), end(), RHS.begin()); | ||
} | ||
|
||
/// slice(n, m) - Chop off the first N elements of the array, and keep M | ||
/// elements in the array. | ||
ArrayRef<T> slice(size_t N, size_t M) const { | ||
assert(N + M <= size() && "Invalid specifier"); | ||
return ArrayRef<T>(data() + N, M); | ||
} | ||
|
||
/// slice(n) - Chop off the first N elements of the array. | ||
ArrayRef<T> slice(size_t N) const { return slice(N, size() - N); } | ||
|
||
/// @} | ||
/// @name Operator Overloads | ||
/// @{ | ||
const T& operator[](size_t Index) const { | ||
assert(Index < Length && "Invalid index!"); | ||
return Data[Index]; | ||
} | ||
|
||
/// Vector compatibility | ||
const T& at(size_t Index) const { | ||
assert(Index < Length && "Invalid index!"); | ||
return Data[Index]; | ||
} | ||
|
||
/// Disallow accidental assignment from a temporary. | ||
/// | ||
/// The declaration here is extra complicated so that "arrayRef = {}" | ||
/// continues to select the move assignment operator. | ||
template <typename U> | ||
typename std::enable_if<std::is_same<U, T>::value, ArrayRef<T>>::type& | ||
operator=(U&& Temporary) = delete; | ||
|
||
/// Disallow accidental assignment from a temporary. | ||
/// | ||
/// The declaration here is extra complicated so that "arrayRef = {}" | ||
/// continues to select the move assignment operator. | ||
template <typename U> | ||
typename std::enable_if<std::is_same<U, T>::value, ArrayRef<T>>::type& | ||
operator=(std::initializer_list<U>) = delete; | ||
|
||
/// @} | ||
/// @name Expensive Operations | ||
/// @{ | ||
std::vector<T> vec() const { return std::vector<T>(Data, Data + Length); } | ||
|
||
/// @} | ||
/// @name Conversion operators | ||
/// @{ | ||
operator std::vector<T>() const { | ||
return std::vector<T>(Data, Data + Length); | ||
} | ||
|
||
/// @} | ||
}; | ||
|
||
} // namespace my_ai_training::ir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
#pragma once | ||
|
||
#include "ir/assertions.h" | ||
|
||
namespace my_ai_training::ir { | ||
|
||
// Intrusive doubly linked lists with sane reverse iterators. | ||
// The header file is named graph_node_list.h because it is ONLY | ||
// used for Graph's Node lists, and if you want to use it for other | ||
// things, you will have to do some refactoring. | ||
// | ||
// At the moment, the templated type T must support a few operations: | ||
// | ||
// - It must have a field: T* next_in_graph[2] = { nullptr, nullptr }; | ||
// which are used for the intrusive linked list pointers. | ||
// | ||
// - It must have a method 'destroy()', which removes T from the | ||
// list and frees a T. | ||
// | ||
// In practice, we are only using it with Node and const Node. 'destroy()' | ||
// needs to be renegotiated if you want to use this somewhere else. | ||
// | ||
// Besides the benefits of being intrusive, unlike std::list, these lists handle | ||
// forward and backward iteration uniformly because we require a | ||
// "before-first-element" sentinel. This means that reverse iterators | ||
// physically point to the element they logically point to, rather than | ||
// the off-by-one behavior for all standard library reverse iterators. | ||
|
||
static constexpr size_t kNextDirection = 0; | ||
static constexpr size_t kPrevDirection = 1; | ||
|
||
template <typename T> | ||
struct generic_graph_node_list; | ||
|
||
template <typename T> | ||
struct generic_graph_node_list_iterator; | ||
|
||
struct Node; | ||
using graph_node_list = generic_graph_node_list<Node>; | ||
using const_graph_node_list = generic_graph_node_list<const Node>; | ||
using graph_node_list_iterator = generic_graph_node_list_iterator<Node>; | ||
using const_graph_node_list_iterator = | ||
generic_graph_node_list_iterator<const Node>; | ||
|
||
template <typename T> | ||
struct generic_graph_node_list_iterator final { | ||
generic_graph_node_list_iterator() : cur(nullptr), d(kNextDirection) {} | ||
generic_graph_node_list_iterator(T* cur, size_t d) : cur(cur), d(d) {} | ||
T* operator*() const { return cur; } | ||
T* operator->() const { return cur; } | ||
generic_graph_node_list_iterator& operator++() { | ||
ONNX_ASSERT(cur); | ||
cur = cur->next_in_graph[d]; | ||
return *this; | ||
} | ||
generic_graph_node_list_iterator operator++(int) { | ||
generic_graph_node_list_iterator old = *this; | ||
++(*this); | ||
return old; | ||
} | ||
generic_graph_node_list_iterator& operator--() { | ||
ONNX_ASSERT(cur); | ||
cur = cur->next_in_graph[reverseDir()]; | ||
return *this; | ||
} | ||
generic_graph_node_list_iterator operator--(int) { | ||
generic_graph_node_list_iterator old = *this; | ||
--(*this); | ||
return old; | ||
} | ||
|
||
// erase cur without invalidating this iterator | ||
// named differently from destroy so that ->/. bugs do not | ||
// silently cause the wrong one to be called. | ||
// iterator will point to the previous entry after call | ||
void destroyCurrent() { | ||
T* n = cur; | ||
cur = cur->next_in_graph[reverseDir()]; | ||
n->destroy(); | ||
} | ||
generic_graph_node_list_iterator reverse() { | ||
return generic_graph_node_list_iterator(cur, reverseDir()); | ||
} | ||
|
||
private: | ||
size_t reverseDir() { | ||
return d == kNextDirection ? kPrevDirection : kNextDirection; | ||
} | ||
T* cur; | ||
size_t d; // direction 0 is forward 1 is reverse, see next_in_graph | ||
}; | ||
|
||
template <typename T> | ||
struct generic_graph_node_list final { | ||
using iterator = generic_graph_node_list_iterator<T>; | ||
using const_iterator = generic_graph_node_list_iterator<const T>; | ||
generic_graph_node_list_iterator<T> begin() { | ||
return generic_graph_node_list_iterator<T>(head->next_in_graph[d], d); | ||
} | ||
generic_graph_node_list_iterator<const T> begin() const { | ||
return generic_graph_node_list_iterator<const T>(head->next_in_graph[d], d); | ||
} | ||
generic_graph_node_list_iterator<T> end() { | ||
return generic_graph_node_list_iterator<T>(head, d); | ||
} | ||
generic_graph_node_list_iterator<const T> end() const { | ||
return generic_graph_node_list_iterator<const T>(head, d); | ||
} | ||
generic_graph_node_list_iterator<T> rbegin() { return reverse().begin(); } | ||
generic_graph_node_list_iterator<const T> rbegin() const { | ||
return reverse().begin(); | ||
} | ||
generic_graph_node_list_iterator<T> rend() { return reverse().end(); } | ||
generic_graph_node_list_iterator<const T> rend() const { | ||
return reverse().end(); | ||
} | ||
generic_graph_node_list reverse() { | ||
return generic_graph_node_list( | ||
head, d == kNextDirection ? kPrevDirection : kNextDirection); | ||
} | ||
const generic_graph_node_list reverse() const { | ||
return generic_graph_node_list( | ||
head, d == kNextDirection ? kPrevDirection : kNextDirection); | ||
} | ||
generic_graph_node_list(T* head, size_t d) : head(head), d(d) {} | ||
|
||
private: | ||
T* head; | ||
size_t d; | ||
}; | ||
|
||
template <typename T> | ||
static inline bool operator==(generic_graph_node_list_iterator<T> a, | ||
generic_graph_node_list_iterator<T> b) { | ||
return *a == *b; | ||
} | ||
|
||
template <typename T> | ||
static inline bool operator!=(generic_graph_node_list_iterator<T> a, | ||
generic_graph_node_list_iterator<T> b) { | ||
return *a != *b; | ||
} | ||
|
||
} // namespace my_ai_training::ir | ||
|
||
namespace std { | ||
|
||
template <typename T> | ||
struct iterator_traits< | ||
my_ai_training::ir::generic_graph_node_list_iterator<T>> { | ||
using difference_type = int64_t; | ||
using value_type = T*; | ||
using pointer = T**; | ||
using reference = T*&; | ||
using iterator_category = bidirectional_iterator_tag; | ||
}; | ||
|
||
} // namespace std |
Oops, something went wrong.