Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NODE][REFACTOR] Refactor reflection system in node. #4189

Merged
merged 3 commits into from
Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/api_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class EnvFuncNode : public Node {
/*! \brief constructor */
EnvFuncNode() {}

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
}

Expand Down
6 changes: 3 additions & 3 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class ConstIntBoundNode : public Node {
int64_t min_value;
int64_t max_value;

void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("min_value", &min_value);
v->Visit("max_value", &max_value);
}
Expand Down Expand Up @@ -162,7 +162,7 @@ class ModularSetNode : public Node {
/*! \brief The base */
int64_t base;

void VisitAttrs(tvm::AttrVisitor* v) final {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("coeff", &coeff);
v->Visit("base", &base);
}
Expand Down Expand Up @@ -351,7 +351,7 @@ enum SignType {
*/
struct IntSetNode : public Node {
static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Object);
};

/*!
Expand Down
10 changes: 6 additions & 4 deletions include/tvm/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class AttrFieldInfoNode : public Node {
/*! \brief detailed description of the type */
std::string description;

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("type_info", &type_info);
v->Visit("description", &description);
Expand Down Expand Up @@ -197,7 +197,7 @@ class AttrsHash {
size_t operator()(const std::string& value) const {
return std::hash<std::string>()(value);
}
size_t operator()(const Type& value) const {
size_t operator()(const DataType& value) const {
return std::hash<int>()(
static_cast<int>(value.code()) |
(static_cast<int>(value.bits()) << 8) |
Expand All @@ -221,6 +221,8 @@ class BaseAttrsNode : public Node {
public:
using TVMArgs = runtime::TVMArgs;
using TVMRetValue = runtime::TVMRetValue;
// visit function
virtual void VisitAttrs(AttrVisitor* v) {}
/*!
* \brief Initialize the attributes by sequence of arguments
* \param args The postional arguments in the form
Expand Down Expand Up @@ -753,12 +755,12 @@ class AttrNonDefaultVisitor {
template<typename DerivedType>
class AttrsNode : public BaseAttrsNode {
public:
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
::tvm::detail::AttrNormalVisitor vis(v);
self()->__VisitAttrs__(vis);
}

void VisitNonDefaultAttrs(AttrVisitor* v) final {
void VisitNonDefaultAttrs(AttrVisitor* v) {
::tvm::detail::AttrNonDefaultVisitor vis(v);
self()->__VisitAttrs__(vis);
}
Expand Down
169 changes: 1 addition & 168 deletions include/tvm/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,89 +19,16 @@

/*!
* \file tvm/base.h
* \brief Defines the base data structure
* \brief Base utilities
*/
#ifndef TVM_BASE_H_
#define TVM_BASE_H_

#include <dmlc/logging.h>
#include <dmlc/registry.h>
#include <tvm/node/node.h>
#include <string>
#include <memory>
#include <functional>
#include <utility>
#include "runtime/registry.h"

namespace tvm {

using ::tvm::Node;
using ::tvm::NodeRef;
using ::tvm::AttrVisitor;

/*!
* \brief Macro to define common node ref methods.
* \param TypeName The name of the NodeRef.
* \param BaseTypeName The Base type.
* \param NodeName The node container type.
*/
#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \
TypeName() {} \
explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \
: BaseTypeName(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(data_.get()); \
} \
operator bool() const { return this->defined(); } \
using ContainerType = NodeName;

/*!
* \brief Macro to define CopyOnWrite function in a NodeRef.
* \param NodeName The Type of the Node.
*
* CopyOnWrite will generate a unique copy of the internal node.
* The node will be copied if it is referenced by multiple places.
* The function returns the raw pointer to the node to allow modification
* of the content.
*
* \code
*
* MyCOWNodeRef ref, ref2;
* ref2 = ref;
* ref.CopyOnWrite()->value = new_value;
* assert(ref2->value == old_value);
* assert(ref->value == new_value);
*
* \endcode
*/
#define TVM_DEFINE_NODE_REF_COW(NodeName) \
NodeName* CopyOnWrite() { \
CHECK(data_ != nullptr); \
if (!data_.unique()) { \
NodePtr<NodeName> n = make_node<NodeName>(*(operator->())); \
ObjectPtr<Object>(std::move(n)).swap(data_); \
} \
return static_cast<NodeName*>(data_.get()); \
}

/*! \brief Macro to make it easy to define node ref type given node */
#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \
class TypeName : public ::tvm::NodeRef { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \
}; \

/*!
* \brief Macro to make it easy to define node ref type that
* has a CopyOnWrite member function.
*/
#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \
class TypeName : public BaseType { \
public: \
TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \
TVM_DEFINE_NODE_REF_COW(NodeName); \
};

/*!
* \brief RAII wrapper function to enter and exit a context object
* similar to python's with syntax.
Expand Down Expand Up @@ -146,100 +73,6 @@ class With {
ContextType ctx_;
};

/*!
* \brief save the node as well as all the node it depends on as json.
* This can be used to serialize any TVM object
*
* \return the string representation of the node.
*/
std::string SaveJSON(const NodeRef& node);

/*!
* \brief Internal implementation of LoadJSON
* Load tvm Node object from json and return a shared_ptr of Node.
* \param json_str The json string to load from.
*
* \return The shared_ptr of the Node.
*/
ObjectPtr<Object> LoadJSON_(std::string json_str);

/*!
* \brief Load the node from json string.
* This can be used to deserialize any TVM object.
*
* \param json_str The json string to load from.
*
* \tparam NodeType the nodetype
*
* \code
* Expr e = LoadJSON<Expr>(json_str);
* \endcode
*/
template<typename NodeType,
typename = typename std::enable_if<std::is_base_of<NodeRef, NodeType>::value>::type >
inline NodeType LoadJSON(const std::string& json_str) {
return NodeType(LoadJSON_(json_str));
}

/*!
* \brief Registry entry for NodeFactory.
*
* There are two types of Nodes that can be serialized.
* The normal node requires a registration a creator function that
* constructs an empty Node of the corresponding type.
*
* The global singleton(e.g. global operator) where only global_key need to be serialized,
* in this case, FGlobalKey need to be defined.
*/
struct NodeFactoryReg {
/*!
* \brief creator function.
* \param global_key Key that identifies a global single object.
* If this is not empty then FGlobalKey
* \return The created function.
*/
using FCreate = std::function<NodePtr<Node>(const std::string& global_key)>;
/*!
* \brief Global key function, only needed by global objects.
* \param node The node pointer.
* \return node The global key to the node.
*/
using FGlobalKey = std::function<std::string(const Node* node)>;
/*! \brief registered name */
std::string name;
/*!
* \brief The creator function
*/
FCreate fcreator = nullptr;
/*!
* \brief The global key function.
*/
FGlobalKey fglobal_key = nullptr;
// setter of creator
NodeFactoryReg& set_creator(FCreate f) { // NOLINT(*)
this->fcreator = f;
return *this;
}
// setter of creator
NodeFactoryReg& set_global_key(FGlobalKey f) { // NOLINT(*)
this->fglobal_key = f;
return *this;
}
// global registry singleton
TVM_DLL static ::dmlc::Registry<::tvm::NodeFactoryReg> *Registry();
};

/*!
* \brief Register a Node type
* \note This is necessary to enable serialization of the Node.
*/
#define TVM_REGISTER_NODE_TYPE(TypeName) \
TVM_REGISTER_OBJECT_TYPE(TypeName); \
static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \
.set_creator([](const std::string&) { return ::tvm::make_node<TypeName>(); })


#define TVM_STRINGIZE_DETAIL(x) #x
#define TVM_STRINGIZE(x) TVM_STRINGIZE_DETAIL(x)
#define TVM_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" TVM_STRINGIZE(__LINE__))
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class BufferNode : public Node {
/*! \brief constructor */
BufferNode() {}

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("data", &data);
v->Visit("dtype", &dtype);
v->Visit("shape", &shape);
Expand Down
6 changes: 4 additions & 2 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class TargetNode : public Node {
/*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const;

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("target_name", &target_name);
v->Visit("device_name", &device_name);
v->Visit("device_type", &device_type);
Expand Down Expand Up @@ -229,7 +229,7 @@ class BuildConfigNode : public Node {
/*! \brief Whether to disable loop vectorization. */
bool disable_vectorize = false;

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
v->Visit("double_buffer_split_loop", &double_buffer_split_loop);
Expand Down Expand Up @@ -473,6 +473,8 @@ class GenericFuncNode : public Node {
/* \brief map from keys to registered functions */
std::unordered_map<std::string, runtime::PackedFunc> dispatch_dict_;

void VisitAttrs(AttrVisitor* v) {}

static constexpr const char* _type_key = "GenericFunc";
TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node);
};
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ struct ChannelNode : public Node {
/*! \brief default data type in read/write */
Type dtype;
// visit all attributes
void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("handle_var", &handle_var);
v->Visit("dtype", &dtype);
}
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/data_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class LayoutNode : public Node {
*/
Array<IterVar> axes;

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("axes", &axes);
}
Expand Down Expand Up @@ -325,7 +325,7 @@ class BijectiveLayoutNode : public Node {
/*! \brief The destination layout */
Layout dst_layout;

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("src_layout", &src_layout);
v->Visit("dst_layout", &dst_layout);
v->Visit("forward_rule", &forward_rule);
Expand Down
12 changes: 7 additions & 5 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
#include <string>
#include <algorithm>
#include <unordered_map>
#include <iostream>
#include "base.h"
#include "dtype.h"
#include "node/node.h"
#include "node/container.h"
#include "node/ir_functor.h"
#include "runtime/c_runtime_api.h"
Expand Down Expand Up @@ -110,7 +112,7 @@ class Variable : public ExprNode {

static Var make(DataType dtype, std::string name_hint);

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("name", &name_hint);
}
Expand Down Expand Up @@ -164,7 +166,7 @@ class IntImm : public ExprNode {
/*! \brief the Internal value. */
int64_t value;

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("value", &value);
}
Expand Down Expand Up @@ -230,7 +232,7 @@ class RangeNode : public Node {
RangeNode() {}
RangeNode(Expr min, Expr extent) : min(min), extent(extent) {}

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("min", &min);
v->Visit("extent", &extent);
}
Expand Down Expand Up @@ -406,7 +408,7 @@ class IterVarNode : public Node {
*/
std::string thread_tag;

void VisitAttrs(AttrVisitor* v) final {
void VisitAttrs(AttrVisitor* v) {
v->Visit("dom", &dom);
v->Visit("var", &var);
v->Visit("iter_type", &iter_type);
Expand Down Expand Up @@ -490,7 +492,7 @@ class IRPrinter {
};

// default print function for all nodes
inline std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*)
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
IRPrinter(os).Print(n);
return os;
}
Expand Down
Loading