From 78ca6fc8b00eb26dbba335c1fad8f46ae62b4822 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 24 Oct 2019 13:40:04 -0700 Subject: [PATCH] [NODE][REFACTOR] Refactor reflection system in node. (#4189) * [NODE][REFACTOR] Refactor reflection system in node. - Removed the old Node, Node is now just an alias of runtime::Object - Introduce ReflectionVTable, a new columnar dispatcher to support reflection - This allows us to remove vtable from most node objects - The VisitAttrs are registered via TVM_RESGITER_NODE_TYPE, they are no longer virtual. - Consolidated serialization and reflection features into node. * Explicit type qualification when calling destructor. * Fix SPIRV, more comments --- include/tvm/api_registry.h | 2 +- include/tvm/arithmetic.h | 6 +- include/tvm/attrs.h | 10 +- include/tvm/base.h | 169 +--------- include/tvm/buffer.h | 2 +- include/tvm/build_module.h | 6 +- include/tvm/channel.h | 2 +- include/tvm/data_layout.h | 4 +- include/tvm/expr.h | 12 +- include/tvm/ir.h | 66 ++-- include/tvm/lowered_func.h | 2 +- include/tvm/node/container.h | 13 +- include/tvm/node/node.h | 199 ++++++------ include/tvm/node/reflection.h | 241 ++++++++++++++ include/tvm/node/serialization.h | 51 +++ include/tvm/operation.h | 12 +- include/tvm/packed_func_ext.h | 38 +-- include/tvm/relay/adt.h | 16 +- include/tvm/relay/base.h | 6 +- include/tvm/relay/expr.h | 24 +- include/tvm/relay/interpreter.h | 10 +- include/tvm/relay/module.h | 2 +- include/tvm/relay/op.h | 4 +- include/tvm/relay/transform.h | 7 +- include/tvm/relay/type.h | 20 +- include/tvm/runtime/device_api.h | 1 + include/tvm/runtime/memory.h | 25 +- include/tvm/runtime/object.h | 28 +- include/tvm/runtime/packed_func.h | 18 +- include/tvm/runtime/registry.h | 42 +-- include/tvm/schedule.h | 14 +- include/tvm/target_info.h | 2 +- include/tvm/tensor.h | 2 +- include/tvm/tensor_intrin.h | 4 +- nnvm/src/compiler/compile_engine.h | 4 +- nnvm/src/compiler/graph_hash.h | 6 +- nnvm/src/compiler/graph_runtime.cc | 7 +- nnvm/src/compiler/graph_runtime.h | 4 +- src/README.md | 3 +- src/api/api_base.cc | 5 +- src/api/dsl_api.cc | 190 ----------- src/arithmetic/bound_deducer.cc | 8 +- src/arithmetic/canonical_simplify.cc | 7 +- src/arithmetic/int_set.cc | 2 + src/arithmetic/int_set.h | 2 +- src/codegen/spirv/intrin_rule_spirv.cc | 6 +- src/lang/api_registry.cc | 6 +- src/lang/ir.cc | 2 + src/lang/target_info.cc | 6 +- src/node/reflection.cc | 306 ++++++++++++++++++ .../reflection.cc => node/serialization.cc} | 277 +++++----------- src/relay/backend/compile_engine.h | 8 +- src/relay/backend/interpreter.cc | 7 +- src/relay/backend/param_dict.cc | 12 +- src/relay/backend/param_dict.h | 6 +- src/relay/ir/adt.cc | 1 - src/relay/ir/base.cc | 4 +- src/relay/ir/op.cc | 2 +- src/relay/ir/pretty_printer.cc | 17 +- src/relay/ir/type_functor.cc | 6 +- src/relay/pass/alter_op_layout.cc | 2 +- src/relay/pass/device_annotation.cc | 18 +- src/relay/pass/eta_expand.cc | 4 +- src/relay/pass/fold_scale_axis.cc | 4 +- src/relay/pass/forward_rewrite.cc | 6 +- src/relay/pass/pass_manager.cc | 6 +- src/relay/pass/quantize/annotate.cc | 2 +- src/relay/pass/quantize/partition.cc | 2 +- src/relay/pass/quantize/quantize.cc | 2 - src/relay/pass/quantize/quantize.h | 2 +- src/relay/pass/quantize/realize.cc | 2 +- src/relay/pass/type_solver.cc | 6 +- src/relay/pass/util.cc | 8 +- src/runtime/object.cc | 16 +- tests/cpp/build_module_test.cc | 1 + tests/cpp/packed_func_test.cc | 1 + 76 files changed, 1105 insertions(+), 941 deletions(-) create mode 100644 include/tvm/node/reflection.h create mode 100644 include/tvm/node/serialization.h delete mode 100644 src/api/dsl_api.cc create mode 100644 src/node/reflection.cc rename src/{lang/reflection.cc => node/serialization.cc} (64%) diff --git a/include/tvm/api_registry.h b/include/tvm/api_registry.h index dbd097293593..c41c3087f4ac 100644 --- a/include/tvm/api_registry.h +++ b/include/tvm/api_registry.h @@ -58,7 +58,7 @@ class EnvFuncNode : public Node { /*! \brief constructor */ EnvFuncNode() {} - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index e81fa0afd254..bda6ac647f55 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -60,7 +60,7 @@ class ConstIntBoundNode : public Node { int64_t min_value; int64_t max_value; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("min_value", &min_value); v->Visit("max_value", &max_value); } @@ -162,7 +162,7 @@ class ModularSetNode : public Node { /*! \brief The base */ int64_t base; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("coeff", &coeff); v->Visit("base", &base); } @@ -351,7 +351,7 @@ enum SignType { */ struct IntSetNode : public Node { static constexpr const char* _type_key = "IntSet"; - TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node); + TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Object); }; /*! diff --git a/include/tvm/attrs.h b/include/tvm/attrs.h index fb8927a75613..2fbb9e6a866e 100644 --- a/include/tvm/attrs.h +++ b/include/tvm/attrs.h @@ -115,7 +115,7 @@ class AttrFieldInfoNode : public Node { /*! \brief detailed description of the type */ std::string description; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("type_info", &type_info); v->Visit("description", &description); @@ -197,7 +197,7 @@ class AttrsHash { size_t operator()(const std::string& value) const { return std::hash()(value); } - size_t operator()(const Type& value) const { + size_t operator()(const DataType& value) const { return std::hash()( static_cast(value.code()) | (static_cast(value.bits()) << 8) | @@ -221,6 +221,8 @@ class BaseAttrsNode : public Node { public: using TVMArgs = runtime::TVMArgs; using TVMRetValue = runtime::TVMRetValue; + // visit function + virtual void VisitAttrs(AttrVisitor* v) {} /*! * \brief Initialize the attributes by sequence of arguments * \param args The postional arguments in the form @@ -753,12 +755,12 @@ class AttrNonDefaultVisitor { template class AttrsNode : public BaseAttrsNode { public: - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { ::tvm::detail::AttrNormalVisitor vis(v); self()->__VisitAttrs__(vis); } - void VisitNonDefaultAttrs(AttrVisitor* v) final { + void VisitNonDefaultAttrs(AttrVisitor* v) { ::tvm::detail::AttrNonDefaultVisitor vis(v); self()->__VisitAttrs__(vis); } diff --git a/include/tvm/base.h b/include/tvm/base.h index a42de10abef2..9b3b4cd3e8df 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -19,89 +19,16 @@ /*! * \file tvm/base.h - * \brief Defines the base data structure + * \brief Base utilities */ #ifndef TVM_BASE_H_ #define TVM_BASE_H_ #include -#include -#include -#include -#include -#include #include -#include "runtime/registry.h" namespace tvm { -using ::tvm::Node; -using ::tvm::NodeRef; -using ::tvm::AttrVisitor; - -/*! - * \brief Macro to define common node ref methods. - * \param TypeName The name of the NodeRef. - * \param BaseTypeName The Base type. - * \param NodeName The node container type. - */ -#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \ - TypeName() {} \ - explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \ - : BaseTypeName(n) {} \ - const NodeName* operator->() const { \ - return static_cast(data_.get()); \ - } \ - operator bool() const { return this->defined(); } \ - using ContainerType = NodeName; - -/*! - * \brief Macro to define CopyOnWrite function in a NodeRef. - * \param NodeName The Type of the Node. - * - * CopyOnWrite will generate a unique copy of the internal node. - * The node will be copied if it is referenced by multiple places. - * The function returns the raw pointer to the node to allow modification - * of the content. - * - * \code - * - * MyCOWNodeRef ref, ref2; - * ref2 = ref; - * ref.CopyOnWrite()->value = new_value; - * assert(ref2->value == old_value); - * assert(ref->value == new_value); - * - * \endcode - */ -#define TVM_DEFINE_NODE_REF_COW(NodeName) \ - NodeName* CopyOnWrite() { \ - CHECK(data_ != nullptr); \ - if (!data_.unique()) { \ - NodePtr n = make_node(*(operator->())); \ - ObjectPtr(std::move(n)).swap(data_); \ - } \ - return static_cast(data_.get()); \ - } - -/*! \brief Macro to make it easy to define node ref type given node */ -#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \ - class TypeName : public ::tvm::NodeRef { \ - public: \ - TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \ - }; \ - -/*! - * \brief Macro to make it easy to define node ref type that - * has a CopyOnWrite member function. - */ -#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \ - class TypeName : public BaseType { \ - public: \ - TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \ - TVM_DEFINE_NODE_REF_COW(NodeName); \ - }; - /*! * \brief RAII wrapper function to enter and exit a context object * similar to python's with syntax. @@ -146,100 +73,6 @@ class With { ContextType ctx_; }; -/*! - * \brief save the node as well as all the node it depends on as json. - * This can be used to serialize any TVM object - * - * \return the string representation of the node. - */ -std::string SaveJSON(const NodeRef& node); - -/*! - * \brief Internal implementation of LoadJSON - * Load tvm Node object from json and return a shared_ptr of Node. - * \param json_str The json string to load from. - * - * \return The shared_ptr of the Node. - */ -ObjectPtr LoadJSON_(std::string json_str); - -/*! - * \brief Load the node from json string. - * This can be used to deserialize any TVM object. - * - * \param json_str The json string to load from. - * - * \tparam NodeType the nodetype - * - * \code - * Expr e = LoadJSON(json_str); - * \endcode - */ -template::value>::type > -inline NodeType LoadJSON(const std::string& json_str) { - return NodeType(LoadJSON_(json_str)); -} - -/*! - * \brief Registry entry for NodeFactory. - * - * There are two types of Nodes that can be serialized. - * The normal node requires a registration a creator function that - * constructs an empty Node of the corresponding type. - * - * The global singleton(e.g. global operator) where only global_key need to be serialized, - * in this case, FGlobalKey need to be defined. - */ -struct NodeFactoryReg { - /*! - * \brief creator function. - * \param global_key Key that identifies a global single object. - * If this is not empty then FGlobalKey - * \return The created function. - */ - using FCreate = std::function(const std::string& global_key)>; - /*! - * \brief Global key function, only needed by global objects. - * \param node The node pointer. - * \return node The global key to the node. - */ - using FGlobalKey = std::function; - /*! \brief registered name */ - std::string name; - /*! - * \brief The creator function - */ - FCreate fcreator = nullptr; - /*! - * \brief The global key function. - */ - FGlobalKey fglobal_key = nullptr; - // setter of creator - NodeFactoryReg& set_creator(FCreate f) { // NOLINT(*) - this->fcreator = f; - return *this; - } - // setter of creator - NodeFactoryReg& set_global_key(FGlobalKey f) { // NOLINT(*) - this->fglobal_key = f; - return *this; - } - // global registry singleton - TVM_DLL static ::dmlc::Registry<::tvm::NodeFactoryReg> *Registry(); -}; - -/*! - * \brief Register a Node type - * \note This is necessary to enable serialization of the Node. - */ -#define TVM_REGISTER_NODE_TYPE(TypeName) \ - TVM_REGISTER_OBJECT_TYPE(TypeName); \ - static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \ - ::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \ - .set_creator([](const std::string&) { return ::tvm::make_node(); }) - - #define TVM_STRINGIZE_DETAIL(x) #x #define TVM_STRINGIZE(x) TVM_STRINGIZE_DETAIL(x) #define TVM_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" TVM_STRINGIZE(__LINE__)) diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index f18ed9206db3..d2c2b40661e2 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -135,7 +135,7 @@ class BufferNode : public Node { /*! \brief constructor */ BufferNode() {} - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("data", &data); v->Visit("dtype", &dtype); v->Visit("shape", &shape); diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index c985fbe17546..7114a4550331 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -61,7 +61,7 @@ class TargetNode : public Node { /*! \return the full device string to pass to codegen::Build */ TVM_DLL const std::string& str() const; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("target_name", &target_name); v->Visit("device_name", &device_name); v->Visit("device_type", &device_type); @@ -229,7 +229,7 @@ class BuildConfigNode : public Node { /*! \brief Whether to disable loop vectorization. */ bool disable_vectorize = false; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("data_alignment", &data_alignment); v->Visit("offset_factor", &offset_factor); v->Visit("double_buffer_split_loop", &double_buffer_split_loop); @@ -473,6 +473,8 @@ class GenericFuncNode : public Node { /* \brief map from keys to registered functions */ std::unordered_map dispatch_dict_; + void VisitAttrs(AttrVisitor* v) {} + static constexpr const char* _type_key = "GenericFunc"; TVM_DECLARE_NODE_TYPE_INFO(GenericFuncNode, Node); }; diff --git a/include/tvm/channel.h b/include/tvm/channel.h index 346291a6b06a..3a40a787d891 100644 --- a/include/tvm/channel.h +++ b/include/tvm/channel.h @@ -54,7 +54,7 @@ struct ChannelNode : public Node { /*! \brief default data type in read/write */ Type dtype; // visit all attributes - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("handle_var", &handle_var); v->Visit("dtype", &dtype); } diff --git a/include/tvm/data_layout.h b/include/tvm/data_layout.h index ad3da6b347af..5e2cc08660db 100644 --- a/include/tvm/data_layout.h +++ b/include/tvm/data_layout.h @@ -104,7 +104,7 @@ class LayoutNode : public Node { */ Array axes; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("axes", &axes); } @@ -325,7 +325,7 @@ class BijectiveLayoutNode : public Node { /*! \brief The destination layout */ Layout dst_layout; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("src_layout", &src_layout); v->Visit("dst_layout", &dst_layout); v->Visit("forward_rule", &forward_rule); diff --git a/include/tvm/expr.h b/include/tvm/expr.h index d884a4d61748..ea578152899d 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -27,8 +27,10 @@ #include #include #include +#include #include "base.h" #include "dtype.h" +#include "node/node.h" #include "node/container.h" #include "node/ir_functor.h" #include "runtime/c_runtime_api.h" @@ -110,7 +112,7 @@ class Variable : public ExprNode { static Var make(DataType dtype, std::string name_hint); - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("name", &name_hint); } @@ -164,7 +166,7 @@ class IntImm : public ExprNode { /*! \brief the Internal value. */ int64_t value; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("value", &value); } @@ -230,7 +232,7 @@ class RangeNode : public Node { RangeNode() {} RangeNode(Expr min, Expr extent) : min(min), extent(extent) {} - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("min", &min); v->Visit("extent", &extent); } @@ -406,7 +408,7 @@ class IterVarNode : public Node { */ std::string thread_tag; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dom", &dom); v->Visit("var", &var); v->Visit("iter_type", &iter_type); @@ -490,7 +492,7 @@ class IRPrinter { }; // default print function for all nodes -inline std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*) +inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*) IRPrinter(os).Print(n); return os; } diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 37718fe1b3c7..b6c3028d892f 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -45,7 +45,7 @@ class UIntImm : public ExprNode { /*! \brief The constant value content. */ uint64_t value; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("value", &value); } @@ -62,7 +62,7 @@ class FloatImm : public ExprNode { /*! \brief The constant value content. */ double value; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("value", &value); } @@ -79,7 +79,7 @@ class StringImm : public ExprNode { /*! \brief The constant value content. */ std::string value; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("value", &value); } @@ -99,7 +99,7 @@ class Cast : public ExprNode { /*! \brief Original data type. */ Expr value; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("value", &value); } @@ -122,7 +122,7 @@ class BinaryOpNode : public ExprNode { /*! \brief The right operand. */ Expr b; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &(this->type)); v->Visit("a", &a); v->Visit("b", &b); @@ -214,7 +214,7 @@ class CmpOpNode : public ExprNode { /*! \brief The right operand. */ Expr b; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &(this->type)); v->Visit("a", &a); v->Visit("b", &b); @@ -278,7 +278,7 @@ class And : public ExprNode { /*! \brief The right operand. */ Expr b; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &(this->type)); v->Visit("a", &a); v->Visit("b", &b); @@ -298,7 +298,7 @@ class Or : public ExprNode { /*! \brief The right operand. */ Expr b; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("a", &a); v->Visit("b", &b); @@ -316,7 +316,7 @@ class Not : public ExprNode { /*! \brief The input operand. */ Expr a; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("a", &a); } @@ -343,7 +343,7 @@ class Select : public ExprNode { /*! \brief value to be returned when condition is false. */ Expr false_value; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("condition", &condition); v->Visit("true_value", &true_value); @@ -380,7 +380,7 @@ class Load : public ExprNode { /*! \brief The predicate to mask which lanes would be loaded. */ Expr predicate; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("buffer_var", &buffer_var); v->Visit("index", &index); @@ -411,7 +411,7 @@ class Ramp : public ExprNode { /*! \brief Total number of lanes. */ int lanes; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("base", &base); v->Visit("stride", &stride); @@ -432,7 +432,7 @@ class Broadcast : public ExprNode { /*! \brief The numerb of lanes. */ int lanes; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("value", &value); v->Visit("lanes", &lanes); @@ -456,7 +456,7 @@ class Let : public ExprNode { /*! \brief The result expression. */ Expr body; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("var", &var); v->Visit("value", &value); @@ -522,7 +522,7 @@ class Call : public ExprNode { /*! \brief The output value index if func's value is a tuple. */ int value_index{0}; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("name", &name); v->Visit("args", &args); @@ -592,7 +592,7 @@ class Shuffle : public ExprNode { /*! \brief The indices of each element. */ Array indices; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("vectors", &vectors); v->Visit("indices", &indices); } @@ -652,7 +652,7 @@ class CommReducerNode : public Node { Array result, Array identity_element); - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("lhs", &lhs); v->Visit("rhs", &rhs); v->Visit("result", &result); @@ -694,7 +694,7 @@ class Reduce : public ExprNode { Expr condition, int value_index); - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &type); v->Visit("combiner", &combiner); v->Visit("source", &source); @@ -710,7 +710,7 @@ class Reduce : public ExprNode { /*! \brief Any shape. */ class Any : public ExprNode { public: - void VisitAttrs(AttrVisitor* v) final {} + void VisitAttrs(AttrVisitor* v) {} /*! \brief Convert to var. */ Var ToVar() const { return Variable::make(Int(32), "any_dim"); @@ -735,7 +735,7 @@ class LetStmt : public StmtNode { /*! \brief The body block. */ Stmt body; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("var", &var); v->Visit("value", &value); v->Visit("body", &body); @@ -768,7 +768,7 @@ class AttrStmt : public StmtNode { /*! \brief The body statement to be executed */ Stmt body; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("node", &node); v->Visit("attr_key", &attr_key); v->Visit("value", &value); @@ -799,7 +799,7 @@ class AssertStmt : public StmtNode { */ Stmt body; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("condition", &condition); v->Visit("message", &message); v->Visit("body", &body); @@ -822,7 +822,7 @@ class ProducerConsumer : public StmtNode { /*! \brief Body to be executed. */ Stmt body; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("func", &func); v->Visit("is_producer", &is_producer); v->Visit("body", &body); @@ -863,7 +863,7 @@ class Store : public StmtNode { /*! \brief The predicate to mask which lanes would be stored. */ Expr predicate; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("buffer_var", &buffer_var); v->Visit("value", &value); v->Visit("index", &index); @@ -893,7 +893,7 @@ class Provide : public StmtNode { /*! \brief The index arguments of the function. */ Array args; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("func", &func); v->Visit("value_index", &value_index); v->Visit("value", &value); @@ -929,7 +929,7 @@ class Allocate : public StmtNode { Expr new_expr; std::string free_function; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("buffer_var", &buffer_var); v->Visit("dtype", &type); v->Visit("extents", &extents); @@ -972,7 +972,7 @@ class Free : public StmtNode { /*! \brief The buffer variable. */ Var buffer_var; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("buffer_var", &buffer_var); } @@ -1001,7 +1001,7 @@ class Realize : public StmtNode { /*! \brief The body of realization. */ Stmt body; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("func", &func); v->Visit("value_index", &value_index); v->Visit("dtype", &type); @@ -1031,7 +1031,7 @@ class Block : public StmtNode { /*! \brief The restof statments. */ Stmt rest; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("first", &first); v->Visit("rest", &rest); } @@ -1055,7 +1055,7 @@ class IfThenElse : public StmtNode { /*! \brief The branch to be executed when condition is false, can be null. */ Stmt else_case; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("condition", &condition); v->Visit("then_case", &then_case); v->Visit("else_case", &else_case); @@ -1078,7 +1078,7 @@ class Evaluate : public StmtNode { /*! \brief The expression to be evaluated. */ Expr value; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("value", &value); } @@ -1142,7 +1142,7 @@ class For : public StmtNode { DeviceAPI device_api, Stmt body); - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("loop_var", &loop_var); v->Visit("min", &min); v->Visit("extent", &extent); @@ -1169,7 +1169,7 @@ class Prefetch : public StmtNode { /*! \brief Bounds to be prefetched. */ Region bounds; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("func", &func); v->Visit("value_index", &value_index); v->Visit("type", &type); diff --git a/include/tvm/lowered_func.h b/include/tvm/lowered_func.h index e2147d036587..6709f545cb39 100644 --- a/include/tvm/lowered_func.h +++ b/include/tvm/lowered_func.h @@ -119,7 +119,7 @@ class LoweredFuncNode : public ir::FunctionBaseNode { int num_outputs() const final { return 1; } - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("args", &args); v->Visit("thread_axis", &thread_axis); diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index 2e1a978f4806..c36c6c141451 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -40,8 +40,7 @@ class ArrayNode : public Node { /*! \brief the data content */ std::vector data; - void VisitAttrs(AttrVisitor* visitor) final { - // Visitor to array have no effect. + void VisitAttrs(AttrVisitor* visitor) { } static constexpr const char* _type_key = "Array"; @@ -51,9 +50,9 @@ class ArrayNode : public Node { /*! \brief map node content */ class MapNode : public Node { public: - void VisitAttrs(AttrVisitor* visitor) final { - // Visitor to map have no effect. + void VisitAttrs(AttrVisitor* visitor) { } + /*! \brief The corresponding conatiner type */ using ContainerType = std::unordered_map< ObjectRef, @@ -71,12 +70,12 @@ class MapNode : public Node { /*! \brief specialized map node with string as key */ class StrMapNode : public Node { public: - void VisitAttrs(AttrVisitor* visitor) final { - // Visitor to map have no effect. - } /*! \brief The corresponding conatiner type */ using ContainerType = std::unordered_map; + void VisitAttrs(AttrVisitor* visitor) { + } + /*! \brief the data content */ ContainerType data; diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index 8203ee69f686..4014c3700596 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -18,113 +18,68 @@ */ /*! * \file tvm/node/node.h - * \brief Node system data structure. + * \brief Definitions and helper macros for IR/AST nodes. + * + * The node folder contains base utilities for IR/AST nodes, + * invariant of which specific language dialect. + * + * We implement AST/IR nodes as sub-classes of runtime::Object. + * The base class Node is just an alias of runtime::Object. + * + * Besides the runtime type checking provided by Object, + * node folder contains additional functionalities such as + * reflection and serialization, which are important features + * for building a compiler infra. */ #ifndef TVM_NODE_NODE_H_ #define TVM_NODE_NODE_H_ -#include #include #include #include -#include +#include + #include #include #include #include - namespace tvm { -// forward declaration -class DataType; -class Node; -class NodeRef; -/*! - * \brief Visitor class to each node content. - * The content is going to be called for each field. - */ -class TVM_DLL AttrVisitor { - public: -//! \cond Doxygen_Suppress - virtual ~AttrVisitor() = default; - virtual void Visit(const char* key, double* value) = 0; - virtual void Visit(const char* key, int64_t* value) = 0; - virtual void Visit(const char* key, uint64_t* value) = 0; - virtual void Visit(const char* key, int* value) = 0; - virtual void Visit(const char* key, bool* value) = 0; - virtual void Visit(const char* key, std::string* value) = 0; - virtual void Visit(const char* key, void** value) = 0; - virtual void Visit(const char* key, DataType* value) = 0; - virtual void Visit(const char* key, NodeRef* value) = 0; - virtual void Visit(const char* key, runtime::NDArray* value) = 0; - virtual void Visit(const char* key, runtime::ObjectRef* value) = 0; - template::value>::type> - void Visit(const char* key, ENum* ptr) { - static_assert(std::is_same::type>::value, - "declare enum to be enum int to use visitor"); - this->Visit(key, reinterpret_cast(ptr)); - } -//! \endcond -}; +using runtime::TypeIndex; +using runtime::Object; +using runtime::ObjectPtr; +using runtime::ObjectRef; +using runtime::GetRef; +using runtime::Downcast; +using runtime::ObjectHash; +using runtime::ObjectEqual; +using runtime::make_object; -/*! \brief Reuse the type index in he runtime. */ -using TypeIndex = runtime::TypeIndex; +using NodeHash = ObjectHash; +using NodeEqual = ObjectEqual; +using Node = Object; /*! - * \brief base class of node container in DSL AST. + * \brief Base class of all references to AST/IR nodes. */ -class Node : public runtime::Object { +class NodeRef : public ObjectRef { public: - /*! \brief virtual destructor */ - virtual ~Node() {} - - /*! - * \brief Apply visitor to each field of the Node - * Visitor could mutate the content of the node. - * override if Node contains attribute fields. - * \param visitor The visitor - */ - virtual void VisitAttrs(AttrVisitor* visitor) {} - - static constexpr const char* _type_key = "Node"; - static constexpr uint32_t _type_index = TypeIndex::kDynamic; - - TVM_DECLARE_BASE_OBJECT_INFO(Node, runtime::Object); + NodeRef() {} + explicit NodeRef(ObjectPtr n) : ObjectRef(n) {} }; - /*! - * \brief Base class of all node reference object - * NodeRef is just a alias of ObjectRef. + * \brief Allocate a node object. + * \param args arguments to the constructor. + * \tparam T the node type. + * \return The NodePtr to the allocated object. + * \note This function is an alias of make_object. */ -class NodeRef : public runtime::ObjectRef { - public: - /*! \brief type indicate the container type */ - using ContainerType = Node; - - /*! \return the internal node pointer */ - const Node* get() const { - return static_cast(ObjectRef::get()); - } - /*! \return the internal node pointer */ - const Node* operator->() const { - return get(); - } - /*! - * \brief A more powerful version of as that also works with - * intermediate base types. - * \tparam T the target type, must be subtype of IRNode - */ - template - const T *as_derived() const { - return as(); - } - /*! \brief default constructor */ - NodeRef() = default; - explicit NodeRef(runtime::ObjectPtr ptr) : ObjectRef(ptr) {} -}; +template +inline NodePtr make_node(Args&&... args) { + return runtime::make_object(std::forward(args)...); +} /*! * \brief helper macro to declare type information in a base node. @@ -139,27 +94,67 @@ class NodeRef : public runtime::ObjectRef { TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, Parent); -using runtime::Object; -using runtime::ObjectPtr; -using runtime::ObjectRef; -using runtime::GetRef; -using runtime::Downcast; -using runtime::make_object; -using runtime::ObjectHash; -using runtime::ObjectEqual; +/*! + * \brief Macro to define common node ref methods. + * \param TypeName The name of the NodeRef. + * \param BaseTypeName The Base type. + * \param NodeName The node container type. + */ +#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \ + TypeName() {} \ + explicit TypeName(::tvm::ObjectPtr<::tvm::Object> n) \ + : BaseTypeName(n) {} \ + const NodeName* operator->() const { \ + return static_cast(data_.get()); \ + } \ + operator bool() const { return this->defined(); } \ + using ContainerType = NodeName; -using NodeHash = ObjectHash; -using NodeEqual = ObjectEqual; +/*! + * \brief Macro to define CopyOnWrite function in a NodeRef. + * \param NodeName The Type of the Node. + * + * CopyOnWrite will generate a unique copy of the internal node. + * The node will be copied if it is referenced by multiple places. + * The function returns the raw pointer to the node to allow modification + * of the content. + * + * \code + * + * MyCOWNodeRef ref, ref2; + * ref2 = ref; + * ref.CopyOnWrite()->value = new_value; + * assert(ref2->value == old_value); + * assert(ref->value == new_value); + * + * \endcode + */ +#define TVM_DEFINE_NODE_REF_COW(NodeName) \ + NodeName* CopyOnWrite() { \ + CHECK(data_ != nullptr); \ + if (!data_.unique()) { \ + NodePtr n = make_node(*(operator->())); \ + ObjectPtr(std::move(n)).swap(data_); \ + } \ + return static_cast(data_.get()); \ + } + +/*! \brief Macro to make it easy to define node ref type given node */ +#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \ + class TypeName : public ::tvm::NodeRef { \ + public: \ + TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \ + }; \ /*! - * \brief Allocate a node object. - * \param args arguments to the constructor. - * \tparam T the node type. - * \return The NodePtr to the allocated object. + * \brief Macro to make it easy to define node ref type that + * has a CopyOnWrite member function. */ -template -inline NodePtr make_node(Args&&... args) { - return runtime::make_object(std::forward(args)...); -} +#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \ + class TypeName : public BaseType { \ + public: \ + TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \ + TVM_DEFINE_NODE_REF_COW(NodeName); \ + }; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h new file mode 100644 index 000000000000..e6caa443ab9c --- /dev/null +++ b/include/tvm/node/reflection.h @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/node/reflection.h + * \brief Reflection and serialization of compiler IR/AST nodes. + */ +#ifndef TVM_NODE_REFLECTION_H_ +#define TVM_NODE_REFLECTION_H_ + +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { + +// forward declaration +class DataType; + +using runtime::Object; +using runtime::ObjectPtr; +using runtime::ObjectRef; + +/*! + * \brief Visitor class for to get the attributesof a AST/IR node. + * The content is going to be called for each field. + * + * Each objects that wants reflection will need to implement + * a VisitAttrs function and call visitor->Visit on each of its field. + */ +class TVM_DLL AttrVisitor { + public: +//! \cond Doxygen_Suppress + virtual ~AttrVisitor() = default; + virtual void Visit(const char* key, double* value) = 0; + virtual void Visit(const char* key, int64_t* value) = 0; + virtual void Visit(const char* key, uint64_t* value) = 0; + virtual void Visit(const char* key, int* value) = 0; + virtual void Visit(const char* key, bool* value) = 0; + virtual void Visit(const char* key, std::string* value) = 0; + virtual void Visit(const char* key, void** value) = 0; + virtual void Visit(const char* key, DataType* value) = 0; + virtual void Visit(const char* key, runtime::NDArray* value) = 0; + virtual void Visit(const char* key, runtime::ObjectRef* value) = 0; + template::value>::type> + void Visit(const char* key, ENum* ptr) { + static_assert(std::is_same::type>::value, + "declare enum to be enum int to use visitor"); + this->Visit(key, reinterpret_cast(ptr)); + } +//! \endcond +}; + +/*! + * \brief Virtual function table to support IR/AST node reflection. + * + * Functions are stored in columar manner. + * Each column is a vector indexed by Object's type_index. + */ +class ReflectionVTable { + public: + /*! + * \brief Visitor function. + * \note We use function pointer, instead of std::function + * to reduce the dispatch overhead as field visit + * does not need as much customization. + */ + typedef void (*FVisitAttrs)(Object* self, AttrVisitor* visitor); + /*! + * \brief creator function. + * \param global_key Key that identifies a global single object. + * If this is not empty then FGlobalKey must be defined for the object. + * \return The created function. + */ + using FCreate = std::function(const std::string& global_key)>; + /*! + * \brief Global key function, only needed by global objects. + * \param node The node pointer. + * \return node The global key to the node. + */ + using FGlobalKey = std::function; + /*! + * \brief Dispatch the VisitAttrs function. + * \param self The pointer to the object. + * \param visitor The attribute visitor. + */ + inline void VisitAttrs(Object* self, AttrVisitor* visitor) const; + /*! + * \brief Get global key of the object, if any. + * \param self The pointer to the object. + * \return the global key if object has one, otherwise return empty string. + */ + inline std::string GetGlobalKey(Object* self) const; + /*! + * \brief Create an initial object using default constructor + * by type_key and global key. + * + * \param type_key The type key of the object. + * \param global_key A global key that can be used to uniquely identify the object if any. + */ + TVM_DLL ObjectPtr CreateInitObject(const std::string& type_key, + const std::string& global_key = "") const; + /*! + * \brief Get an field object by the attr name. + * \param self The pointer to the object. + * \param attr_name The name of the field. + * \return The corresponding attribute value. + * \note This function will throw an exception if the object does not contain the field. + */ + TVM_DLL runtime::TVMRetValue GetAttr(Object* self, const std::string& attr_name) const; + + /*! + * \brief List all the fields in the object. + * \return All the fields. + */ + TVM_DLL std::vector ListAttrNames(Object* self) const; + + /*! \return The global singleton. */ + TVM_DLL static ReflectionVTable* Global(); + + class Registry; + template + inline Registry Register(); + + private: + /*! \brief Attribute visitor. */ + std::vector fvisit_attrs_; + /*! \brief Creation function. */ + std::vector fcreate_; + /*! \brief Global key function. */ + std::vector fglobal_key_; +}; + +/*! \brief Registry of a reflection table. */ +class ReflectionVTable::Registry { + public: + Registry(ReflectionVTable* parent, uint32_t type_index) + : parent_(parent), type_index_(type_index) { } + /*! + * \brief Set fcreate function. + * \param f The creator function. + * \return rference to self. + */ + Registry& set_creator(FCreate f) { // NOLINT(*) + CHECK_LT(type_index_, parent_->fcreate_.size()); + parent_->fcreate_[type_index_] = f; + return *this; + } + /*! + * \brief Set global_key function. + * \param f The creator function. + * \return rference to self. + */ + Registry& set_global_key(FGlobalKey f) { // NOLINT(*) + CHECK_LT(type_index_, parent_->fglobal_key_.size()); + parent_->fglobal_key_[type_index_] = f; + return *this; + } + + private: + ReflectionVTable* parent_; + uint32_t type_index_; +}; + +/*! + * \brief Register a node type to object registry and reflection registry. + * \param TypeName The name of the type. + * \note This macro will call TVM_REGISTER_OBJECT_TYPE for the type as well. + */ +#define TVM_REGISTER_NODE_TYPE(TypeName) \ + TVM_REGISTER_OBJECT_TYPE(TypeName); \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry & \ + __make_Node ## _ ## TypeName ## __ = \ + ::tvm::ReflectionVTable::Global()->Register() \ + .set_creator([](const std::string&) { \ + return ::tvm::runtime::make_object(); \ + }) + +// Implementation details +template +inline ReflectionVTable::Registry +ReflectionVTable::Register() { + uint32_t tindex = T::RuntimeTypeIndex(); + if (tindex >= fvisit_attrs_.size()) { + fvisit_attrs_.resize(tindex + 1, nullptr); + fcreate_.resize(tindex + 1, nullptr); + fglobal_key_.resize(tindex + 1, nullptr); + } + // functor that implemnts the redirection. + struct Functor { + static void VisitAttrs(Object* self, AttrVisitor* v) { + static_cast(self)->VisitAttrs(v); + } + }; + + fvisit_attrs_[tindex] = Functor::VisitAttrs; + return Registry(this, tindex); +} + +inline void ReflectionVTable:: +VisitAttrs(Object* self, AttrVisitor* visitor) const { + uint32_t tindex = self->type_index(); + if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) { + LOG(FATAL) << "TypeError: " << self->GetTypeKey() + << " is not registered via TVM_REGISTER_NODE_TYPE"; + } + fvisit_attrs_[tindex](self, visitor); +} + +inline std::string ReflectionVTable::GetGlobalKey(Object* self) const { + uint32_t tindex = self->type_index(); + if (tindex < fglobal_key_.size() && fglobal_key_[tindex] != nullptr) { + return fglobal_key_[tindex](self); + } else { + return std::string(); + } +} + +} // namespace tvm +#endif // TVM_NODE_REFLECTION_H_ diff --git a/include/tvm/node/serialization.h b/include/tvm/node/serialization.h new file mode 100644 index 000000000000..ac675946e0eb --- /dev/null +++ b/include/tvm/node/serialization.h @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Utility functions for serialization. + * \file tvm/node/serialization.h + */ +#ifndef TVM_NODE_SERIALIZATION_H_ +#define TVM_NODE_SERIALIZATION_H_ + +#include +#include + +#include + +namespace tvm { +/*! + * \brief save the node as well as all the node it depends on as json. + * This can be used to serialize any TVM object + * + * \return the string representation of the node. + */ +TVM_DLL std::string SaveJSON(const runtime::ObjectRef& node); + +/*! + * \brief Internal implementation of LoadJSON + * Load tvm Node object from json and return a shared_ptr of Node. + * \param json_str The json string to load from. + * + * \return The shared_ptr of the Node. + */ +TVM_DLL runtime::ObjectRef LoadJSON(std::string json_str); + +} // namespace tvm +#endif // TVM_NODE_SERIALIZATION_H_ diff --git a/include/tvm/operation.h b/include/tvm/operation.h index b942464d4907..f53c1ce56a93 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -188,7 +188,7 @@ class PlaceholderOpNode : public OperationNode { const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("tag", &tag); v->Visit("attrs", &attrs); @@ -259,7 +259,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { bool debug_keep_trivial_loop) const final; size_t num_schedulable_dims() const final; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("tag", &tag); v->Visit("attrs", &attrs); @@ -312,7 +312,7 @@ class TensorComputeOpNode : public BaseComputeOpNode { bool debug_keep_trivial_loop) const final; size_t num_schedulable_dims() const final; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("tag", &tag); v->Visit("axis", &axis); @@ -394,7 +394,7 @@ class ScanOpNode : public OperationNode { const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("tag", &tag); v->Visit("attrs", &attrs); @@ -461,7 +461,7 @@ class ExternOpNode : public OperationNode { const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("tag", &tag); v->Visit("attrs", &attrs); @@ -529,7 +529,7 @@ class HybridOpNode : public OperationNode { const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("tag", &tag); v->Visit("attrs", &attrs); diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index 48d46fdf2fc6..71f8f55b2655 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -20,7 +20,7 @@ /*! * \file tvm/packed_func_ext.h * \brief Extension package to PackedFunc - * This enales pass NodeRef types into/from PackedFunc. + * This enales pass ObjectRef types into/from PackedFunc. */ #ifndef TVM_PACKED_FUNC_EXT_H_ #define TVM_PACKED_FUNC_EXT_H_ @@ -129,18 +129,18 @@ inline std::string ObjectTypeName() { // extensions for tvm arg value -template -inline TNodeRef TVMArgValue::AsNodeRef() const { +template +inline TObjectRef TVMArgValue::AsObjectRef() const { static_assert( - std::is_base_of::value, - "Conversion only works for NodeRef"); - if (type_code_ == kNull) return TNodeRef(NodePtr(nullptr)); + std::is_base_of::value, + "Conversion only works for ObjectRef"); + if (type_code_ == kNull) return TObjectRef(NodePtr(nullptr)); TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); Object* ptr = static_cast(value_.v_handle); - CHECK(ObjectTypeChecker::Check(ptr)) - << "Expected type " << ObjectTypeName() + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected type " << ObjectTypeName() << " but get " << ptr->GetTypeKey(); - return TNodeRef(ObjectPtr(ptr)); + return TObjectRef(ObjectPtr(ptr)); } inline TVMArgValue::operator tvm::Expr() const { @@ -184,28 +184,28 @@ inline TVMArgValue::operator tvm::Integer() const { return Integer(ObjectPtr(ptr)); } -template +template inline bool TVMPODValue_::IsObjectRef() const { TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); Object* ptr = static_cast(value_.v_handle); - return ObjectTypeChecker::Check(ptr); + return ObjectTypeChecker::Check(ptr); } // extensions for TVMRetValue -template -inline TNodeRef TVMRetValue::AsNodeRef() const { +template +inline TObjectRef TVMRetValue::AsObjectRef() const { static_assert( - std::is_base_of::value, - "Conversion only works for NodeRef"); - if (type_code_ == kNull) return TNodeRef(); + std::is_base_of::value, + "Conversion only works for ObjectRef"); + if (type_code_ == kNull) return TObjectRef(); TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); Object* ptr = static_cast(value_.v_handle); - CHECK(ObjectTypeChecker::Check(ptr)) - << "Expected type " << ObjectTypeName() + CHECK(ObjectTypeChecker::Check(ptr)) + << "Expected type " << ObjectTypeName() << " but get " << ptr->GetTypeKey(); - return TNodeRef(ObjectPtr(ptr)); + return TObjectRef(ObjectPtr(ptr)); } // type related stuffs diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index e54d88d5a393..a74353239a00 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -66,7 +66,7 @@ class PatternWildcardNode : public PatternNode { TVM_DLL static PatternWildcard make(); - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } @@ -88,7 +88,7 @@ class PatternVarNode : public PatternNode { TVM_DLL static PatternVar make(tvm::relay::Var var); - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("var", &var); v->Visit("span", &span); } @@ -122,7 +122,7 @@ class ConstructorNode : public ExprNode { tvm::Array inputs, GlobalTypeVar belong_to); - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); v->Visit("inputs", &inputs); v->Visit("belong_to", &belong_to); @@ -151,7 +151,7 @@ class PatternConstructorNode : public PatternNode { TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array var); - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("constructor", &constructor); v->Visit("patterns", &patterns); v->Visit("span", &span); @@ -175,7 +175,7 @@ class PatternTupleNode : public PatternNode { TVM_DLL static PatternTuple make(tvm::Array var); - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("patterns", &patterns); v->Visit("span", &span); } @@ -213,7 +213,7 @@ class TypeDataNode : public TypeNode { /*! \brief The constructors. */ tvm::Array constructors; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("header", &header); v->Visit("type_vars", &type_vars); v->Visit("constructors", &constructors); @@ -240,7 +240,7 @@ class ClauseNode : public Node { /*! \brief The resulting value. */ Expr rhs; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("lhs", &lhs); v->Visit("rhs", &rhs); } @@ -269,7 +269,7 @@ class MatchNode : public ExprNode { */ bool complete; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); v->Visit("clauses", &clauses); v->Visit("complete", &complete); diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 15330b00e961..5a2326ece05d 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -107,7 +107,7 @@ class SourceNameNode : public Node { /*! \brief The source name. */ std::string name; // override attr visitor - void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); } + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } static constexpr const char* _type_key = "relay.SourceName"; TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node); @@ -160,7 +160,7 @@ class SpanNode : public Node { /*! \brief column offset */ int col_offset; // override attr visitor - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("source", &source); v->Visit("lineno", &lineno); v->Visit("col_offset", &col_offset); @@ -204,7 +204,7 @@ class IdNode : public Node { */ std::string name_hint; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); } diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 281b99297e78..6df4273d34c0 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -95,7 +95,7 @@ class ConstantNode : public ExprNode { return data->ndim == 0; } - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); @@ -117,7 +117,7 @@ class TupleNode : public ExprNode { /*! \brief the fields of the tuple */ tvm::Array fields; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); @@ -165,7 +165,7 @@ class VarNode : public ExprNode { return vid->name_hint; } - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("vid", &vid); v->Visit("type_annotation", &type_annotation); v->Visit("span", &span); @@ -197,7 +197,7 @@ class GlobalVarNode : public ExprNode { /*! \brief The name of the variable, this only acts as a hint. */ std::string name_hint; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); @@ -243,7 +243,7 @@ class FunctionNode : public ExprNode { */ tvm::Attrs attrs; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("params", ¶ms); v->Visit("body", &body); v->Visit("ret_type", &ret_type); @@ -327,7 +327,7 @@ class CallNode : public ExprNode { */ tvm::Array type_args; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("op", &op); v->Visit("args", &args); v->Visit("attrs", &attrs); @@ -369,7 +369,7 @@ class LetNode : public ExprNode { /*! \brief The body of the let binding */ Expr body; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("var", &var); v->Visit("value", &value); v->Visit("body", &body); @@ -407,7 +407,7 @@ class IfNode : public ExprNode { /*! \brief The expression evaluated when condition is false */ Expr false_branch; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("cond", &cond); v->Visit("true_branch", &true_branch); v->Visit("false_branch", &false_branch); @@ -432,7 +432,7 @@ class TupleGetItemNode : public ExprNode { /*! \brief which value to get */ int index; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("tuple_value", &tuple); v->Visit("index", &index); v->Visit("span", &span); @@ -454,7 +454,7 @@ class RefCreateNode : public ExprNode { /*! \brief The initial value of the Reference. */ Expr value; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); @@ -475,7 +475,7 @@ class RefReadNode : public ExprNode { /*! \brief The Reference Expression. */ Expr ref; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("ref", &ref); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); @@ -498,7 +498,7 @@ class RefWriteNode : public ExprNode { /*! \brief The value to write into. */ Expr value; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("ref", &ref); v->Visit("value", &value); v->Visit("span", &span); diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index f0b1e7ce8a26..3bdc125f9938 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -106,7 +106,7 @@ class ClosureNode : public ValueNode { ClosureNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("env", &env); v->Visit("func", &func); } @@ -154,7 +154,7 @@ struct TupleValueNode : ValueNode { TupleValueNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } TVM_DLL static TupleValue make(tvm::Array value); @@ -173,7 +173,7 @@ struct TensorValueNode : ValueNode { TensorValueNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); } /*! \brief Build a value from an NDArray. */ TVM_DLL static TensorValue make(runtime::NDArray data); @@ -192,7 +192,7 @@ struct RefValueNode : ValueNode { RefValueNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); } @@ -215,7 +215,7 @@ struct ConstructorValueNode : ValueNode { /*! \brief Optional field tracking ADT constructor. */ Constructor constructor; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("tag", &tag); v->Visit("fields", &fields); v->Visit("constructor", &constructor); diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 10d72349d0f5..160ae5db8265 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -68,7 +68,7 @@ class ModuleNode : public RelayNode { ModuleNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("functions", &functions); v->Visit("type_definitions", &type_definitions); v->Visit("global_var_map_", &global_var_map_); diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 572c194bc269..7d2a1f653a93 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -24,6 +24,8 @@ #ifndef TVM_RELAY_OP_H_ #define TVM_RELAY_OP_H_ +#include + #include #include #include @@ -82,7 +84,7 @@ class OpNode : public relay::ExprNode { */ int32_t support_level = 10; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); v->Visit("op_type", &op_type); v->Visit("description", &description); diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 08ea3075cb83..82144d76e565 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -101,7 +101,7 @@ class PassContextNode : public RelayNode { PassContextNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("opt_level", &opt_level); v->Visit("fallback_device", &fallback_device); v->Visit("required_pass", &required_pass); @@ -196,7 +196,7 @@ class PassInfoNode : public RelayNode { PassInfoNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("opt_level", &opt_level); v->Visit("name", &name); v->Visit("required", &required); @@ -221,6 +221,7 @@ class Pass; */ class PassNode : public RelayNode { public: + virtual ~PassNode() {} /*! * \brief Get the pass information/meta data. */ virtual PassInfo Info() const = 0; @@ -247,7 +248,7 @@ class PassNode : public RelayNode { virtual Module operator()(const Module& mod, const PassContext& pass_ctx) const = 0; - void VisitAttrs(tvm::AttrVisitor* v) override {} + void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "relay.Pass"; TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode); diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index a5cc3c83383e..e0c056c1216b 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -96,7 +96,7 @@ class TensorTypeNode : public BaseTensorTypeNode { /*! \brief The content data type */ DataType dtype; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("shape", &shape); v->Visit("dtype", &dtype); v->Visit("span", &span); @@ -159,7 +159,7 @@ class TypeVarNode : public TypeNode { /*! \brief The kind of type parameter */ Kind kind; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("var", &var); v->Visit("kind", &kind); v->Visit("span", &span); @@ -188,7 +188,7 @@ class GlobalTypeVarNode : public TypeNode { /*! \brief The kind of type parameter */ Kind kind; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("var", &var); v->Visit("kind", &kind); v->Visit("span", &span); @@ -216,7 +216,7 @@ class TypeCallNode : public TypeNode { /*! \brief The arguments. */ tvm::Array args; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("func", &func); v->Visit("args", &args); v->Visit("span", &span); @@ -245,7 +245,7 @@ class IncompleteTypeNode : public TypeNode { public: Kind kind; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("kind", &kind); v->Visit("span", &span); } @@ -297,7 +297,7 @@ class FuncTypeNode : public TypeNode { */ tvm::Array type_constraints; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("arg_types", &arg_types); v->Visit("ret_type", &ret_type); v->Visit("type_params", &type_params); @@ -330,7 +330,7 @@ class TupleTypeNode : public TypeNode { TupleTypeNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); v->Visit("span", &span); } @@ -357,7 +357,7 @@ class RefTypeNode : public TypeNode { RefTypeNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); v->Visit("span", &span); } @@ -417,7 +417,7 @@ class TypeReporterNode : public Node { TVM_DLL virtual Module GetModule() = 0; // solver is not serializable. - void VisitAttrs(tvm::AttrVisitor* v) final {} + void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "relay.TypeReporter"; TVM_DECLARE_NODE_TYPE_INFO(TypeReporterNode, Node); @@ -488,7 +488,7 @@ class TypeRelationNode : public TypeConstraintNode { /*! \brief Attributes to the relation function */ Attrs attrs; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("func", &func); v->Visit("args", &args); v->Visit("num_inputs", &num_inputs); diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 68029c13cb93..bb362dcdec66 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -230,6 +230,7 @@ inline std::ostream& operator<<(std::ostream& os, DLContext ctx) { // NOLINT(*) os << runtime::DeviceName(device_type) << "(" << ctx.device_id << ")"; return os; } + #endif } // namespace runtime } // namespace tvm diff --git a/include/tvm/runtime/memory.h b/include/tvm/runtime/memory.h index 01c08d324fcb..d28552eaf7fd 100644 --- a/include/tvm/runtime/memory.h +++ b/include/tvm/runtime/memory.h @@ -82,6 +82,8 @@ class SimpleObjAllocator : template class Handler { public: + using StorageType = typename std::aligned_storage::type; + template static T* New(SimpleObjAllocator*, Args&&... args) { // NOTE: the first argument is not needed for SimpleObjAllocator @@ -91,7 +93,15 @@ class SimpleObjAllocator : // In the case of an object pool, an allocator needs to create // a special chunk memory that hides reference to the allocator // and call allocator's release function in the deleter. - return new T(std::forward(args)...); + + // NOTE2: Use inplace new to allocate + // This is used to get rid of warning when deleting a virtual + // class with non-virtual destructor. + // We are fine here as we captured the right deleter during construction. + // This is also the right way to get storage type for an object pool. + StorageType* data = new StorageType(); + new (data) T(std::forward(args)...); + return reinterpret_cast(data); } static Object::FDeleter Deleter() { @@ -99,8 +109,17 @@ class SimpleObjAllocator : } private: - static void Deleter_(Object* ptr) { - delete static_cast(ptr); + static void Deleter_(Object* objptr) { + // NOTE: this is important to cast back to T* + // because objptr and tptr may not be the same + // depending on how sub-class allocates the space. + T* tptr = static_cast(objptr); + // It is important to do tptr->T::~T(), + // so that we explicitly call the specific destructor + // instead of tptr->~T(), which could mean the intention + // call a virtual destructor(which may not be available and is not required). + tptr->T::~T(); + delete reinterpret_cast(tptr); } }; }; diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 143f3bb35220..cc4a295cc5d4 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -23,6 +23,7 @@ #ifndef TVM_RUNTIME_OBJECT_H_ #define TVM_RUNTIME_OBJECT_H_ +#include #include #include #include @@ -189,7 +190,7 @@ class Object { * \param key The type key. * \return the result. */ - TVM_DLL static uint32_t TypeKey2Index(const char* key); + TVM_DLL static uint32_t TypeKey2Index(const std::string& key); #if TVM_OBJECT_ATOMIC_REF_COUNTER using RefCounterType = std::atomic; @@ -197,18 +198,24 @@ class Object { using RefCounterType = int32_t; #endif - // Object type properties static constexpr const char* _type_key = "Object"; - static constexpr bool _type_final = false; - static constexpr uint32_t _type_child_slots = 0; - static constexpr bool _type_child_slots_can_overflow = true; + static uint32_t _GetOrAllocRuntimeTypeIndex() { - return 0; + return TypeIndex::kRoot; } static uint32_t RuntimeTypeIndex() { - return 0; + return TypeIndex::kRoot; } + // Default object type properties for sub-classes + static constexpr bool _type_final = false; + static constexpr uint32_t _type_child_slots = 0; + static constexpr bool _type_child_slots_can_overflow = true; + // NOTE: the following field is not type index of Object + // but was intended to be used by sub-classes as default value. + // The type index of Object is TypeIndex::kRoot + static constexpr uint32_t _type_index = TypeIndex::kDynamic; + // Default constructor and copy constructor Object() {} // Override the copy and assign constructors to do nothing. @@ -262,13 +269,12 @@ class Object { * \return The allocated type index. */ TVM_DLL static uint32_t GetOrAllocRuntimeTypeIndex( - const char* key, + const std::string& key, uint32_t static_tindex, uint32_t parent_tindex, uint32_t type_child_slots, bool type_child_slots_can_overflow); - private: // reference counter related operations /*! \brief developer function, increases reference counter. */ inline void IncRef(); @@ -621,8 +627,8 @@ struct ObjectEqual { */ #define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ static const uint32_t RuntimeTypeIndex() { \ - if (_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \ - return _type_index; \ + if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \ + return TypeName::_type_index; \ } \ return _GetOrAllocRuntimeTypeIndex(); \ } \ diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 649a5058a9a5..a42946ac2d2c 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -51,8 +51,6 @@ namespace tvm { class Integer; class DataType; class Expr; -class Node; -class NodeRef; namespace runtime { @@ -516,9 +514,9 @@ class TVMPODValue_ { CHECK_LT(type_code_, kExtEnd); return static_cast(value_.v_handle)[0]; } - template::value>::type> + std::is_class::value>::type> inline bool IsObjectRef() const; int type_code() const { return type_code_; @@ -620,8 +618,8 @@ class TVMArgValue : public TVMPODValue_ { return value_; } // Deferred extension handler. - template - inline TNodeRef AsNodeRef() const; + template + inline TObjectRef AsObjectRef() const; template::value>::type> @@ -834,13 +832,13 @@ class TVMRetValue : public TVMPODValue_ { type_code_ != kStr) << "TVMRetValue.value can only be used for POD data"; return value_; } - // NodeRef related extenstions: in tvm/packed_func_ext.h + // ObjectRef related extenstions: in tvm/packed_func_ext.h template::value>::type> inline operator T() const; - template - inline TNodeRef AsNodeRef() const; + template + inline TObjectRef AsObjectRef() const; // type related inline operator tvm::DataType() const; inline TVMRetValue& operator=(const tvm::DataType& other); @@ -1306,7 +1304,7 @@ template struct TVMValueCast { static T Apply(const TSrc* self) { static_assert(!is_ext && !is_nd, "The default case accepts only non-extensions"); - return self->template AsNodeRef(); + return self->template AsObjectRef(); } }; diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h index 40e1a520cb67..d668984f50e2 100644 --- a/include/tvm/runtime/registry.h +++ b/include/tvm/runtime/registry.h @@ -91,7 +91,7 @@ class Registry { * Note that this will ignore default arg values and always require all arguments to be provided. * * \code - * + * * int multiply(int x, int y) { * return x * y; * } @@ -115,7 +115,7 @@ class Registry { * Note that this will ignore default arg values and always require all arguments to be provided. * * \code - * + * * // node subclass: * struct Example { * int doThing(int x); @@ -143,7 +143,7 @@ class Registry { * Note that this will ignore default arg values and always require all arguments to be provided. * * \code - * + * * // node subclass: * struct Example { * int doThing(int x); @@ -168,22 +168,22 @@ class Registry { /*! * \brief set the body of the function to be the passed method pointer. - * Used when calling a method on a Node subclass through a NodeRef subclass. + * Used when calling a method on a Node subclass through a ObjectRef subclass. * Note that this will ignore default arg values and always require all arguments to be provided. * * \code - * + * * // node subclass: * struct ExampleNode: BaseNode { * int doThing(int x); * } - * + * * // noderef subclass - * struct Example; + * struct Example; * * TVM_REGISTER_API("Example_doThing") * .set_body_method(&ExampleNode::doThing); // will have type int(Example, int) - * + * * // note that just doing: * // .set_body_method(&ExampleNode::doThing); * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue. @@ -191,15 +191,15 @@ class Registry { * \endcode * * \param f the method pointer to forward to. - * \tparam TNodeRef the node reference type to call the method on + * \tparam TObjectRef the node reference type to call the method on * \tparam TNode the node type containing the method (inferred). * \tparam R the return type of the function (inferred). * \tparam Args the argument types of the function (inferred). */ - template::value>::type> + template::value>::type> Registry& set_body_method(R (TNode::*f)(Args...)) { - return set_body_typed([f](TNodeRef ref, Args... params) { + return set_body_typed([f](TObjectRef ref, Args... params) { TNode* target = ref.operator->(); // call method pointer return (target->*f)(params...); @@ -208,22 +208,22 @@ class Registry { /*! * \brief set the body of the function to be the passed method pointer. - * Used when calling a method on a Node subclass through a NodeRef subclass. + * Used when calling a method on a Node subclass through a ObjectRef subclass. * Note that this will ignore default arg values and always require all arguments to be provided. * * \code - * + * * // node subclass: * struct ExampleNode: BaseNode { * int doThing(int x); * } - * + * * // noderef subclass - * struct Example; + * struct Example; * * TVM_REGISTER_API("Example_doThing") * .set_body_method(&ExampleNode::doThing); // will have type int(Example, int) - * + * * // note that just doing: * // .set_body_method(&ExampleNode::doThing); * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue. @@ -231,15 +231,15 @@ class Registry { * \endcode * * \param f the method pointer to forward to. - * \tparam TNodeRef the node reference type to call the method on + * \tparam TObjectRef the node reference type to call the method on * \tparam TNode the node type containing the method (inferred). * \tparam R the return type of the function (inferred). * \tparam Args the argument types of the function (inferred). */ - template::value>::type> + template::value>::type> Registry& set_body_method(R (TNode::*f)(Args...) const) { - return set_body_typed([f](TNodeRef ref, Args... params) { + return set_body_typed([f](TObjectRef ref, Args... params) { const TNode* target = ref.operator->(); // call method pointer return (target->*f)(params...); diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 36265667e5b6..3f4ee38a7695 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -495,7 +495,7 @@ class StageNode : public Node { /*! \brief Number of direct child stages, only used for group stage.*/ int num_child_stages{0}; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("op", &op); v->Visit("origin_op", &origin_op); v->Visit("all_iter_vars", &all_iter_vars); @@ -540,7 +540,7 @@ class ScheduleNode : public Node { */ std::unordered_map op2stage_cache_; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("outputs", &outputs); v->Visit("stages", &stages); v->Visit("groups", &groups); @@ -617,7 +617,7 @@ class IterVarAttrNode : public Node { */ Array pragma_values; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("iter_type", &iter_type); v->Visit("bind_thread", &bind_thread); v->Visit("prefetch_data", &prefetch_data); @@ -657,7 +657,7 @@ class SplitNode : public IterVarRelationNode { /*! \brief Number of parts, only factor or nparts can be given */ Expr nparts; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("parent", &parent); v->Visit("outer", &outer); v->Visit("inner", &inner); @@ -687,7 +687,7 @@ class FuseNode : public IterVarRelationNode { /*! \brief The target domain */ IterVar fused; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("outer", &outer); v->Visit("inner", &inner); v->Visit("fused", &fused); @@ -712,7 +712,7 @@ class RebaseNode : public IterVarRelationNode { /*! \brief The inner domain */ IterVar rebased; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("parent", &parent); v->Visit("rebased", &rebased); } @@ -732,7 +732,7 @@ class SingletonNode : public IterVarRelationNode { /*! \brief The singleton iterator */ IterVar iter; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("iter", &iter); } diff --git a/include/tvm/target_info.h b/include/tvm/target_info.h index 1e3a7686ca00..86cb0e275609 100644 --- a/include/tvm/target_info.h +++ b/include/tvm/target_info.h @@ -47,7 +47,7 @@ struct MemoryInfoNode : public Node { */ Expr head_address; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("unit_bits", &unit_bits); v->Visit("max_num_bits", &max_num_bits); v->Visit("max_simd_bits", &max_simd_bits); diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index 6471c9c69a62..599d6ff657d1 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -171,7 +171,7 @@ class TensorNode : public Node { /*! \brief constructor */ TensorNode() {} - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("shape", &shape); v->Visit("dtype", &dtype); v->Visit("op", &op); diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h index 152a27f6e2a9..0d4795ad5440 100644 --- a/include/tvm/tensor_intrin.h +++ b/include/tvm/tensor_intrin.h @@ -87,7 +87,7 @@ class TensorIntrinNode : public Node { /*! \brief constructor */ TensorIntrinNode() {} - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); v->Visit("op", &op); v->Visit("inputs", &inputs); @@ -152,7 +152,7 @@ class TensorIntrinCallNode : public Node { /*! \brief scalar expression inputs */ Array scalar_inputs; - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("intrin", &intrin); v->Visit("tensors", &tensors); v->Visit("regions", ®ions); diff --git a/nnvm/src/compiler/compile_engine.h b/nnvm/src/compiler/compile_engine.h index e8d33cb4be7e..ec9a13b13b17 100644 --- a/nnvm/src/compiler/compile_engine.h +++ b/nnvm/src/compiler/compile_engine.h @@ -55,7 +55,7 @@ struct GraphFuncNode : public tvm::Node { /*! \brief The lowered functions */ tvm::Array funcs; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("target", &target); v->Visit("func_name", &func_name); v->Visit("inputs", &inputs); @@ -78,7 +78,7 @@ struct GraphCacheEntryNode : public tvm::Node { /*! \brief Index of the master node for calling schedule*/ int master_idx; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("graph_func", &graph_func); v->Visit("use_count", &use_count); v->Visit("master_idx", &master_idx); diff --git a/nnvm/src/compiler/graph_hash.h b/nnvm/src/compiler/graph_hash.h index aed3462cf128..6966a152224b 100644 --- a/nnvm/src/compiler/graph_hash.h +++ b/nnvm/src/compiler/graph_hash.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -48,7 +48,7 @@ struct GraphKeyNode : public tvm::Node { // The graph hash key is ensured always not to be 0 mutable size_t cache_hash_key_{0}; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("inputs", &inputs); v->Visit("target", &target); } diff --git a/nnvm/src/compiler/graph_runtime.cc b/nnvm/src/compiler/graph_runtime.cc index 3bfebe3ba4e8..d8ff3bf34bf8 100644 --- a/nnvm/src/compiler/graph_runtime.cc +++ b/nnvm/src/compiler/graph_runtime.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,11 +18,12 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file graph_runtime.cc * \brief Interface code with TVM graph runtime. */ #include +#include + #include #include "graph_runtime.h" diff --git a/nnvm/src/compiler/graph_runtime.h b/nnvm/src/compiler/graph_runtime.h index 7b324ba100ad..770c98e83261 100644 --- a/nnvm/src/compiler/graph_runtime.h +++ b/nnvm/src/compiler/graph_runtime.h @@ -61,13 +61,13 @@ struct NDArrayWrapperNode : public ::tvm::Node { std::string name; tvm::runtime::NDArray array; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); v->Visit("array", &array); } static constexpr const char* _type_key = "NDArrayWrapper"; - TVM_DECLARE_NODE_TYPE_INFO(NDArrayWrapperNode, Node); + TVM_DECLARE_NODE_TYPE_INFO(NDArrayWrapperNode, tvm::Node); }; TVM_DEFINE_NODE_REF(NDArrayWrapper, NDArrayWrapperNode); diff --git a/src/README.md b/src/README.md index 0c6f30a881b8..599f41dfdc5f 100644 --- a/src/README.md +++ b/src/README.md @@ -22,6 +22,8 @@ There can be internal header files within each module that sit in src. ## Modules - common: Internal common utilities. +- runtime: Minimum runtime related codes. +- node: base infra for IR/AST nodes that is dialect independent. - api: API function registration. - lang: The definition of DSL related data structure. - arithmetic: Arithmetic expression and set simplification. @@ -29,7 +31,6 @@ There can be internal header files within each module that sit in src. - schedule: The operations on the schedule graph before converting to IR. - pass: The optimization pass on the IR structure. - codegen: The code generator. -- runtime: Minimum runtime related codes. - autotvm: The auto-tuning module. - relay: Implementation of Relay. The second generation of NNVM, a new IR for deep learning frameworks. - contrib: Contrib extension libraries. diff --git a/src/api/api_base.cc b/src/api/api_base.cc index c25c35f636e6..42367efb15bb 100644 --- a/src/api/api_base.cc +++ b/src/api/api_base.cc @@ -26,6 +26,7 @@ #include #include #include +#include namespace tvm { TVM_REGISTER_API("_format_str") @@ -43,10 +44,10 @@ TVM_REGISTER_API("_raw_ptr") }); TVM_REGISTER_API("_save_json") -.set_body_typed(SaveJSON); +.set_body_typed(SaveJSON); TVM_REGISTER_API("_load_json") -.set_body_typed(LoadJSON); +.set_body_typed(LoadJSON); TVM_REGISTER_API("_TVMSetStream") .set_body_typed(TVMSetStream); diff --git a/src/api/dsl_api.cc b/src/api/dsl_api.cc deleted file mode 100644 index 64805c9e8aa0..000000000000 --- a/src/api/dsl_api.cc +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Implementation of DSL API - * \file dsl_api.cc - */ -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace runtime { - -struct APIAttrGetter : public AttrVisitor { - std::string skey; - TVMRetValue* ret; - bool found_ref_object{false}; - - void Visit(const char* key, double* value) final { - if (skey == key) *ret = value[0]; - } - void Visit(const char* key, int64_t* value) final { - if (skey == key) *ret = value[0]; - } - void Visit(const char* key, uint64_t* value) final { - CHECK_LE(value[0], static_cast(std::numeric_limits::max())) - << "cannot return too big constant"; - if (skey == key) *ret = static_cast(value[0]); - } - void Visit(const char* key, int* value) final { - if (skey == key) *ret = static_cast(value[0]); - } - void Visit(const char* key, bool* value) final { - if (skey == key) *ret = static_cast(value[0]); - } - void Visit(const char* key, void** value) final { - if (skey == key) *ret = static_cast(value[0]); - } - void Visit(const char* key, Type* value) final { - if (skey == key) *ret = value[0]; - } - void Visit(const char* key, std::string* value) final { - if (skey == key) *ret = value[0]; - } - void Visit(const char* key, NodeRef* value) final { - if (skey == key) { - *ret = value[0]; - found_ref_object = true; - } - } - void Visit(const char* key, runtime::NDArray* value) final { - if (skey == key) { - *ret = value[0]; - found_ref_object = true; - } - } - void Visit(const char* key, runtime::ObjectRef* value) final { - if (skey == key) { - *ret = value[0]; - found_ref_object = true; - } - } -}; - -struct APIAttrDir : public AttrVisitor { - std::vector* names; - - void Visit(const char* key, double* value) final { - names->push_back(key); - } - void Visit(const char* key, int64_t* value) final { - names->push_back(key); - } - void Visit(const char* key, uint64_t* value) final { - names->push_back(key); - } - void Visit(const char* key, bool* value) final { - names->push_back(key); - } - void Visit(const char* key, int* value) final { - names->push_back(key); - } - void Visit(const char* key, void** value) final { - names->push_back(key); - } - void Visit(const char* key, Type* value) final { - names->push_back(key); - } - void Visit(const char* key, std::string* value) final { - names->push_back(key); - } - void Visit(const char* key, NodeRef* value) final { - names->push_back(key); - } - void Visit(const char* key, runtime::NDArray* value) final { - names->push_back(key); - } - void Visit(const char* key, runtime::ObjectRef* value) final { - names->push_back(key); - } -}; - -struct NodeAPI { - static void GetAttr(TVMArgs args, TVMRetValue* ret) { - NodeRef ref = args[0]; - Node* tnode = const_cast(ref.get()); - APIAttrGetter getter; - getter.skey = args[1].operator std::string(); - getter.ret = ret; - - bool success; - if (getter.skey == "type_key") { - *ret = tnode->GetTypeKey(); - success = true; - } else if (!tnode->IsInstance()) { - tnode->VisitAttrs(&getter); - success = getter.found_ref_object || ret->type_code() != kNull; - } else { - // specially handle dict attr - DictAttrsNode* dnode = static_cast(tnode); - auto it = dnode->dict.find(getter.skey); - if (it != dnode->dict.end()) { - success = true; - *ret = (*it).second; - } else { - success = false; - } - } - if (!success) { - LOG(FATAL) << "AttributeError: " << tnode->GetTypeKey() - << " object has no attributed " << getter.skey; - } - } - - static void ListAttrNames(TVMArgs args, TVMRetValue* ret) { - NodeRef ref = args[0]; - Node* tnode = const_cast(ref.get()); - auto names = std::make_shared >(); - APIAttrDir dir; - dir.names = names.get(); - - if (!tnode->IsInstance()) { - tnode->VisitAttrs(&dir); - } else { - // specially handle dict attr - DictAttrsNode* dnode = static_cast(tnode); - for (const auto& kv : dnode->dict) { - names->push_back(kv.first); - } - } - - *ret = PackedFunc([names](TVMArgs args, TVMRetValue *rv) { - int64_t i = args[0]; - if (i == -1) { - *rv = static_cast(names->size()); - } else { - *rv = (*names)[i]; - } - }); - } -}; - -TVM_REGISTER_GLOBAL("_NodeGetAttr") -.set_body(NodeAPI::GetAttr); - -TVM_REGISTER_GLOBAL("_NodeListAttrNames") -.set_body(NodeAPI::ListAttrNames); - -} // namespace runtime -} // namespace tvm diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 6f7b4d78da05..9c3a706e2ad0 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -53,17 +53,17 @@ class VariablePathFinder: public IRVisitor { if (!found_) path_.pop_back(); } - std::vector path_; + std::vector path_; private: bool found_{false}; Expr target_; - std::unordered_set visited_; + std::unordered_set visited_; }; // get the path to the variable, // return empty vector to represent failure -std::vector GetPath(Expr target, Expr expr) { +std::vector GetPath(Expr target, Expr expr) { VariablePathFinder v(target); v.Visit(expr); return v.path_; @@ -189,7 +189,7 @@ class BoundDeducer: public IRVisitor { const std::unordered_map& hint_map_; const std::unordered_map& relax_map_; ExprIntSetMap expr_map_; - std::vector path_; + std::vector path_; size_t iter_{0}; // internal analzyer Analyzer analyzer_; diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index 02e8079c9c7b..1b576a645824 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -43,6 +43,7 @@ class SplitExpr; */ class CanonicalExprNode : public BaseExprNode { public: + virtual ~CanonicalExprNode() {} /*! * \brief Return the normal Expr that is equivalent to self. * \note Can mutate the internal data structure. @@ -51,7 +52,7 @@ class CanonicalExprNode : public BaseExprNode { virtual Expr Normalize() const = 0; // overrides - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { } static constexpr const char* _type_key = "arith.CanonicalExpr"; @@ -485,7 +486,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { * \return Normalized expr. */ Expr Normalize(Expr expr) { - if (const auto* op = expr.as_derived()) { + if (const auto* op = expr.as()) { return op->Normalize(); } else { return expr; @@ -503,7 +504,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { if (const auto* op = expr.as()) { if (op->base == 0 && op->args.size() == 1) return op->args[0]; } - if (const auto* op = expr.as_derived()) { + if (const auto* op = expr.as()) { expr = op->Normalize(); } NodePtr n = make_node(); diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 313b34ded034..409477578758 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -807,6 +807,8 @@ IntSet EvalSet(Range r, return EvalSet(r, ConvertDomMap(dom_map)); } +TVM_REGISTER_NODE_TYPE(IntervalSetNode); + TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const IntervalSetNode *op, IRPrinter *p) { p->stream << "IntervalSet" diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index 306361868759..831b44409030 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -47,7 +47,7 @@ class IntervalSetNode : public IntSetNode { Expr max_value; // visitor overload. - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("min_value", &min_value); v->Visit("max_value", &max_value); } diff --git a/src/codegen/spirv/intrin_rule_spirv.cc b/src/codegen/spirv/intrin_rule_spirv.cc index a046cc4f458c..fca9aa203f80 100644 --- a/src/codegen/spirv/intrin_rule_spirv.cc +++ b/src/codegen/spirv/intrin_rule_spirv.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,9 +18,9 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file intrin_rule_spirv.cc */ +#include #include #include #include diff --git a/src/lang/api_registry.cc b/src/lang/api_registry.cc index e041f3a2dd2d..cd3d43b7dcf3 100644 --- a/src/lang/api_registry.cc +++ b/src/lang/api_registry.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -62,7 +62,7 @@ TVM_REGISTER_API("_EnvFuncGetPackedFunc") TVM_REGISTER_NODE_TYPE(EnvFuncNode) .set_creator(CreateEnvNode) -.set_global_key([](const Node* n) { +.set_global_key([](const Object* n) { return static_cast(n)->name; }); diff --git a/src/lang/ir.cc b/src/lang/ir.cc index 48b486a7e13b..04e04aef455c 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -1150,6 +1150,8 @@ TVM_REGISTER_NODE_TYPE(Select); TVM_REGISTER_NODE_TYPE(Load); TVM_REGISTER_NODE_TYPE(Ramp); TVM_REGISTER_NODE_TYPE(Broadcast); +TVM_REGISTER_NODE_TYPE(Shuffle); +TVM_REGISTER_NODE_TYPE(Prefetch); TVM_REGISTER_NODE_TYPE(Call); TVM_REGISTER_NODE_TYPE(Let); TVM_REGISTER_NODE_TYPE(LetStmt); diff --git a/src/lang/target_info.cc b/src/lang/target_info.cc index ff6a35286f20..481a9269193b 100644 --- a/src/lang/target_info.cc +++ b/src/lang/target_info.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,9 +18,9 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file target_info.cc */ +#include #include #include diff --git a/src/node/reflection.cc b/src/node/reflection.cc new file mode 100644 index 000000000000..e92ca92834a2 --- /dev/null +++ b/src/node/reflection.cc @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Reflection utilities. + * \file node/reflection.cc + */ +#include +#include +#include +#include +#include + +namespace tvm { + +// Attr getter. +class AttrGetter : public AttrVisitor { + public: + const std::string& skey; + TVMRetValue* ret; + + AttrGetter(const std::string &skey, + TVMRetValue* ret) + : skey(skey), ret(ret) {} + + bool found_ref_object{false}; + + void Visit(const char* key, double* value) final { + if (skey == key) *ret = value[0]; + } + void Visit(const char* key, int64_t* value) final { + if (skey == key) *ret = value[0]; + } + void Visit(const char* key, uint64_t* value) final { + CHECK_LE(value[0], static_cast(std::numeric_limits::max())) + << "cannot return too big constant"; + if (skey == key) *ret = static_cast(value[0]); + } + void Visit(const char* key, int* value) final { + if (skey == key) *ret = static_cast(value[0]); + } + void Visit(const char* key, bool* value) final { + if (skey == key) *ret = static_cast(value[0]); + } + void Visit(const char* key, void** value) final { + if (skey == key) *ret = static_cast(value[0]); + } + void Visit(const char* key, Type* value) final { + if (skey == key) *ret = value[0]; + } + void Visit(const char* key, std::string* value) final { + if (skey == key) *ret = value[0]; + } + + void Visit(const char* key, runtime::NDArray* value) final { + if (skey == key) { + *ret = value[0]; + found_ref_object = true; + } + } + void Visit(const char* key, runtime::ObjectRef* value) final { + if (skey == key) { + *ret = value[0]; + found_ref_object = true; + } + } +}; + +runtime::TVMRetValue ReflectionVTable::GetAttr( + Object* self, const std::string& field_name) const { + runtime::TVMRetValue ret; + AttrGetter getter(field_name, &ret); + + bool success; + if (getter.skey == "type_key") { + ret = self->GetTypeKey(); + success = true; + } else if (!self->IsInstance()) { + VisitAttrs(self, &getter); + success = getter.found_ref_object || ret.type_code() != kNull; + } else { + // specially handle dict attr + DictAttrsNode* dnode = static_cast(self); + auto it = dnode->dict.find(getter.skey); + if (it != dnode->dict.end()) { + success = true; + ret = (*it).second; + } else { + success = false; + } + } + if (!success) { + LOG(FATAL) << "AttributeError: " << self->GetTypeKey() + << " object has no attributed " << getter.skey; + } + return ret; +} + +// List names; +class AttrDir : public AttrVisitor { + public: + std::vector* names; + + void Visit(const char* key, double* value) final { + names->push_back(key); + } + void Visit(const char* key, int64_t* value) final { + names->push_back(key); + } + void Visit(const char* key, uint64_t* value) final { + names->push_back(key); + } + void Visit(const char* key, bool* value) final { + names->push_back(key); + } + void Visit(const char* key, int* value) final { + names->push_back(key); + } + void Visit(const char* key, void** value) final { + names->push_back(key); + } + void Visit(const char* key, Type* value) final { + names->push_back(key); + } + void Visit(const char* key, std::string* value) final { + names->push_back(key); + } + void Visit(const char* key, runtime::NDArray* value) final { + names->push_back(key); + } + void Visit(const char* key, runtime::ObjectRef* value) final { + names->push_back(key); + } +}; + +std::vector +ReflectionVTable::ListAttrNames(Object* self) const { + std::vector names; + AttrDir dir; + dir.names = &names; + + if (!self->IsInstance()) { + VisitAttrs(self, &dir); + } else { + // specially handle dict attr + DictAttrsNode* dnode = static_cast(self); + for (const auto& kv : dnode->dict) { + names.push_back(kv.first); + } + } + return names; +} + +ReflectionVTable* ReflectionVTable::Global() { + static ReflectionVTable inst; + return &inst; +} + +ObjectPtr +ReflectionVTable::CreateInitObject(const std::string& type_key, + const std::string& global_key) const { + uint32_t tindex = Object::TypeKey2Index(type_key); + if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) { + LOG(FATAL) << "TypeError: " << type_key + << " is not registered via TVM_REGISTER_NODE_TYPE"; + } + return fcreate_[tindex](global_key); +} + +class NodeAttrSetter : public AttrVisitor { + public: + std::string type_key; + std::unordered_map attrs; + + void Visit(const char* key, double* value) final { + *value = GetAttr(key).operator double(); + } + void Visit(const char* key, int64_t* value) final { + *value = GetAttr(key).operator int64_t(); + } + void Visit(const char* key, uint64_t* value) final { + *value = GetAttr(key).operator uint64_t(); + } + void Visit(const char* key, int* value) final { + *value = GetAttr(key).operator int(); + } + void Visit(const char* key, bool* value) final { + *value = GetAttr(key).operator bool(); + } + void Visit(const char* key, std::string* value) final { + *value = GetAttr(key).operator std::string(); + } + void Visit(const char* key, void** value) final { + *value = GetAttr(key).operator void*(); + } + void Visit(const char* key, DataType* value) final { + *value = GetAttr(key).operator DataType(); + } + void Visit(const char* key, runtime::NDArray* value) final { + *value = GetAttr(key).operator runtime::NDArray(); + } + void Visit(const char* key, ObjectRef* value) final { + *value = GetAttr(key).operator ObjectRef(); + } + + private: + runtime::TVMArgValue GetAttr(const char* key) { + auto it = attrs.find(key); + if (it == attrs.end()) { + LOG(FATAL) << type_key << ": require field " << key; + } + runtime::TVMArgValue v = it->second; + attrs.erase(it); + return v; + } +}; + +void InitNodeByPackedArgs(Object* n, const TVMArgs& args) { + NodeAttrSetter setter; + setter.type_key = n->GetTypeKey(); + CHECK_EQ(args.size() % 2, 0); + for (int i = 0; i < args.size(); i += 2) { + setter.attrs.emplace(args[i].operator std::string(), + args[i + 1]); + } + auto* reflection = ReflectionVTable::Global(); + reflection->VisitAttrs(n, &setter); + + if (setter.attrs.size() != 0) { + std::ostringstream os; + os << setter.type_key << " does not contain field "; + for (const auto &kv : setter.attrs) { + os << " " << kv.first; + } + LOG(FATAL) << os.str(); + } +} + +// Expose to FFI APIs. +void NodeGetAttr(TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* self = static_cast(args[0].value().v_handle); + *ret = ReflectionVTable::Global()->GetAttr(self, args[1]); +} + +void NodeListAttrNames(TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args[0].type_code(), kObjectHandle); + Object* self = static_cast(args[0].value().v_handle); + + auto names = std::make_shared >( + ReflectionVTable::Global()->ListAttrNames(self)); + + *ret = PackedFunc([names](TVMArgs args, TVMRetValue *rv) { + int64_t i = args[0]; + if (i == -1) { + *rv = static_cast(names->size()); + } else { + *rv = (*names)[i]; + } + }); +} + +// API function to make node. +// args format: +// key1, value1, ..., key_n, value_n +void MakeNode(const TVMArgs& args, TVMRetValue* rv) { + std::string type_key = args[0]; + std::string empty_str; + TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1); + auto* reflection = ReflectionVTable::Global(); + ObjectPtr n = reflection->CreateInitObject(type_key); + if (n->IsInstance()) { + static_cast(n.get())->InitByPackedArgs(kwargs); + } else { + InitNodeByPackedArgs(n.get(), kwargs); + } + *rv = ObjectRef(n); +} + + +TVM_REGISTER_GLOBAL("_NodeGetAttr") +.set_body(NodeGetAttr); + +TVM_REGISTER_GLOBAL("_NodeListAttrNames") +.set_body(NodeListAttrNames); + +TVM_REGISTER_GLOBAL("make._Node") +.set_body(MakeNode); + +} // namespace tvm diff --git a/src/lang/reflection.cc b/src/node/serialization.cc similarity index 64% rename from src/lang/reflection.cc rename to src/node/serialization.cc index 8e2c3fe7cd15..d270e72d3958 100644 --- a/src/lang/reflection.cc +++ b/src/node/serialization.cc @@ -18,50 +18,42 @@ */ /*! - * \file reflection.cc - * \brief Utilities to save/load/construct TVM objects + * \file node/serialization.cc + * \brief Utilities to serialize TVM AST/IR objects. */ -#include -#include -#include -#include -#include -#include -#include #include #include + +#include +#include +#include +#include +#include +#include + #include -#include "../common/base64.h" +#include -namespace dmlc { -DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg); -} // namespace dmlc +#include "../common/base64.h" namespace tvm { -::dmlc::Registry* NodeFactoryReg::Registry() { - return ::dmlc::Registry::Get(); -} - -inline std::string Type2String(const Type& t) { +inline std::string Type2String(const DataType& t) { return runtime::TVMType2String(Type2TVMType(t)); } - inline Type String2Type(std::string s) { return TVMType2Type(runtime::String2TVMType(s)); } -using runtime::Object; -using runtime::ObjectRef; - // indexer to index all the nodes class NodeIndexer : public AttrVisitor { public: - std::unordered_map node_index{{nullptr, 0}}; - std::vector node_list{nullptr}; - std::unordered_map tensor_index; - std::vector tensor_list; + std::unordered_map node_index_{{nullptr, 0}}; + std::vector node_list_{nullptr}; + std::unordered_map tensor_index_; + std::vector tensor_list_; + ReflectionVTable* reflection_ = ReflectionVTable::Global(); void Visit(const char* key, double* value) final {} void Visit(const char* key, int64_t* value) final {} @@ -70,17 +62,14 @@ class NodeIndexer : public AttrVisitor { void Visit(const char* key, bool* value) final {} void Visit(const char* key, std::string* value) final {} void Visit(const char* key, void** value) final {} - void Visit(const char* key, Type* value) final {} - void Visit(const char* key, NodeRef* value) final { - MakeIndex(const_cast(value->get())); - } + void Visit(const char* key, DataType* value) final {} void Visit(const char* key, runtime::NDArray* value) final { DLTensor* ptr = const_cast((*value).operator->()); - if (tensor_index.count(ptr)) return; - CHECK_EQ(tensor_index.size(), tensor_list.size()); - tensor_index[ptr] = tensor_list.size(); - tensor_list.push_back(ptr); + if (tensor_index_.count(ptr)) return; + CHECK_EQ(tensor_index_.size(), tensor_list_.size()); + tensor_index_[ptr] = tensor_list_.size(); + tensor_list_.push_back(ptr); } void Visit(const char* key, ObjectRef* value) final { @@ -88,15 +77,14 @@ class NodeIndexer : public AttrVisitor { } // make index of all the children of node - void MakeIndex(Object* ptr) { - if (ptr == nullptr) return; - CHECK(ptr->IsInstance()); - auto* node = static_cast(ptr); + void MakeIndex(Object* node) { + if (node == nullptr) return; + CHECK(node->IsInstance()); - if (node_index.count(node)) return; - CHECK_EQ(node_index.size(), node_list.size()); - node_index[node] = node_list.size(); - node_list.push_back(node); + if (node_index_.count(node)) return; + CHECK_EQ(node_index_.size(), node_list_.size()); + node_index_[node] = node_list_.size(); + node_list_.push_back(node); if (node->IsInstance()) { ArrayNode* n = static_cast(node); @@ -115,7 +103,7 @@ class NodeIndexer : public AttrVisitor { MakeIndex(const_cast(kv.second.get())); } } else { - static_cast(node)->VisitAttrs(this); + reflection_->VisitAttrs(node, this); } } }; @@ -123,17 +111,17 @@ class NodeIndexer : public AttrVisitor { // use map so attributes are ordered. using AttrMap = std::map; -// A Node structure for JSON node. +/*! \brief Node structure for json format. */ struct JSONNode { - // The type key of the data + /*! \brief The type of key of the object. */ std::string type_key; - // The global key for global object + /*! \brief The global key for global object. */ std::string global_key; - // the attributes + /*! \brief the attributes */ AttrMap attrs; - // container keys + /*! \brief keys of a map. */ std::vector keys; - // container data + /*! \brief values of a map or array. */ std::vector data; void Save(dmlc::JSONWriter *writer) const { @@ -169,11 +157,14 @@ struct JSONNode { } }; +// Helper class to populate the json node +// using the existing index. class JSONAttrGetter : public AttrVisitor { public: const std::unordered_map* node_index_; const std::unordered_map* tensor_index_; JSONNode* node_; + ReflectionVTable* reflection_ = ReflectionVTable::Global(); void Visit(const char* key, double* value) final { node_->attrs[key] = std::to_string(*value); @@ -196,40 +187,36 @@ class JSONAttrGetter : public AttrVisitor { void Visit(const char* key, void** value) final { LOG(FATAL) << "not allowed to serialize a pointer"; } - void Visit(const char* key, Type* value) final { + void Visit(const char* key, DataType* value) final { node_->attrs[key] = Type2String(*value); } - void Visit(const char* key, NodeRef* value) final { - node_->attrs[key] = std::to_string( - node_index_->at(const_cast(value->get()))); - } + void Visit(const char* key, runtime::NDArray* value) final { node_->attrs[key] = std::to_string( tensor_index_->at(const_cast((*value).operator->()))); } + void Visit(const char* key, ObjectRef* value) final { - LOG(FATAL) << "Do not support json serialize non-node object"; + node_->attrs[key] = std::to_string( + node_index_->at(const_cast(value->get()))); } + // Get the node - void Get(Object* ptr) { - if (ptr == nullptr) { + void Get(Object* node) { + if (node == nullptr) { node_->type_key.clear(); return; } - CHECK(ptr->IsInstance()); - auto* node = static_cast(ptr); node_->type_key = node->GetTypeKey(); + node_->global_key = reflection_->GetGlobalKey(node); + // No need to recursively visit fields of global singleton + // They are registered via the environment. + if (node_->global_key.length() != 0) return; - // sepcially handle global object - auto* f = dmlc::Registry::Find(node_->type_key); - CHECK(f != nullptr) - << "Node type \'" << node_->type_key << "\' is not registered in TVM"; - if (f->fglobal_key != nullptr) { - node_->global_key = f->fglobal_key(node); - return; - } + // populates the fields. node_->attrs.clear(); node_->data.clear(); + if (node->IsInstance()) { ArrayNode* n = static_cast(node); for (size_t i = 0; i < n->data.size(); ++i) { @@ -252,23 +239,22 @@ class JSONAttrGetter : public AttrVisitor { node_index_->at(const_cast(kv.second.get()))); } } else { - // do not need to recover content of global singleton object - // they are registered via the environment - auto* f = dmlc::Registry::Find(node->GetTypeKey()); - if (f != nullptr && f->fglobal_key != nullptr) return; // recursively index normal object. - node->VisitAttrs(this); + reflection_->VisitAttrs(node, this); } } }; +// Helper class to set the attributes of a node +// from given json node. class JSONAttrSetter : public AttrVisitor { public: const std::vector >* node_list_; const std::vector* tensor_list_; - JSONNode* node_; + ReflectionVTable* reflection_ = ReflectionVTable::Global(); + std::string GetValue(const char* key) const { auto it = node_->attrs.find(key); if (it == node_->attrs.end()) { @@ -305,16 +291,10 @@ class JSONAttrSetter : public AttrVisitor { void Visit(const char* key, void** value) final { LOG(FATAL) << "not allowed to deserialize a pointer"; } - void Visit(const char* key, Type* value) final { + void Visit(const char* key, DataType* value) final { std::string stype = GetValue(key); *value = String2Type(stype); } - void Visit(const char* key, NodeRef* value) final { - size_t index; - ParseValue(key, &index); - CHECK_LE(index, node_list_->size()); - *value = NodeRef(node_list_->at(index)); - } void Visit(const char* key, runtime::NDArray* value) final { size_t index; ParseValue(key, &index); @@ -322,14 +302,15 @@ class JSONAttrSetter : public AttrVisitor { *value = tensor_list_->at(index); } void Visit(const char* key, ObjectRef* value) final { - LOG(FATAL) << "Do not support json serialize non-node object"; + size_t index; + ParseValue(key, &index); + CHECK_LE(index, node_list_->size()); + *value = ObjectRef(node_list_->at(index)); } // set node to be current JSONNode - void Set(Object* ptr) { - if (ptr == nullptr) return; + void Set(Object* node) { + if (node == nullptr) return; - CHECK(ptr->IsInstance()); - auto* node = static_cast(ptr); if (node->IsInstance()) { ArrayNode* n = static_cast(node); n->data.clear(); @@ -351,7 +332,7 @@ class JSONAttrSetter : public AttrVisitor { = ObjectRef(node_list_->at(node_->data[i])); } } else { - node->VisitAttrs(this); + reflection_->VisitAttrs(node, this); } } }; @@ -393,18 +374,18 @@ struct JSONGraph { NodeIndexer indexer; indexer.MakeIndex(const_cast(root.get())); JSONAttrGetter getter; - getter.node_index_ = &indexer.node_index; - getter.tensor_index_ = &indexer.tensor_index; - for (Object* n : indexer.node_list) { + getter.node_index_ = &indexer.node_index_; + getter.tensor_index_ = &indexer.tensor_index_; + for (Object* n : indexer.node_list_) { JSONNode jnode; getter.node_ = &jnode; getter.Get(n); g.nodes.emplace_back(std::move(jnode)); } g.attrs["tvm_version"] = TVM_VERSION; - g.root = indexer.node_index.at(const_cast(root.get())); + g.root = indexer.node_index_.at(const_cast(root.get())); // serialize tensor - for (DLTensor* tensor : indexer.tensor_list) { + for (DLTensor* tensor : indexer.tensor_list_) { std::string blob; dmlc::MemoryStringStream mstrm(&blob); common::Base64OutStream b64strm(&mstrm); @@ -416,7 +397,7 @@ struct JSONGraph { } }; -std::string SaveJSON(const NodeRef& n) { +std::string SaveJSON(const ObjectRef& n) { auto jgraph = JSONGraph::Create(n); std::ostringstream os; dmlc::JSONWriter writer(&os); @@ -424,8 +405,7 @@ std::string SaveJSON(const NodeRef& n) { return os.str(); } -ObjectPtr LoadJSON_(std::string json_str) { - LOG(INFO) << json_str; +ObjectRef LoadJSON(std::string json_str) { std::istringstream is(json_str); dmlc::JSONReader reader(&is); JSONGraph jgraph; @@ -442,16 +422,18 @@ ObjectPtr LoadJSON_(std::string json_str) { CHECK(temp.Load(&b64strm)); tensors.emplace_back(temp); } + ReflectionVTable* reflection = ReflectionVTable::Global(); + // node 0 is always null nodes.reserve(jgraph.nodes.size()); + for (const JSONNode& jnode : jgraph.nodes) { if (jnode.type_key.length() != 0) { - auto* f = dmlc::Registry::Find(jnode.type_key); - CHECK(f != nullptr) - << "Node type \'" << jnode.type_key << "\' is not registered in TVM"; - nodes.emplace_back(f->fcreator(jnode.global_key)); + ObjectPtr node = + reflection->CreateInitObject(jnode.type_key, jnode.global_key); + nodes.emplace_back(node); } else { - nodes.emplace_back(NodePtr()); + nodes.emplace_back(ObjectPtr()); } } CHECK_EQ(nodes.size(), jgraph.nodes.size()); @@ -467,101 +449,6 @@ ObjectPtr LoadJSON_(std::string json_str) { setter.Set(nodes[i].get()); } } - return nodes.at(jgraph.root); + return ObjectRef(nodes.at(jgraph.root)); } - -class NodeAttrSetter : public AttrVisitor { - public: - std::string type_key; - std::unordered_map attrs; - - void Visit(const char* key, double* value) final { - *value = GetAttr(key).operator double(); - } - void Visit(const char* key, int64_t* value) final { - *value = GetAttr(key).operator int64_t(); - } - void Visit(const char* key, uint64_t* value) final { - *value = GetAttr(key).operator uint64_t(); - } - void Visit(const char* key, int* value) final { - *value = GetAttr(key).operator int(); - } - void Visit(const char* key, bool* value) final { - *value = GetAttr(key).operator bool(); - } - void Visit(const char* key, std::string* value) final { - *value = GetAttr(key).operator std::string(); - } - void Visit(const char* key, void** value) final { - *value = GetAttr(key).operator void*(); - } - void Visit(const char* key, Type* value) final { - *value = GetAttr(key).operator Type(); - } - void Visit(const char* key, NodeRef* value) final { - *value = GetAttr(key).operator NodeRef(); - } - void Visit(const char* key, runtime::NDArray* value) final { - *value = GetAttr(key).operator runtime::NDArray(); - } - void Visit(const char* key, ObjectRef* value) final { - *value = GetAttr(key).operator ObjectRef(); - } - - private: - runtime::TVMArgValue GetAttr(const char* key) { - auto it = attrs.find(key); - if (it == attrs.end()) { - LOG(FATAL) << type_key << ": require field " << key; - } - runtime::TVMArgValue v = it->second; - attrs.erase(it); - return v; - } -}; - - -void InitNodeByPackedArgs(Node* n, const TVMArgs& args) { - NodeAttrSetter setter; - setter.type_key = n->GetTypeKey(); - CHECK_EQ(args.size() % 2, 0); - for (int i = 0; i < args.size(); i += 2) { - setter.attrs.emplace(args[i].operator std::string(), - args[i + 1]); - } - n->VisitAttrs(&setter); - if (setter.attrs.size() != 0) { - std::ostringstream os; - os << setter.type_key << " does not contain field "; - for (const auto &kv : setter.attrs) { - os << " " << kv.first; - } - LOG(FATAL) << os.str(); - } -} - -// API function to make node. -// args format: -// key1, value1, ..., key_n, value_n -void MakeNode(const TVMArgs& args, TVMRetValue* rv) { - std::string type_key = args[0]; - std::string empty_str; - auto* f = dmlc::Registry::Find(type_key); - CHECK(f != nullptr) - << "Node type \'" << type_key << "\' is not registered in TVM"; - TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1); - CHECK(f->fglobal_key == nullptr) - << "Cannot make node type \'" << type_key << "\' with global_key."; - NodePtr n = f->fcreator(empty_str); - if (n->IsInstance()) { - static_cast(n.get())->InitByPackedArgs(kwargs); - } else { - InitNodeByPackedArgs(n.get(), kwargs); - } - *rv = NodeRef(n); -} - -TVM_REGISTER_GLOBAL("make._Node") -.set_body(MakeNode); } // namespace tvm diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index e09ae0648534..65f5eed8d405 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -59,7 +59,7 @@ struct CachedFuncNode : public Node { /*! \brief Parameter usage states in the shape function. */ tvm::Array shape_func_param_states; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("target", &target); v->Visit("func_name", &func_name); v->Visit("inputs", &inputs); @@ -84,7 +84,7 @@ class CCacheKeyNode : public Node { /*! \brief The hardware target.*/ Target target; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("source_func", &source_func); v->Visit("target", &target); } @@ -141,7 +141,7 @@ class CCacheValueNode : public Node { /*! \brief usage statistics */ int use_count{0}; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("cached_func", &cached_func); v->Visit("use_count", &use_count); } @@ -191,7 +191,7 @@ class CompileEngineNode : public Node { virtual void Clear() = 0; // VisitAttrs - void VisitAttrs(AttrVisitor*) final {} + void VisitAttrs(AttrVisitor*) {} static constexpr const char* _type_key = "relay.CompileEngine"; TVM_DECLARE_NODE_TYPE_INFO(CompileEngineNode, Node); diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 2703b1c8634a..8c6daceedd5c 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file src/tvm/relay/interpreter.cc * \brief An interpreter for the Relay IR. */ @@ -116,6 +115,8 @@ RefValue RefValueNode::make(Value value) { TVM_REGISTER_API("relay._make.RefValue") .set_body_typed(RefValueNode::make); +TVM_REGISTER_NODE_TYPE(RefValueNode); + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const RefValueNode* node, tvm::IRPrinter* p) { @@ -135,6 +136,8 @@ ConstructorValue ConstructorValueNode::make(int32_t tag, TVM_REGISTER_API("relay._make.ConstructorValue") .set_body_typed(ConstructorValueNode::make); +TVM_REGISTER_NODE_TYPE(ConstructorValueNode); + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const ConstructorValueNode* node, tvm::IRPrinter* p) { @@ -207,7 +210,7 @@ class InterpreterStateNode : public Node { /*! \brief The call stack of the interpreter. */ Stack stack; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("current_expr", ¤t_expr); v->Visit("stack", &stack); } diff --git a/src/relay/backend/param_dict.cc b/src/relay/backend/param_dict.cc index 0b9a299ae59b..9bde3a0b4edd 100644 --- a/src/relay/backend/param_dict.cc +++ b/src/relay/backend/param_dict.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,19 +18,21 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file param_dict.cc * \brief Implementation and registration of parameter dictionary * serializing/deserializing functions. */ -#include "param_dict.h" - +#include #include #include #include #include +#include "param_dict.h" + + + namespace tvm { namespace relay { diff --git a/src/relay/backend/param_dict.h b/src/relay/backend/param_dict.h index 296c71ced644..e7695dc74c09 100644 --- a/src/relay/backend/param_dict.h +++ b/src/relay/backend/param_dict.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -45,7 +45,7 @@ struct NamedNDArrayNode : public ::tvm::Node { std::string name; tvm::runtime::NDArray array; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); v->Visit("array", &array); } diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index 9c670bf47e8c..12cebe5f5d3c 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file src/tvm/ir/adt.cc * \brief AST nodes for Relay algebraic data types (ADTs). */ diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 2032112f2a85..80f07904662f 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -61,7 +61,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(SourceNameNode) .set_creator(GetSourceNameNode) -.set_global_key([](const Node* n) { +.set_global_key([](const Object* n) { return static_cast(n)->name; }); @@ -88,7 +88,7 @@ TVM_REGISTER_NODE_TYPE(IdNode); TVM_REGISTER_API("relay._base.set_span") .set_body_typed([](NodeRef node_ref, Span sp) { - auto rn = node_ref.as_derived(); + auto rn = node_ref.as(); CHECK(rn); rn->span = sp; }); diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index b0f889c4a489..7bfe41c05058 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -195,7 +195,7 @@ NodePtr CreateOp(const std::string& name) { TVM_REGISTER_NODE_TYPE(OpNode) .set_creator(CreateOp) -.set_global_key([](const Node* n) { +.set_global_key([](const Object* n) { return static_cast(n)->name; }); diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 394ec7eaab82..b2a8396706f2 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -32,7 +32,7 @@ * - Otherwise, inline if the node is at the end of a scope and is used at most once. */ -#include +#include #include #include #include @@ -214,7 +214,7 @@ class PrettyPrinter : } Doc PrintFinal(const NodeRef& node) { - if (node.as_derived()) { + if (node.as()) { Expr expr = Downcast(node); dg_ = DependencyGraph::Create(&arena_, expr); } @@ -237,13 +237,13 @@ class PrettyPrinter : std::vector PrintFuncAttrs(const Attrs& attrs); Doc Print(const NodeRef& node, bool meta = false, bool try_inline = false) { - if (node.as_derived()) { + if (node.as()) { return PrintExpr(Downcast(node), meta, try_inline); - } else if (node.as_derived()) { + } else if (node.as()) { return PrintType(Downcast(node), meta); - } else if (node.as_derived()) { + } else if (node.as()) { return PrintPattern(Downcast(node), meta); - } else if (node.as_derived()) { + } else if (node.as()) { return PrintMod(Downcast(node)); } else { Doc doc; @@ -924,14 +924,11 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { void Visit(const char* key, DataType* value) final { PrintKV(key, PrintString(runtime::TVMType2String(Type2TVMType(*value)))); } - void Visit(const char* key, NodeRef* value) final { - PrintKV(key, parent_->PrintAttr(*value)); - } void Visit(const char* key, runtime::NDArray* value) final { LOG(FATAL) << "do not allow NDarray as argument"; } void Visit(const char* key, runtime::ObjectRef* obj) final { - LOG(FATAL) << "do not allow Object as argument"; + PrintKV(key, parent_->PrintAttr(*obj)); } private: diff --git a/src/relay/ir/type_functor.cc b/src/relay/ir/type_functor.cc index cde68c50daef..b93d9cc79433 100644 --- a/src/relay/ir/type_functor.cc +++ b/src/relay/ir/type_functor.cc @@ -132,7 +132,7 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) { if (const TypeVarNode* tin = new_type_param.as()) { type_params.push_back(GetRef(tin)); } else { - LOG(FATAL) << new_type_param << std::endl; + LOG(FATAL) << new_type_param; } } @@ -141,10 +141,10 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) { auto new_type_cs = VisitType(type_cs); changed = changed || !new_type_cs.same_as(type_cs); if (const TypeConstraintNode* tin = - new_type_cs.as_derived()) { + new_type_cs.as()) { type_constraints.push_back(GetRef(tin)); } else { - LOG(FATAL) << new_type_cs << std::endl; + LOG(FATAL) << new_type_cs; } } diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index 9143ae3a43b7..bbfb97c56dc2 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -140,7 +140,7 @@ class LayoutAlternatedExprNode : public TempExprNode { return tmp_memorizer.Transform(value, new_layout, old_layout); } - void VisitAttrs(AttrVisitor *v) final { + void VisitAttrs(AttrVisitor *v) { v->Visit("value", &value); v->Visit("old_layout", &old_layout); v->Visit("new_layout", &new_layout); diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index 94d09b7c236c..21992ab7abb7 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -18,8 +18,6 @@ */ /*! - * Copyright (c) 2018 by Contributors - * * \file deivce_annotation.cc * \brief Passes to rewrite annotated program and retrieve the device allocation * of expression. @@ -46,13 +44,15 @@ namespace relay { namespace { bool IsOnDeviceNode(const ExprNode* node) { - const auto* call_node = dynamic_cast(node); - return call_node != nullptr && call_node->attrs.as(); + if (!node->IsInstance()) return false; + const auto* call_node = static_cast(node); + return call_node->attrs.as(); } bool IsDeviceCopyNode(const ExprNode* node) { - const auto* call_node = dynamic_cast(node); - return call_node != nullptr && call_node->attrs.as(); + if (!node->IsInstance()) return false; + const auto* call_node = static_cast(node); + return call_node->attrs.as(); } } // namespace @@ -447,7 +447,8 @@ class DeviceInfo { static const ExprNode* GetDeviceCopyNode(const ExprNode* node) { if (IsDeviceCopyNode(node)) { return node; - } else if (const auto* call_node = dynamic_cast(node)) { + } else if (node->IsInstance()) { + const auto* call_node = static_cast(node); if (const auto* fn = call_node->op.as()) { const ExprNode* body = fn->body.operator->(); if (IsDeviceCopyNode(body)) { @@ -472,7 +473,8 @@ class DeviceInfo { for (auto it = post_visitor_.post_dfs_order_.crbegin(); it != post_visitor_.post_dfs_order_.crend(); ++it) { if (const auto* node = GetDeviceCopyNode(it->first)) { - last_copy_node = dynamic_cast(node); + CHECK(node->IsInstance()); + last_copy_node = static_cast(node); const auto* attrs = last_copy_node->attrs.as(); cur_dev_type = attrs->src_dev_type; if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type; diff --git a/src/relay/pass/eta_expand.cc b/src/relay/pass/eta_expand.cc index 612ababfe044..a5d04871ba95 100644 --- a/src/relay/pass/eta_expand.cc +++ b/src/relay/pass/eta_expand.cc @@ -37,14 +37,14 @@ Expr EtaExpand(const Expr& e, const Module& mod) { Type ret_type; if (e->IsInstance()) { - auto gvar_node = e.as_derived(); + auto gvar_node = e.as(); auto func = mod->Lookup(GetRef(gvar_node)); original_params = func->params; original_type_params = func->type_params; ret_type = func->ret_type; } else { CHECK(e->IsInstance()); - auto func = GetRef(e.as_derived()); + auto func = GetRef(e.as()); original_params = func->params; original_type_params = func->type_params; ret_type = func->ret_type; diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index 6defa35b5106..e13a50a99c58 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -176,7 +176,7 @@ class ScaledExprNode : public TempExprNode { return value; } - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("value", &value); v->Visit("axes", &axes); v->Visit("scale", &scale); @@ -664,7 +664,7 @@ class BackwardTransformerNode : } // solver is not serializable. - void VisitAttrs(tvm::AttrVisitor* v) final {} + void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "relay.fold_scale_axis.FBackwardTransformer"; TVM_DECLARE_NODE_TYPE_INFO(BackwardTransformerNode, Node); diff --git a/src/relay/pass/forward_rewrite.cc b/src/relay/pass/forward_rewrite.cc index 6c66d6e982a7..f7d463a0547e 100644 --- a/src/relay/pass/forward_rewrite.cc +++ b/src/relay/pass/forward_rewrite.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -47,7 +47,7 @@ class TempRealizer : private ExprMutator { return it->second; } else { Expr res; - if (const auto* temp = expr.as_derived()) { + if (const auto* temp = expr.as()) { res = temp->Realize(); } else { diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 928d8bd180e5..d2688620b0c3 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -102,7 +102,7 @@ class ModulePassNode : public PassNode { ModulePassNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } @@ -156,7 +156,7 @@ class FunctionPassNode : public PassNode { FunctionPassNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } @@ -211,7 +211,7 @@ class SequentialNode : public PassNode { /*! \brief A list of passes that used to compose a sequential pass. */ tvm::Array passes; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); v->Visit("passes", &passes); } diff --git a/src/relay/pass/quantize/annotate.cc b/src/relay/pass/quantize/annotate.cc index 38ffd9b59892..31e95fc6fb8d 100644 --- a/src/relay/pass/quantize/annotate.cc +++ b/src/relay/pass/quantize/annotate.cc @@ -41,7 +41,7 @@ class QAnnotateExprNode : public TempExprNode { Expr expr; QAnnotateKind kind; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("expr", &expr); v->Visit("kind", &kind); } diff --git a/src/relay/pass/quantize/partition.cc b/src/relay/pass/quantize/partition.cc index 6c7dc504b05e..f66aed3549a2 100644 --- a/src/relay/pass/quantize/partition.cc +++ b/src/relay/pass/quantize/partition.cc @@ -42,7 +42,7 @@ class QPartitionExprNode : public TempExprNode { /*! \brief The original expression */ Expr expr; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("expr", &expr); } diff --git a/src/relay/pass/quantize/quantize.cc b/src/relay/pass/quantize/quantize.cc index dafbc1d1007f..3d0e71edfb7c 100644 --- a/src/relay/pass/quantize/quantize.cc +++ b/src/relay/pass/quantize/quantize.cc @@ -18,8 +18,6 @@ */ /*! - * Copyright (c) 2018 by Contributors - * * \file quantize.cc * * \brief transform a graph to a low-bit graph diff --git a/src/relay/pass/quantize/quantize.h b/src/relay/pass/quantize/quantize.h index f193f9a63e0a..412bce0a394e 100644 --- a/src/relay/pass/quantize/quantize.h +++ b/src/relay/pass/quantize/quantize.h @@ -76,7 +76,7 @@ class QConfigNode : public Node { bool round_for_shift = true; Array debug_enabled_ops = Array(NodePtr(nullptr)); - void VisitAttrs(AttrVisitor* v) final { + void VisitAttrs(AttrVisitor* v) { v->Visit("nbit_input", &nbit_input); v->Visit("nbit_weight", &nbit_weight); v->Visit("nbit_activation", &nbit_activation); diff --git a/src/relay/pass/quantize/realize.cc b/src/relay/pass/quantize/realize.cc index cd367fdc0e5f..bdd0d732d146 100644 --- a/src/relay/pass/quantize/realize.cc +++ b/src/relay/pass/quantize/realize.cc @@ -56,7 +56,7 @@ class QRealizeIntExprNode : public QRealizeExprNode { Expr dom_scale; DataType dtype; - void VisitAttrs(tvm::AttrVisitor* v) final { + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); v->Visit("dom_scale", &dom_scale); v->Visit("dtype", &dtype); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index f2bf46af4b28..6035790225aa 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -153,7 +153,7 @@ class TypeSolver::Unifier : public TypeFunctor { // default: unify only if alpha-equal Type VisitTypeDefault_(const Node* op, const Type& tn) final { NodeRef nr = GetRef(op); - Type t1 = GetRef(nr.as_derived()); + Type t1 = GetRef(nr.as()); if (!AlphaEqual(t1, tn)) { return Type(nullptr); } @@ -411,7 +411,7 @@ class TypeSolver::Propagator : public TypeFunctor { void VisitTypeDefault_(const Node* op) override { NodeRef nr = GetRef(op); - Type t = GetRef(nr.as_derived()); + Type t = GetRef(nr.as()); UpdateRelSet(t); } @@ -495,7 +495,7 @@ class TypeSolver::Merger : public TypeFunctor { void VisitTypeDefault_(const Node* op) override { NodeRef nr = GetRef(op); - Type t = GetRef(nr.as_derived()); + Type t = GetRef(nr.as()); TransferLinks(t); } diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 90c3de857329..fe1cc14b304d 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -280,7 +280,7 @@ TVM_REGISTER_API("relay._analysis.free_vars") TVM_REGISTER_API("relay._analysis.bound_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; - if (x.as_derived()) { + if (x.as()) { *ret = BoundVars(Downcast(x)); } else { *ret = BoundVars(Downcast(x)); @@ -294,7 +294,7 @@ TVM_REGISTER_API("relay._analysis.free_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; Module mod = args[1]; - if (x.as_derived()) { + if (x.as()) { *ret = FreeTypeVars(Downcast(x), mod); } else { *ret = FreeTypeVars(Downcast(x), mod); @@ -305,7 +305,7 @@ TVM_REGISTER_API("relay._analysis.bound_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; Module mod = args[1]; - if (x.as_derived()) { + if (x.as()) { *ret = BoundTypeVars(Downcast(x), mod); } else { *ret = BoundTypeVars(Downcast(x), mod); @@ -316,7 +316,7 @@ TVM_REGISTER_API("relay._analysis.all_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; Module mod = args[1]; - if (x.as_derived()) { + if (x.as()) { *ret = AllTypeVars(Downcast(x), mod); } else { *ret = AllTypeVars(Downcast(x), mod); diff --git a/src/runtime/object.cc b/src/runtime/object.cc index d07612f6a963..5d71c2fd2fa1 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -73,13 +73,12 @@ class TypeContext { return child_tindex == parent_tindex; } - uint32_t GetOrAllocRuntimeTypeIndex(const char* key, + uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey, uint32_t static_tindex, uint32_t parent_tindex, uint32_t num_child_slots, bool child_slots_can_overflow) { std::lock_guard lock(mutex_); - std::string skey = key; auto it = type_key2index_.find(skey); if (it != type_key2index_.end()) { return it->second; @@ -106,7 +105,7 @@ class TypeContext { << "Conflicting static index " << static_tindex << " between " << type_table_[allocated_tindex].name << " and " - << key; + << skey; } else if (pinfo.allocated_slots + num_slots < pinfo.num_slots) { // allocate the slot from parent's reserved pool allocated_tindex = parent_tindex + pinfo.allocated_slots; @@ -152,11 +151,10 @@ class TypeContext { return type_table_[tindex].name_hash; } - uint32_t TypeKey2Index(const char* key) { - std::string skey = key; + uint32_t TypeKey2Index(const std::string& skey) { auto it = type_key2index_.find(skey); CHECK(it != type_key2index_.end()) - << "Cannot find type " << key; + << "Cannot find type " << skey; return it->second; } @@ -176,7 +174,7 @@ class TypeContext { std::unordered_map type_key2index_; }; -uint32_t Object::GetOrAllocRuntimeTypeIndex(const char* key, +uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key, uint32_t static_tindex, uint32_t parent_tindex, uint32_t num_child_slots, @@ -198,7 +196,7 @@ size_t Object::TypeIndex2KeyHash(uint32_t tindex) { return TypeContext::Global()->TypeIndex2KeyHash(tindex); } -uint32_t Object::TypeKey2Index(const char* key) { +uint32_t Object::TypeKey2Index(const std::string& key) { return TypeContext::Global()->TypeKey2Index(key); } @@ -210,7 +208,7 @@ class TVMObjectCAPI { } } - static uint32_t TypeKey2Index(const char* type_key) { + static uint32_t TypeKey2Index(const std::string& type_key) { return Object::TypeKey2Index(type_key); } }; diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index a7237db482ac..6e43b408978a 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index 4baf649c6e49..70a4c32bedac 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include