From f23d19495c7ec773c69b680b56feb2f36da0e1f6 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 14 Oct 2019 22:46:35 -0700 Subject: [PATCH] [RFC][RUNTIME] Introduce new object protocol. (#4115) * [RUNTIME] Introduce new object protocol. This PR introduces a new object protocol to unify the node and object. We also updated the existing runtime::vm code to make use of the new system. Update to the node will be done in a follow up PR. Other changes: - Remove object related code in json serializer as that code logic was not complete and we have a separate serializer for VM, can revisit later. * address review comment * Fix the child slot logic --- Makefile | 2 +- include/tvm/node/node.h | 4 +- include/tvm/runtime/memory.h | 115 +++++ include/tvm/runtime/object.h | 622 ++++++++++++++++++--------- include/tvm/runtime/packed_func.h | 25 +- include/tvm/runtime/vm.h | 101 ++++- python/tvm/_ffi/vmobj.py | 6 +- src/api/dsl_api.cc | 4 +- src/lang/reflection.cc | 30 +- src/relay/backend/vm/compiler.cc | 2 +- src/relay/backend/vm/deserializer.cc | 2 +- src/relay/backend/vm/serializer.cc | 7 +- src/relay/ir/pretty_printer.cc | 2 +- src/runtime/object.cc | 188 ++++++++ src/runtime/vm/object.cc | 115 ++--- src/runtime/vm/profiler/vm.cc | 2 +- src/runtime/vm/profiler/vm.h | 2 +- src/runtime/vm/vm.cc | 91 ++-- tests/cpp/object_protocol_test.cc | 103 +++++ tests/python/relay/test_vm.py | 2 +- 20 files changed, 1041 insertions(+), 384 deletions(-) create mode 100644 include/tvm/runtime/memory.h create mode 100644 src/runtime/object.cc create mode 100644 tests/cpp/object_protocol_test.cc diff --git a/Makefile b/Makefile index ce3c4757c474..d3ad1030b9f2 100644 --- a/Makefile +++ b/Makefile @@ -70,7 +70,7 @@ cpplint: python3 3rdparty/dmlc-core/scripts/lint.py vta cpp vta/include vta/src python3 3rdparty/dmlc-core/scripts/lint.py topi cpp topi/include; python3 3rdparty/dmlc-core/scripts/lint.py nnvm cpp nnvm/include nnvm/src; - python3 3rdparty/dmlc-core/scripts/lint.py tvm cpp include src verilog\ + python3 3rdparty/dmlc-core/scripts/lint.py tvm cpp include src \ examples/extension/src examples/graph_executor/src pylint: diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index e79b73d3beb9..cb18e46e9a5c 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -42,7 +42,7 @@ namespace runtime { // forward declaration class NDArray; // forward declaration -class Object; +class ObjectRef; } // namespace runtime /*! @@ -63,7 +63,7 @@ class TVM_DLL AttrVisitor { virtual void Visit(const char* key, DataType* value) = 0; virtual void Visit(const char* key, NodeRef* value) = 0; virtual void Visit(const char* key, runtime::NDArray* value) = 0; - virtual void Visit(const char* key, runtime::Object* value) = 0; + virtual void Visit(const char* key, runtime::ObjectRef* value) = 0; template::value>::type> void Visit(const char* key, ENum* ptr) { diff --git a/include/tvm/runtime/memory.h b/include/tvm/runtime/memory.h new file mode 100644 index 000000000000..6b4f01e4ac9b --- /dev/null +++ b/include/tvm/runtime/memory.h @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/runtime/memory.h + * \brief Runtime memory management. + */ +#ifndef TVM_RUNTIME_MEMORY_H_ +#define TVM_RUNTIME_MEMORY_H_ + +#include +#include +#include "object.h" + +namespace tvm { +namespace runtime { +/*! + * \brief Allocate an object using default allocator. + * \param args arguments to the constructor. + * \tparam T the node type. + * \return The NodePtr to the allocated object. + */ +template +inline ObjectPtr make_object(Args&&... args); + +// Detail implementations after this +// +// The current design allows swapping the +// allocator pattern when necessary. +// +// Possible future allocator optimizations: +// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr) +// - Thread-local object pools: one pool per size and alignment requirement. +// - Can specialize by type of object to give the specific allocator to each object. + +/*! + * \brief Base class of object allocators that implements make. + * Use curiously recurring template pattern. + * + * \tparam Derived The derived class. + */ +template +class ObjAllocatorBase { + public: + /*! + * \tparam T The type to be allocated. + * \tparam Args The constructor signature. + * \param args The arguments. + */ + template + inline ObjectPtr make(Args&&... args) { + using Handler = typename Derived::template Handler; + static_assert(std::is_base_of::value, + "make_node can only be used to create NodeBase"); + T* ptr = Handler::New(static_cast(this), + std::forward(args)...); + ptr->type_index_ = T::type_index(); + ptr->deleter_ = Handler::Deleter(); + return ObjectPtr(ptr); + } +}; + +// Simple allocator that uses new/delete. +class SimpleObjAllocator : + public ObjAllocatorBase { + public: + template + class Handler { + public: + template + static T* New(SimpleObjAllocator*, Args&&... args) { + // NOTE: the first argument is not needed for SimpleObjAllocator + // It is reserved for special allocators that needs to recycle + // the object to itself (e.g. in the case of object pool). + // + // In the case of an object pool, an allocator needs to create + // a special chunk memory that hides reference to the allocator + // and call allocator's release function in the deleter. + return new T(std::forward(args)...); + } + + static Object::FDeleter Deleter() { + return Deleter_; + } + + private: + static void Deleter_(Object* ptr) { + delete static_cast(ptr); + } + }; +}; + +template +inline ObjectPtr make_object(Args&&... args) { + return SimpleObjAllocator().make(std::forward(args)...); +} + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_MEMORY_H_ diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index e1e00ff59113..7b0653ae5485 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -16,117 +16,249 @@ * specific language governing permissions and limitations * under the License. */ - /*! - * Copyright (c) 2019 by Contributors * \file tvm/runtime/object.h * \brief A managed object in the TVM runtime. */ #ifndef TVM_RUNTIME_OBJECT_H_ #define TVM_RUNTIME_OBJECT_H_ -#include -#include +#include +#include #include -#include +#include "c_runtime_api.h" + +/*! + * \brief Whether or not use atomic reference counter. + * If the reference counter is not atomic, + * an object cannot be owned by multiple threads. + * We can, however, move an object across threads + */ +#ifndef TVM_OBJECT_ATOMIC_REF_COUNTER +#define TVM_OBJECT_ATOMIC_REF_COUNTER 1 +#endif + +#if TVM_OBJECT_ATOMIC_REF_COUNTER +#include +#endif // TVM_OBJECT_ATOMIC_REF_COUNTER namespace tvm { namespace runtime { -template -class ObjectPtr; -class Object; - -enum struct ObjectTag { - /*! \brief The tag of a tensor. */ - kTensor = 0U, - /*! \brief The tag of a closure. */ - kClosure = 1U, - /*! \brief The tag of a structure. */ - kDatatype = 2U, +/*! \brief list of the type index. */ +enum TypeIndex { + /*! \brief Root object type. */ + kRoot = 0, + kVMTensor = 1, + kVMClosure = 2, + kVMDatatype = 3, + kStaticIndexEnd, + /*! \brief Type index is allocated during runtime. */ + kDynamic = kStaticIndexEnd }; -std::ostream& operator<<(std::ostream& os, const ObjectTag&); - -struct ObjectCell { +/*! + * \brief base class of all object containers. + * + * Sub-class of objects should declare the following static constexpr fields: + * + * - _type_index: + * Static type index of the object, if assigned to TypeIndex::kDynamic + * the type index will be assigned during runtime. + * Runtime type index can be accessed by ObjectType::type_index(); + * - _type_key: + * The unique string identifier of tyep type. + * - _type_final: + * Whether the type is terminal type(there is no subclass of the type in the object system). + * This field is automatically set by marco TVM_DECLARE_FINAL_OBJECT_INFO + * It is still OK to sub-class a terminal object type T and construct it using make_object. + * But IsInstance check will only show that the object type is T(instead of the sub-class). + * + * The following two fields are necessary for base classes that can be sub-classed. + * + * - _type_child_slots: + * Number of reserved type index slots for child classes. + * Used for runtime optimization for type checking in IsInstance. + * If an object's type_index is within range of [type_index, type_index + _type_child_slots] + * Then the object can be quickly decided as sub-class of the current object class. + * If not, a fallback mechanism is used to check the global type table. + * Recommendation: set to estimate number of children needed. + * - _type_child_slots_can_overflow: + * Whether we can add additional child classes even if the number of child classes + * exceeds the _type_child_slots. A fallback mechanism to check global type table will be used. + * Recommendation: set to false for optimal runtime speed if we know exact number of children. + * + * Two macros are used to declare helper functions in the object: + * - Use TVM_DECLARE_BASE_OBJECT_INFO for object classes that can be sub-classed. + * - Use TVM_DECLARE_FINAL_OBJECT_INFO for object classes that cannot be sub-classed. + * + * New objects can be created using make_object function. + * Which will automatically populate the type_index and deleter of the object. + * + * \sa make_object + * \sa ObjectPtr + * \sa ObjectRef + * + * \code + * + * // Create a base object + * class BaseObj : public Object { + * public: + * // object fields + * int field0; + * + * // object properties + * static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + * static constexpr const char* _type_key = "test.BaseObj"; + * TVM_DECLARE_BASE_OBJECT_INFO(BaseObj, Object); + * }; + * + * class ObjLeaf : public ObjBase { + * public: + * // fields + * int child_field0; + * // object properties + * static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + * static constexpr const char* _type_key = "test.LeafObj"; + * TVM_DECLARE_BASE_OBJECT_INFO(LeaffObj, Object); + * }; + * + * // The following code should be put into a cc file. + * TVM_REGISTER_OBJECT_TYPE(ObjBase); + * TVM_REGISTER_OBJECT_TYPE(ObjLeaf); + * + * // Usage example. + * void TestObjects() { + * // create an object + * ObjectRef leaf_ref(make_object()); + * // cast to a specific instance + * const LeafObj* leaf_ptr = leaf_ref.as(); + * CHECK(leaf_ptr != nullptr); + * // can also cast to the base class. + * CHECK(leaf_ref.as() != nullptr); + * } + * + * \endcode + */ +class Object { public: /*! - * \brief The type of object deleter. - * \param The self pointer to the ObjectCell. + * \brief Object deleter + * \param self pointer to the Object. */ - typedef void (*FDeleter)(ObjectCell* self); - - /*! \brief The tag of the object. - * - * Describes which type of value - * is represented by this object. + typedef void (*FDeleter)(Object* self); + /*! \return The internal type index of the object. */ + uint32_t type_index() const { + return type_index_; + } + /*! + * Check if the object is an instance of TargetType. + * \tparam TargetType The target type to be checked. + * \return Whether the target type is true. */ - ObjectTag tag; + template + inline bool IsInstance() const; + +#if TVM_OBJECT_ATOMIC_REF_COUNTER + using RefCounterType = std::atomic; +#else + using RefCounterType = int32_t; +#endif + + // Object type properties + static constexpr const char* _type_key = "Object"; + static constexpr bool _type_final = false; + static constexpr uint32_t _type_child_slots = 0; + static constexpr bool _type_child_slots_can_overflow = true; + static const uint32_t _GetOrAllocRuntimeTypeIndex() { + return 0; + } + protected: + // The fields of the base object cell. + /*! \brief Type index(tag) that indicates the type of the object. */ + uint32_t type_index_{0}; + /*! \brief The internal reference counter */ + RefCounterType ref_counter_{0}; /*! - * \brief Increment the reference count. + * \brief deleter of this object to enable customized allocation. + * If the deleter is nullptr, no deletion will be performed. + * The creator of the object must always set the deleter field properly. */ - void IncRef() { ref_counter_.fetch_add(1, std::memory_order_relaxed); } + FDeleter deleter_ = nullptr; + // Invariant checks. + static_assert(sizeof(int32_t) == sizeof(RefCounterType) && + alignof(int32_t) == sizeof(RefCounterType), + "RefCounter ABI check."); /*! - * \brief Decrement the reference count. + * \brief Get the type index using type key. + * + * When the function is first time called for a type, + * it will register the type to the type table in the runtime. + * If the static_tindex is TypeIndex::kDynamic, the function will + * allocate a runtime type index. + * Otherwise, we will populate the type table and return the static index. + * + * \param key the type key. + * \param static_tindex The current _type_index field. + * can be TypeIndex::kDynamic. + * \param parent_tindex The index of the parent. + * \param type_child_slots Number of slots reserved for its children. + * \param type_child_slots_can_overflow Whether to allow child to overflow the slots. + * \return The allocated type index. */ - void DecRef() { - if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { - std::atomic_thread_fence(std::memory_order_acquire); - if (this->deleter_ != nullptr) { - (*this->deleter_)(this); - } - } - } + TVM_DLL static uint32_t GetOrAllocRuntimeTypeIndex( + const char* key, + uint32_t static_tindex, + uint32_t parent_tindex, + uint32_t type_child_slots, + bool type_child_slots_can_overflow); - protected: - // default constructor and copy constructor - ObjectCell() {} - - explicit ObjectCell(ObjectTag tag) : tag(tag) {} - - // override the copy and assign constructors to do nothing. - // This is to make sure only contents, but not deleter and ref_counter - // are copied when a child class copies itself. - ObjectCell(const ObjectCell& other) { // NOLINT(*) - } - - ObjectCell(ObjectCell&& other) { // NOLINT(*) - } - - ObjectCell& operator=(const ObjectCell& other) { // NOLINT(*) - return *this; - } + /*! + * \brief Get the type key of the corresponding index from runtime. + * \param tindex The type index. + */ + TVM_DLL static std::string TypeIndex2Key(uint32_t tindex); - ObjectCell& operator=(ObjectCell&& other) { // NOLINT(*) - return *this; - } + /*! + * \brief Get the type index of the corresponding key from runtime. + * \param key The type key. + */ + TVM_DLL static uint32_t TypeKey2Index(const char* key); private: - /*! \brief Internal reference counter */ - std::atomic ref_counter_{0}; + // reference counter related operations + /*! \brief developer function, increases reference counter. */ + inline void IncRef(); /*! - * \brief deleter of this object to enable customized allocation. - * If the deleter is nullptr, no deletion will be performed. - * The creator of the Node must always set the deleter field properly. + * \brief developer function, decrease reference counter. + * \note The deleter will be called when ref_counter_ becomes zero. */ - FDeleter deleter_ = nullptr; - - int use_count() const { return ref_counter_.load(std::memory_order_relaxed); } - - // friend declaration - template + inline void DecRef(); + /*! + * \return The usage count of the cell. + * \note We use stl style naming to be consistent with known API in shared_ptr. + */ + inline int use_count() const; + /*! + * \brief Check of this object is derived from the parent. + * \param parent_tindex The parent type index. + * \return The derivation results. + */ + TVM_DLL bool DerivedFrom(uint32_t parent_tindex) const; + // friend classes + template + friend class ObjAllocatorBase; + template friend class ObjectPtr; - - template - friend ObjectPtr MakeObject(Args&&...); + friend class TVMRetValue; }; /*! * \brief A custom smart pointer for Object. - * must be subclass of NodeBase * \tparam T the content data type. + * \sa make_object */ template class ObjectPtr { @@ -159,7 +291,6 @@ class ObjectPtr { : data_(other.data_) { other.data_ = nullptr; } - /*! * \brief move constructor * \param other The value to be moved @@ -171,10 +302,10 @@ class ObjectPtr { "can only assign of child class ObjectPtr to parent"); other.data_ = nullptr; } - /*! \brief destructor */ - ~ObjectPtr() { this->reset(); } - + ~ObjectPtr() { + this->reset(); + } /*! * \brief Swap this array with another Object * \param other The other Object @@ -182,24 +313,24 @@ class ObjectPtr { void swap(ObjectPtr& other) { // NOLINT(*) std::swap(data_, other.data_); } - /*! * \return Get the content of the pointer */ - T* get() const { return static_cast(data_); } - + T* get() const { + return static_cast(data_); + } /*! * \return The pointer */ - T* operator->() const { return get(); } - + T* operator->() const { + return get(); + } /*! * \return The reference */ T& operator*() const { // NOLINT(*) return *get(); } - /*! * \brief copy assignmemt * \param other The value to be assigned. @@ -211,7 +342,6 @@ class ObjectPtr { ObjectPtr(other).swap(*this); // NOLINT(*) return *this; } - /*! * \brief move assignmemt * \param other The value to be assigned. @@ -222,7 +352,6 @@ class ObjectPtr { ObjectPtr(std::move(other)).swap(*this); // NOLINT(*) return *this; } - /*! \brief reset the content of ptr to be nullptr */ void reset() { if (data_ != nullptr) { @@ -230,163 +359,238 @@ class ObjectPtr { data_ = nullptr; } } - /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { return data_ != nullptr ? data_->use_count() : 0; } - + int use_count() const { + return data_ != nullptr ? data_->use_count() : 0; + } /*! \return whether the reference is unique */ - bool unique() const { return data_ != nullptr && data_->use_count() == 1; } - + bool unique() const { + return data_ != nullptr && data_->use_count() == 1; + } /*! \return Whether two ObjectPtr do not equal each other */ - bool operator==(const ObjectPtr& other) const { return data_ == other.data_; } - + bool operator==(const ObjectPtr& other) const { + return data_ == other.data_; + } /*! \return Whether two ObjectPtr equals each other */ - bool operator!=(const ObjectPtr& other) const { return data_ != other.data_; } - + bool operator!=(const ObjectPtr& other) const { + return data_ != other.data_; + } /*! \return Whether the pointer is nullptr */ - bool operator==(std::nullptr_t null) const { return data_ == nullptr; } - + bool operator==(std::nullptr_t null) const { + return data_ == nullptr; + } /*! \return Whether the pointer is not nullptr */ - bool operator!=(std::nullptr_t null) const { return data_ != nullptr; } - - /* ObjectPtr's support custom allocators. - * - * The below allocator represents the simplest - * possible impl. It can be easily swapped - * for customized executor's, different allocation - * strategies, and so on. - * - * See memory.h for more discussion on NodePtr's - * allocator. - */ - class StdAllocator { - public: - template - static T* New(Args&&... args) { - return new T(std::forward(args)...); - } - - static ObjectCell::FDeleter Deleter() { return Deleter_; } - - private: - static void Deleter_(ObjectCell* ptr) { delete static_cast(ptr); } - }; - - template - ObjectPtr As() const { - auto ptr = reinterpret_cast(get()); - return ObjectPtr(ptr); + bool operator!=(std::nullptr_t null) const { + return data_ != nullptr; } private: /*! \brief internal pointer field */ - ObjectCell* data_{nullptr}; + Object* data_{nullptr}; /*! * \brief constructor from NodeBase - * \param data The node base pointer + * \param data The data pointer */ - // TODO(jroesch): NodePtr design doesn't really work here due to the passing. - public: - explicit ObjectPtr(ObjectCell* data) : data_(data) { + explicit ObjectPtr(Object* data) : data_(data) { if (data != nullptr) { data_->IncRef(); } } - - private: - template - friend ObjectPtr MakeObject(Args&&...); - template + // friend classes + friend class Object; + friend class ObjectRef; + template friend class ObjectPtr; - friend class NDArray; + template + friend class ObjAllocatorBase; friend class TVMPODValue_; - friend class TVMArgValue; + friend class TVMArgsSetter; friend class TVMRetValue; - friend class RPCWrappedFunc; }; -struct TensorCell; -struct DatatypeCell; -struct ClosureCell; +/*! \brief Base class of all object reference */ +class ObjectRef { + public: + /*! \brief default constructor */ + ObjectRef() = default; + /*! \brief Constructor from existing object ptr */ + explicit ObjectRef(ObjectPtr data) : data_(data) {} + /*! \return the internal object pointer */ + inline const Object* get() const; + /*! \return the internal node pointer */ + inline const Object* operator->() const; + /*! + * \brief Try to downcast the internal Object to a + * raw pointer of a corresponding type. + * + * The function will return a nullptr if the cast failed. + * + * if (const Add *add = node_ref.As()) { + * // This is an add node + * } + * \tparam ObjectType the target type, must be a subtype of Object/ + */ + template + inline const ObjectType* as() const; + + /*! \brief type indicate the container type */ + using ContainerType = Object; + + protected: + /*! \brief Internal pointer that backs the reference. */ + ObjectPtr data_; + // friend classes. + friend class TVMRetValue; + friend class TVMArgsSetter; +}; /*! - * \brief A managed object in the TVM runtime. - * - * For example a tuple, list, closure, and so on. + * \brief helper macro to declare a base object type that can be inheritated. + * \param TypeName The name of the current type. + * \param ParentType The name of the ParentType + */ +#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ + static const uint32_t type_index() { \ + if (_type_index != TypeIndex::kDynamic) return _type_index; \ + return _GetOrAllocRuntimeTypeIndex(); \ + } \ + static const uint32_t _GetOrAllocRuntimeTypeIndex() { \ + static uint32_t tidx = GetOrAllocRuntimeTypeIndex( \ + TypeName::_type_key, \ + TypeName::_type_index, \ + ParentType::_GetOrAllocRuntimeTypeIndex(), \ + TypeName::_type_child_slots, \ + TypeName::_type_child_slots_can_overflow); \ + return tidx; \ + } \ + +/*! + * \brief helper macro to declare type information in a final class. + * \param TypeName The name of the current type. + * \param ParentType The name of the ParentType + */ +#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \ + static const constexpr bool _type_final = true; \ + static const constexpr int _type_child_slots = 0; \ + TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ + + +/*! + * \brief Helper macro to register the object type to runtime. + * Makes sure that the runtime type table is correctly populated. * - * Maintains a reference count for the object. + * Use this macro in the cc file for each terminal class. */ -class Object { - public: - ObjectPtr ptr_; - explicit Object(ObjectPtr ptr) : ptr_(ptr) {} - explicit Object(ObjectCell* ptr) : ptr_(ptr) {} - Object() : ptr_() {} - Object(const Object& obj) : ptr_(obj.ptr_) {} - ObjectCell* operator->() { return this->ptr_.operator->(); } - const ObjectCell* operator->() const { return this->ptr_.operator->(); } - - /*! \brief Construct a tensor object. */ - static Object Tensor(const NDArray& data); - /*! \brief Construct a datatype object. */ - static Object Datatype(size_t tag, const std::vector& fields); - /*! \brief Construct a tuple object. */ - static Object Tuple(const std::vector& fields); - /*! \brief Construct a closure object. */ - static Object Closure(size_t func_index, const std::vector& free_vars); - - ObjectPtr AsTensor() const; - ObjectPtr AsDatatype() const; - ObjectPtr AsClosure() const; -}; +#define TVM_REGISTER_OBJECT_TYPE(TypeName) \ + static DMLC_ATTRIBUTE_UNUSED uint32_t __make_Object_tidx ## _ ## TypeName ## __ = \ + TypeName::_GetOrAllocRuntimeTypeIndex() -/*! \brief An object containing an NDArray. */ -struct TensorCell : public ObjectCell { - /*! \brief The NDArray. */ - NDArray data; - explicit TensorCell(const NDArray& data) : ObjectCell(ObjectTag::kTensor), data(data) {} -}; -/*! \brief An object representing a structure or enumeration. */ -struct DatatypeCell : public ObjectCell { - /*! \brief The tag representing the constructor used. */ - size_t tag; - /*! \brief The fields of the structure. */ - std::vector fields; +#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + TypeName() {} \ + explicit TypeName( \ + ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ + : ParentType(n) {} \ + const ObjectName* operator->() const { \ + return static_cast(data_.get()); \ + } \ + operator bool() const { return data_ != nullptr; } \ + using ContainerType = ObjectName; - DatatypeCell(size_t tag, const std::vector& fields) - : ObjectCell(ObjectTag::kDatatype), tag(tag), fields(fields) {} -}; -/*! \brief An object representing a closure. */ -struct ClosureCell : public ObjectCell { - /*! \brief The index into the VM function table. */ - size_t func_index; - /*! \brief The free variables of the closure. */ - std::vector free_vars; +// Implementations details below +// Object reference counting. +#if TVM_OBJECT_ATOMIC_REF_COUNTER - ClosureCell(size_t func_index, const std::vector& free_vars) - : ObjectCell(ObjectTag::kClosure), func_index(func_index), free_vars(free_vars) {} -}; +inline void Object::IncRef() { + ref_counter_.fetch_add(1, std::memory_order_relaxed); +} -/*! \brief Extract the NDArray from a tensor object. */ -NDArray ToNDArray(const Object& obj); +inline void Object::DecRef() { + if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { + std::atomic_thread_fence(std::memory_order_acquire); + if (this->deleter_ != nullptr) { + (*this->deleter_)(this); + } + } +} -/*! - * \brief Allocate a node object. - * \param args arguments to the constructor. - * \tparam T the node type. - * \return The NodePtr to the allocated object. - */ -template -inline ObjectPtr MakeObject(Args&&... args) { - using Allocator = typename ObjectPtr::StdAllocator; - static_assert(std::is_base_of::value, "MakeObject can only be used to create "); - T* node = Allocator::New(std::forward(args)...); - node->deleter_ = Allocator::Deleter(); - return ObjectPtr(node); +inline int Object::use_count() const { + return ref_counter_.load(std::memory_order_relaxed); +} + +#else + +inline void Object::IncRef() { + ++ref_counter_; } +inline void Object::DecRef() { + if (--ref_counter == 0) { + if (this->deleter_ != nullptr) { + (*this->deleter_)(this); + } + } +} + +inline int Object::use_count() const { + return ref_counter_; +} + +#endif // TVM_OBJECT_ATOMIC_REF_COUNTER + +template +inline bool Object::IsInstance() const { + const Object* self = this; + // NOTE: the following code can be optimized by + // compiler dead-code elimination for already known constants. + if (self != nullptr) { + // Everything is a subclass of object. + if (std::is_same::value) return true; + if (TargetType::_type_final) { + // if the target type is a final type + // then we only need to check the equivalence. + return self->type_index_ == TargetType::type_index(); + } else { + // if target type is a non-leaf type + // Check if type index falls into the range of reserved slots. + uint32_t begin = TargetType::type_index(); + // The condition will be optimized by constant-folding. + if (TargetType::_type_child_slots != 0) { + uint32_t end = begin + TargetType::_type_child_slots; + if (self->type_index_ >= begin && self->type_index_ < end) return true; + } else { + if (self->type_index_ == begin) return true; + } + if (!TargetType::_type_child_slots_can_overflow) return false; + // Invariance: parent index is always smaller than the child. + if (self->type_index_ < TargetType::type_index()) return false; + // The rare slower-path, check type hierachy. + return self->DerivedFrom(TargetType::type_index()); + } + } else { + return false; + } +} + +inline const Object* ObjectRef::get() const { + return data_.data_; +} + +inline const Object* ObjectRef::operator->() const { + return get(); +} + +template +inline const ObjectType* ObjectRef::as() const { + if (data_ != nullptr && + data_->IsInstance()) { + return static_cast(data_.get()); + } else { + return nullptr; + } +} } // namespace runtime } // namespace tvm + #endif // TVM_RUNTIME_OBJECT_H_ diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 1ebddb805d0c..5b71bbc66142 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -489,10 +489,10 @@ class TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer); return NDArray(static_cast(value_.v_handle)); } - operator Object() const { - if (type_code_ == kNull) return Object(); + operator ObjectRef() const { + if (type_code_ == kNull) return ObjectRef(ObjectPtr(nullptr)); TVM_CHECK_TYPE_CODE(type_code_, kObjectCell); - return Object(static_cast(value_.v_handle)); + return ObjectRef(ObjectPtr(static_cast(value_.v_handle))); } operator TVMContext() const { TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); @@ -566,7 +566,7 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator TVMContext; - using TVMPODValue_::operator Object; + using TVMPODValue_::operator ObjectRef; // conversion operator. operator std::string() const { @@ -662,7 +662,7 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator NDArray; - using TVMPODValue_::operator Object; + using TVMPODValue_::operator ObjectRef; TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); } @@ -759,11 +759,12 @@ class TVMRetValue : public TVMPODValue_ { other.data_ = nullptr; return *this; } - TVMRetValue& operator=(Object other) { + TVMRetValue& operator=(ObjectRef other) { this->Clear(); type_code_ = kObjectCell; - value_.v_handle = other.ptr_.data_; - other.ptr_.data_ = nullptr; + // move the handle out + value_.v_handle = other.data_.data_; + other.data_.data_ = nullptr; return *this; } TVMRetValue& operator=(PackedFunc f) { @@ -862,7 +863,7 @@ class TVMRetValue : public TVMPODValue_ { break; } case kObjectCell: { - *this = other.operator Object(); + *this = other.operator ObjectRef(); break; } default: { @@ -913,7 +914,7 @@ class TVMRetValue : public TVMPODValue_ { break; } case kObjectCell: { - static_cast(value_.v_handle)->DecRef(); + static_cast(value_.v_handle)->DecRef(); break; } } @@ -1161,6 +1162,10 @@ class TVMArgsSetter { values_[i].v_handle = value.data_; type_codes_[i] = kNDArrayContainer; } + void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*) + values_[i].v_handle = value.data_.data_; + type_codes_[i] = kObjectCell; + } void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*) if (value.type_code() == kStr) { values_[i].v_str = value.ptr()->c_str(); diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index fe871882935f..aa8543d569af 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2018 by Contributors * \file tvm/runtime/vm.h * \brief A virtual machine for executing Relay programs. */ @@ -36,6 +35,75 @@ namespace tvm { namespace runtime { namespace vm { +/*! \brief An object containing an NDArray. */ +class TensorObj : public Object { + public: + /*! \brief The NDArray. */ + NDArray data; + + static constexpr const uint32_t _type_index = TypeIndex::kVMTensor; + static constexpr const char* _type_key = "vm.Tensor"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorObj, Object); +}; + +/*! \brief reference to tensor. */ +class Tensor : public ObjectRef { + public: + explicit Tensor(NDArray data); + + TVM_DEFINE_OBJECT_REF_METHODS(Tensor, ObjectRef, TensorObj); +}; + + +/*! \brief An object representing a structure or enumeration. */ +class DatatypeObj : public Object { + public: + /*! \brief The tag representing the constructor used. */ + size_t tag; + /*! \brief The fields of the structure. */ + std::vector fields; + + static constexpr const uint32_t _type_index = TypeIndex::kVMDatatype; + static constexpr const char* _type_key = "vm.Datatype"; + TVM_DECLARE_FINAL_OBJECT_INFO(DatatypeObj, Object); +}; + +/*! \brief reference to data type. */ +class Datatype : public ObjectRef { + public: + Datatype(size_t tag, std::vector fields); + + /*! + * \brief construct a tuple object. + * \param fields The fields of the tuple. + * \return The constructed tuple type. + */ + static Datatype Tuple(std::vector fields); + + TVM_DEFINE_OBJECT_REF_METHODS(Datatype, ObjectRef, DatatypeObj); +}; + +/*! \brief An object representing a closure. */ +class ClosureObj : public Object { + public: + /*! \brief The index into the VM function table. */ + size_t func_index; + /*! \brief The free variables of the closure. */ + std::vector free_vars; + + static constexpr const uint32_t _type_index = TypeIndex::kVMClosure; + static constexpr const char* _type_key = "vm.Closure"; + TVM_DECLARE_FINAL_OBJECT_INFO(ClosureObj, Object); +}; + +/*! \brief reference to closure. */ +class Closure : public ObjectRef { + public: + Closure(size_t func_index, std::vector free_vars); + + TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj); +}; + /*! \brief Magic number for NDArray list file */ constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; @@ -193,7 +261,7 @@ struct Instruction { static Instruction Ret(RegName return_reg); /*! \brief Construct a fatal instruction. * \return The fatal instruction. - * */ + * */ static Instruction Fatal(); /*! \brief Construct a invoke packed instruction. * \param packed_index The index of the packed function. @@ -348,7 +416,7 @@ struct VMFrame { const Instruction* code; /*! \brief Statically allocated space for objects */ - std::vector register_file; + std::vector register_file; /*! \brief Register in caller's frame to put return value */ RegName caller_return_register; @@ -406,8 +474,11 @@ class VirtualMachine : public runtime::ModuleNode { * * \note The return value will be stored in the last output_size slots of args. */ - virtual void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, - Index output_size, const std::vector& args); + virtual void InvokePacked(Index packed_index, + const PackedFunc& func, + Index arg_count, + Index output_size, + const std::vector& args); virtual ~VirtualMachine() {} @@ -424,7 +495,7 @@ class VirtualMachine : public runtime::ModuleNode { /*! \brief The current stack of call frames. */ std::vector frames; /*! \brief The global constant pool. */ - std::vector constants; + std::vector constants; /*! \brief The fuction table index of the current function. */ Index func_index; /*! \brief The current pointer to the code section. */ @@ -433,7 +504,7 @@ class VirtualMachine : public runtime::ModuleNode { Index pc; /*! \brief The special return register. */ - Object return_register; + ObjectRef return_register; /*! \brief The set of TVM contexts the VM is currently executing on. */ std::vector ctxs; @@ -449,13 +520,13 @@ class VirtualMachine : public runtime::ModuleNode { * \param reg The register to write to. * \param obj The object to write to. */ - inline void WriteRegister(RegName reg, const Object& obj); + inline void WriteRegister(RegName reg, const ObjectRef& obj); /*! \brief Read a VM register. * \param reg The register to read from. * \return The read object. */ - inline Object ReadRegister(RegName reg) const; + inline ObjectRef ReadRegister(RegName reg) const; /*! \brief Read a VM register and cast it to int32_t * \param reg The register to read from. @@ -468,15 +539,16 @@ class VirtualMachine : public runtime::ModuleNode { * \param args The arguments to the function. * \return The object representing the result. */ - Object Invoke(const VMFunction& func, const std::vector& args); + ObjectRef Invoke(const VMFunction& func, const std::vector& args); // TODO(@jroesch): I really would like this to be a global variable. - /*! \brief Invoke a VM function by name. + /*! + * \brief Invoke a VM function by name. * \param name The function's name. * \param args The arguments to the function. * \return The object representing the result. */ - Object Invoke(const std::string& name, const std::vector& args); + ObjectRef Invoke(const std::string& name, const std::vector& args); VirtualMachine() : functions(), frames(), func_index(0), code(nullptr), pc(0) {} @@ -513,11 +585,10 @@ class VirtualMachine : public runtime::ModuleNode { * * This does not begin execution of the VM. */ - void InvokeGlobal(const VMFunction& func, const std::vector& args); - + void InvokeGlobal(const VMFunction& func, const std::vector& args); /*! \brief The parameter name to data mapping. */ - std::unordered_map params_; + std::unordered_map params_; }; } // namespace vm diff --git a/python/tvm/_ffi/vmobj.py b/python/tvm/_ffi/vmobj.py index 95b831ad50e8..ea3431aa973c 100644 --- a/python/tvm/_ffi/vmobj.py +++ b/python/tvm/_ffi/vmobj.py @@ -44,9 +44,9 @@ class ObjectTag(object): """Type code used in API calls""" - TENSOR = 0 - CLOSURE = 1 - DATATYPE = 2 + TENSOR = 1 + CLOSURE = 2 + DATATYPE = 3 class Object(_ObjectBase): diff --git a/src/api/dsl_api.cc b/src/api/dsl_api.cc index 9b91d4fc91dd..89e999f73edb 100644 --- a/src/api/dsl_api.cc +++ b/src/api/dsl_api.cc @@ -92,7 +92,7 @@ struct APIAttrGetter : public AttrVisitor { found_ref_object = true; } } - void Visit(const char* key, runtime::Object* value) final { + void Visit(const char* key, runtime::ObjectRef* value) final { if (skey == key) { *ret = value[0]; found_ref_object = true; @@ -133,7 +133,7 @@ struct APIAttrDir : public AttrVisitor { void Visit(const char* key, runtime::NDArray* value) final { names->push_back(key); } - void Visit(const char* key, runtime::Object* value) final { + void Visit(const char* key, runtime::ObjectRef* value) final { names->push_back(key); } }; diff --git a/src/lang/reflection.cc b/src/lang/reflection.cc index bc3d2895b811..651312a949c4 100644 --- a/src/lang/reflection.cc +++ b/src/lang/reflection.cc @@ -54,7 +54,7 @@ inline Type String2Type(std::string s) { } using runtime::Object; -using runtime::ObjectCell; +using runtime::ObjectRef; // indexer to index all the ndoes class NodeIndexer : public AttrVisitor { @@ -63,8 +63,6 @@ class NodeIndexer : public AttrVisitor { std::vector node_list{nullptr}; std::unordered_map tensor_index; std::vector tensor_list; - std::unordered_map vm_obj_index; - std::vector vm_obj_list; void Visit(const char* key, double* value) final {} void Visit(const char* key, int64_t* value) final {} @@ -86,12 +84,8 @@ class NodeIndexer : public AttrVisitor { tensor_list.push_back(ptr); } - void Visit(const char* key, Object* value) final { - ObjectCell* ptr = value->ptr_.get(); - if (vm_obj_index.count(ptr)) return; - CHECK_EQ(vm_obj_index.size(), vm_obj_list.size()); - vm_obj_index[ptr] = vm_obj_list.size(); - vm_obj_list.push_back(ptr); + void Visit(const char* key, ObjectRef* value) final { + LOG(FATAL) << "Do not support json serialize non-node object"; } // make index of all the children of node @@ -177,7 +171,6 @@ class JSONAttrGetter : public AttrVisitor { public: const std::unordered_map* node_index_; const std::unordered_map* tensor_index_; - const std::unordered_map* vm_obj_index_; JSONNode* node_; void Visit(const char* key, double* value) final { @@ -212,9 +205,8 @@ class JSONAttrGetter : public AttrVisitor { node_->attrs[key] = std::to_string( tensor_index_->at(const_cast((*value).operator->()))); } - void Visit(const char* key, Object* value) final { - node_->attrs[key] = std::to_string( - vm_obj_index_->at(value->ptr_.get())); + void Visit(const char* key, ObjectRef* value) final { + LOG(FATAL) << "Do not support json serialize non-node object"; } // Get the node void Get(Node* node) { @@ -269,7 +261,6 @@ class JSONAttrSetter : public AttrVisitor { public: const std::vector >* node_list_; const std::vector* tensor_list_; - const std::vector* vm_obj_list_; JSONNode* node_; @@ -325,11 +316,8 @@ class JSONAttrSetter : public AttrVisitor { CHECK_LE(index, tensor_list_->size()); *value = tensor_list_->at(index); } - void Visit(const char* key, Object* value) final { - size_t index; - ParseValue(key, &index); - CHECK_LE(index, vm_obj_list_->size()); - *value = vm_obj_list_->at(index); + void Visit(const char* key, ObjectRef* value) final { + LOG(FATAL) << "Do not support json serialize non-node object"; } // set node to be current JSONNode void Set(Node* node) { @@ -508,8 +496,8 @@ class NodeAttrSetter : public AttrVisitor { void Visit(const char* key, runtime::NDArray* value) final { *value = GetAttr(key).operator runtime::NDArray(); } - void Visit(const char* key, Object* value) final { - *value = GetAttr(key).operator Object(); + void Visit(const char* key, ObjectRef* value) final { + *value = GetAttr(key).operator ObjectRef(); } private: diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 00d4fb4b6219..0cfae374ab2c 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -885,7 +885,7 @@ void VMCompiler::Compile(Module mod, // populate constants for (auto data : context_.constants) { - vm_->constants.push_back(Object::Tensor(data)); + vm_->constants.push_back(runtime::vm::Tensor(data)); } LibraryCodegen(); diff --git a/src/relay/backend/vm/deserializer.cc b/src/relay/backend/vm/deserializer.cc index 6cf76081de13..777282782e99 100644 --- a/src/relay/backend/vm/deserializer.cc +++ b/src/relay/backend/vm/deserializer.cc @@ -102,7 +102,7 @@ void Deserializer::DeserializeConstantSection() { for (size_t i = 0; i < size; i++) { runtime::NDArray constant; STREAM_CHECK(constant.Load(strm_), "constant"); - runtime::Object obj = runtime::Object::Tensor(constant); + runtime::ObjectRef obj = runtime::vm::Tensor(constant); vm_->constants.push_back(obj); } } diff --git a/src/relay/backend/vm/serializer.cc b/src/relay/backend/vm/serializer.cc index d6e44b4af1f8..0040ef9db470 100644 --- a/src/relay/backend/vm/serializer.cc +++ b/src/relay/backend/vm/serializer.cc @@ -98,8 +98,8 @@ std::string Serializer::Stats() const { // Get the number of constants and the shape of each of them. oss << " Constant shapes (# " << vm_->constants.size() << "): ["; for (const auto& it : vm_->constants) { - auto cell = it.AsTensor(); - CHECK(cell.operator->()); + auto* cell = it.as(); + CHECK(cell != nullptr); runtime::NDArray data = cell->data; const auto& shape = data.Shape(); @@ -175,7 +175,8 @@ void Serializer::SerializeGlobalSection() { void Serializer::SerializeConstantSection() { std::vector arrays; for (const auto& obj : vm_->constants) { - auto cell = obj.AsTensor(); + const auto* cell = obj.as(); + CHECK(cell != nullptr); runtime::NDArray data = cell->data; arrays.push_back(const_cast(data.operator->())); } diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index afc8ad9dcf6a..31218be4a6d4 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -930,7 +930,7 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { void Visit(const char* key, runtime::NDArray* value) final { LOG(FATAL) << "do not allow NDarray as argument"; } - void Visit(const char* key, runtime::Object* obj) final { + void Visit(const char* key, runtime::ObjectRef* obj) final { LOG(FATAL) << "do not allow Object as argument"; } diff --git a/src/runtime/object.cc b/src/runtime/object.cc new file mode 100644 index 000000000000..5248da00245a --- /dev/null +++ b/src/runtime/object.cc @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/runtime/object.cc + * \brief Object type management system. + */ +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { + +/*! \brief Type information */ +struct TypeInfo { + /*! \brief The current index. */ + uint32_t index{0}; + /*! \brief Index of the parent in the type hierachy */ + uint32_t parent_index{0}; + // NOTE: the indices in [index, index + num_reserved_slots) are + // reserved for the child-class of this type. + /*! \brief Total number of slots reserved for the type and its children. */ + uint32_t num_slots{0}; + /*! \brief number of allocated child slots. */ + uint32_t allocated_slots{0}; + /*! \brief Whether child can overflow. */ + bool child_slots_can_overflow{true}; + /*! \brief name of the type. */ + std::string name; +}; + +/*! + * \brief Type context that manages the type hierachy information. + */ +class TypeContext { + public: + // NOTE: this is a relatively slow path for child checking + // Most types are already checked by the fast-path via reserved slot checking. + bool DerivedFrom(uint32_t child_tindex, uint32_t parent_tindex) { + // invariance: child's type index is always bigger than its parent. + if (child_tindex < parent_tindex) return false; + if (child_tindex == parent_tindex) return true; + { + std::lock_guard lock(mutex_); + CHECK_LT(child_tindex, type_table_.size()); + while (child_tindex > parent_tindex) { + child_tindex = type_table_[child_tindex].parent_index; + } + } + return child_tindex == parent_tindex; + } + + uint32_t GetOrAllocRuntimeTypeIndex(const char* key, + uint32_t static_tindex, + uint32_t parent_tindex, + uint32_t num_child_slots, + bool child_slots_can_overflow) { + std::lock_guard lock(mutex_); + std::string skey = key; + auto it = type_key2index_.find(skey); + if (it != type_key2index_.end()) { + return it->second; + } + // try to allocate from parent's type table. + CHECK_LT(parent_tindex, type_table_.size()); + TypeInfo& pinfo = type_table_[parent_tindex]; + CHECK_EQ(pinfo.index, parent_tindex); + + // if parent cannot overflow, then this class cannot. + if (!pinfo.child_slots_can_overflow) { + child_slots_can_overflow = false; + } + + // total number of slots include the type itself. + uint32_t num_slots = num_child_slots + 1; + uint32_t allocated_tindex; + + if (static_tindex != TypeIndex::kDynamic) { + // statically assigned type + allocated_tindex = static_tindex; + CHECK_LT(static_tindex, type_table_.size()); + CHECK_EQ(type_table_[allocated_tindex].allocated_slots, 0U) + << "Conflicting static index " << static_tindex + << " between " << type_table_[allocated_tindex].name + << " and " + << key; + } else if (pinfo.allocated_slots + num_slots < pinfo.num_slots) { + // allocate the slot from parent's reserved pool + allocated_tindex = parent_tindex + pinfo.allocated_slots; + // update parent's state + pinfo.allocated_slots += num_slots; + } else { + CHECK(pinfo.child_slots_can_overflow) + << "Reach maximum number of sub-classes for " << pinfo.name; + // allocate new entries. + allocated_tindex = type_counter_; + type_counter_ += num_slots; + CHECK_LE(type_table_.size(), allocated_tindex); + type_table_.resize(allocated_tindex + 1, TypeInfo()); + } + CHECK_GT(allocated_tindex, parent_tindex); + // initialize the slot. + type_table_[allocated_tindex].index = allocated_tindex; + type_table_[allocated_tindex].parent_index = parent_tindex; + type_table_[allocated_tindex].num_slots = num_slots; + type_table_[allocated_tindex].allocated_slots = 1; + type_table_[allocated_tindex].child_slots_can_overflow = + child_slots_can_overflow; + type_table_[allocated_tindex].name = skey; + // update the key2index mapping. + type_key2index_[skey] = allocated_tindex; + return allocated_tindex; + } + + std::string TypeIndex2Key(uint32_t tindex) { + std::lock_guard lock(mutex_); + CHECK(tindex < type_table_.size() && + type_table_[tindex].allocated_slots != 0) + << "Unknown type index " << tindex; + return type_table_[tindex].name; + } + + uint32_t TypeKey2Index(const char* key) { + std::string skey = key; + auto it = type_key2index_.find(skey); + CHECK(it != type_key2index_.end()) + << "Cannot find type " << key; + return it->second; + } + + static TypeContext* Global() { + static TypeContext inst; + return &inst; + } + + private: + TypeContext() { + type_table_.resize(TypeIndex::kStaticIndexEnd, TypeInfo()); + } + // mutex to avoid registration from multiple threads. + std::mutex mutex_; + std::atomic type_counter_{TypeIndex::kStaticIndexEnd}; + std::vector type_table_; + std::unordered_map type_key2index_; +}; + +uint32_t Object::GetOrAllocRuntimeTypeIndex(const char* key, + uint32_t static_tindex, + uint32_t parent_tindex, + uint32_t num_child_slots, + bool child_slots_can_overflow) { + return TypeContext::Global()->GetOrAllocRuntimeTypeIndex( + key, static_tindex, parent_tindex, num_child_slots, child_slots_can_overflow); +} + +bool Object::DerivedFrom(uint32_t parent_tindex) const { + return TypeContext::Global()->DerivedFrom( + this->type_index_, parent_tindex); +} + +std::string Object::TypeIndex2Key(uint32_t tindex) { + return TypeContext::Global()->TypeIndex2Key(tindex); +} + +uint32_t Object::TypeKey2Index(const char* key) { + return TypeContext::Global()->TypeKey2Index(key); +} +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vm/object.cc b/src/runtime/vm/object.cc index 5750ea83ea90..c20a1ce9de27 100644 --- a/src/runtime/vm/object.cc +++ b/src/runtime/vm/object.cc @@ -18,134 +18,110 @@ */ /*! - * Copyright (c) 2019 by Contributors - * \file object.cc - * \brief A managed object in the TVM runtime. + * \file src/runtime/vm/object.cc + * \brief VM related objects. */ - #include #include +#include +#include #include #include -#include #include "../runtime_base.h" namespace tvm { namespace runtime { +namespace vm { -std::ostream& operator<<(std::ostream& os, const ObjectTag& tag) { - switch (tag) { - case ObjectTag::kClosure: - os << "Closure"; - break; - case ObjectTag::kDatatype: - os << "Datatype"; - break; - case ObjectTag::kTensor: - os << "Tensor"; - break; - default: - LOG(FATAL) << "Invalid object tag: found " << static_cast(tag); - } - return os; -} - -Object Object::Tensor(const NDArray& data) { - ObjectPtr ptr = MakeObject(data); - return Object(ptr); -} - -Object Object::Datatype(size_t tag, const std::vector& fields) { - ObjectPtr ptr = MakeObject(tag, fields); - return Object(ptr); +Tensor::Tensor(NDArray data) { + auto ptr = make_object(); + ptr->data = std::move(data); + data_ = std::move(ptr); } -Object Object::Tuple(const std::vector& fields) { return Object::Datatype(0, fields); } - -Object Object::Closure(size_t func_index, const std::vector& free_vars) { - ObjectPtr ptr = MakeObject(func_index, free_vars); - return Object(ptr); +Datatype::Datatype(size_t tag, std::vector fields) { + auto ptr = make_object(); + ptr->tag = tag; + ptr->fields = std::move(fields); + data_ = std::move(ptr); } -ObjectPtr Object::AsTensor() const { - CHECK(ptr_.get()); - CHECK(ptr_.get()->tag == ObjectTag::kTensor); - return ptr_.As(); +Datatype Datatype::Tuple(std::vector fields) { + return Datatype(0, fields); } -ObjectPtr Object::AsDatatype() const { - CHECK(ptr_.get()); - CHECK(ptr_.get()->tag == ObjectTag::kDatatype); - return ptr_.As(); +Closure::Closure(size_t func_index, std::vector free_vars) { + auto ptr = make_object(); + ptr->func_index = func_index; + ptr->free_vars = std::move(free_vars); + data_ = std::move(ptr); } -ObjectPtr Object::AsClosure() const { - CHECK(ptr_.get()); - CHECK(ptr_.get()->tag == ObjectTag::kClosure); - return ptr_.As(); -} - -NDArray ToNDArray(const Object& obj) { - auto tensor = obj.AsTensor(); - return tensor->data; -} TVM_REGISTER_GLOBAL("_vmobj.GetTensorData") .set_body([](TVMArgs args, TVMRetValue* rv) { - Object obj = args[0]; - auto cell = obj.AsTensor(); + ObjectRef obj = args[0]; + const auto* cell = obj.as(); + CHECK(cell != nullptr); *rv = cell->data; }); TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeTag") .set_body([](TVMArgs args, TVMRetValue* rv) { - Object obj = args[0]; - auto cell = obj.AsDatatype(); - *rv = static_cast(cell->tag); + ObjectRef obj = args[0]; + const auto* cell = obj.as(); + CHECK(cell != nullptr); + *rv = static_cast(cell->tag); }); TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeNumberOfFields") .set_body([](TVMArgs args, TVMRetValue* rv) { - Object obj = args[0]; - auto cell = obj.AsDatatype(); - *rv = static_cast(cell->fields.size()); + ObjectRef obj = args[0]; + const auto* cell = obj.as(); + CHECK(cell != nullptr); + *rv = static_cast(cell->fields.size()); }); TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeFields") .set_body([](TVMArgs args, TVMRetValue* rv) { - Object obj = args[0]; + ObjectRef obj = args[0]; int idx = args[1]; - auto cell = obj.AsDatatype(); + const auto* cell = obj.as(); + CHECK(cell != nullptr); CHECK_LT(idx, cell->fields.size()); *rv = cell->fields[idx]; }); TVM_REGISTER_GLOBAL("_vmobj.Tensor") .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = Object::Tensor(args[0]); +*rv = Tensor(args[0].operator NDArray()); }); TVM_REGISTER_GLOBAL("_vmobj.Tuple") .set_body([](TVMArgs args, TVMRetValue* rv) { - std::vector fields; + std::vector fields; for (auto i = 0; i < args.size(); ++i) { fields.push_back(args[i]); } - *rv = Object::Tuple(fields); + *rv = Datatype::Tuple(fields); }); TVM_REGISTER_GLOBAL("_vmobj.Datatype") .set_body([](TVMArgs args, TVMRetValue* rv) { int itag = args[0]; size_t tag = static_cast(itag); - std::vector fields; + std::vector fields; for (int i = 1; i < args.size(); i++) { fields.push_back(args[i]); } - *rv = Object::Datatype(tag, fields); + *rv = Datatype(tag, fields); }); +TVM_REGISTER_OBJECT_TYPE(TensorObj); +TVM_REGISTER_OBJECT_TYPE(DatatypeObj); +TVM_REGISTER_OBJECT_TYPE(ClosureObj); +} // namespace vm } // namespace runtime } // namespace tvm @@ -153,6 +129,7 @@ using namespace tvm::runtime; int TVMGetObjectTag(TVMObjectHandle handle, int* tag) { API_BEGIN(); - *tag = static_cast(static_cast(handle)->tag); + int res = static_cast(static_cast(handle)->type_index()); + *tag = res; API_END(); } diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 5f59f6ed7f48..80e0ce57a8ae 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -96,7 +96,7 @@ void VirtualMachineDebug::Init(const std::vector& ctxs) { void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, - const std::vector& args) { + const std::vector& args) { auto ctx = VirtualMachine::GetParamsContext(); // warmup VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, diff --git a/src/runtime/vm/profiler/vm.h b/src/runtime/vm/profiler/vm.h index 99060328638d..447967cafeb0 100644 --- a/src/runtime/vm/profiler/vm.h +++ b/src/runtime/vm/profiler/vm.h @@ -45,7 +45,7 @@ class VirtualMachineDebug : public VirtualMachine { const std::shared_ptr& sptr_to_self) final; void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, - Index output_size, const std::vector& args) final; + Index output_size, const std::vector& args) final; ~VirtualMachineDebug() {} diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index ed12d77d80a8..7dea9bdb95ea 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file src/runtime/vm/vm.cc * \brief The Relay virtual machine. */ @@ -558,12 +557,12 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) { return os; } -Object CopyTo(Object src, const DLContext& ctx) { - if (src->tag == ObjectTag::kTensor) { - auto tensor = ToNDArray(src); +ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) { + if (const TensorObj* obj = src.as()) { + auto tensor = obj->data; if (tensor->ctx.device_type != ctx.device_type) { auto copy = tensor.CopyTo(ctx); - return Object::Tensor(copy); + return Tensor(copy); } else { return src; } @@ -585,7 +584,7 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, auto ctx = this->GetParamsContext(); // Prepare the func args - std::vector func_args(param_names.size()); + std::vector func_args(param_names.size()); std::vector empty_slots; for (size_t i = 0; i < param_names.size(); ++i) { @@ -599,7 +598,7 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, CHECK_EQ(empty_slots.size(), args.size() - 1) << "The number of provided parameters doesn't match the number of arguments"; for (int i = 1; i < args.size(); ++i) { - Object obj = CopyTo(args[i], ctx); + ObjectRef obj = CopyTo(args[i], ctx); func_args[empty_slots[i - 1]] = obj; } @@ -660,7 +659,7 @@ void VirtualMachine::LoadParams(const std::string& params) { for (size_t i = 0; i < size; i++) { NDArray arr; CHECK(arr.Load(strm)) << "Invalid parameter file"; - runtime::Object obj = runtime::Object::Tensor(arr); + ObjectRef obj = Tensor(arr); auto copy = CopyTo(obj, ctx); params_.emplace(std::make_pair(names[i], copy)); } @@ -682,7 +681,7 @@ Index VirtualMachine::PopFrame() { return call_stack_size; } -void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector& args) { +void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector& args) { DLOG(INFO) << "Invoking global " << func.name << " " << args.size(); PushFrame(func.params.size(), this->pc + 1, func); @@ -695,7 +694,7 @@ void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector& args) { +ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector& args) { DLOG(INFO) << "Executing Function: " << std::endl << func; InvokeGlobal(func, args); @@ -705,7 +704,7 @@ Object VirtualMachine::Invoke(const VMFunction& func, const std::vector& return return_register; } -Object VirtualMachine::Invoke(const std::string& name, const std::vector& args) { +ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vector& args) { auto func_index = this->global_map[name]; DLOG(INFO) << "Invoke Global " << name << " at index " << func_index; return Invoke(this->functions[func_index], args); @@ -713,11 +712,11 @@ Object VirtualMachine::Invoke(const std::string& name, const std::vector void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, - const std::vector& args) { + const std::vector& args) { size_t arity = 0; for (Index i = 0; i < arg_count; i++) { - if (args[i].ptr_->tag == ObjectTag::kDatatype) { - arity += args[i].AsDatatype()->fields.size(); + if (const auto* obj = args[i].as()) { + arity += obj->fields.size(); } else { ++arity; } @@ -728,15 +727,16 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, runtime::TVMArgsSetter setter(values.data(), codes.data()); int idx = 0; for (Index i = 0; i < arg_count; i++) { - if (args[i].ptr_->tag == ObjectTag::kDatatype) { - auto dt_cell = args[i].AsDatatype(); + if (const auto* dt_cell = args[i].as()) { for (auto obj : dt_cell->fields) { - NDArray data = ToNDArray(obj); - setter(idx++, data); + const auto* tensor = obj.as(); + CHECK(tensor != nullptr); + setter(idx++, tensor->data); } } else { - NDArray data = ToNDArray(args[i]); - setter(idx++, data); + const auto* tensor = args[i].as(); + CHECK(tensor != nullptr); + setter(idx++, tensor->data); } } @@ -761,18 +761,20 @@ void VirtualMachine::Init(const std::vector& ctxs) { } } -inline void VirtualMachine::WriteRegister(Index r, const Object& val) { +inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) { frames.back().register_file[r] = val; } -inline Object VirtualMachine::ReadRegister(Index r) const { +inline ObjectRef VirtualMachine::ReadRegister(Index r) const { return frames.back().register_file[r]; } inline int32_t VirtualMachine::LoadScalarInt(Index r) const { int32_t result; const auto& obj = ReadRegister(r); - NDArray array = ToNDArray(obj).CopyTo({kDLCPU, 0}); + const auto* tensor = obj.as(); + CHECK(tensor != nullptr); + NDArray array = tensor->data.CopyTo({kDLCPU, 0}); if (array->dtype.bits <= 8) { result = reinterpret_cast(array->data)[0]; @@ -798,7 +800,7 @@ void VirtualMachine::RunLoop() { switch (instr.op) { case Opcode::Move: { - Object from_obj; + ObjectRef from_obj; from_obj = ReadRegister(instr.from); WriteRegister(instr.dst, from_obj); pc++; @@ -817,12 +819,12 @@ void VirtualMachine::RunLoop() { case Opcode::LoadConsti: { auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, {kDLCPU, 0}); reinterpret_cast(tensor->data)[0] = instr.load_consti.val; - WriteRegister(instr.dst, Object::Tensor(tensor)); + WriteRegister(instr.dst, Tensor(tensor)); pc++; goto main_loop; } case Opcode::Invoke: { - std::vector args; + std::vector args; for (Index i = 0; i < instr.num_args; ++i) { args.push_back(ReadRegister(instr.invoke_args_registers[i])); } @@ -833,7 +835,7 @@ void VirtualMachine::RunLoop() { case Opcode::InvokePacked: { const auto& func = packed_funcs[instr.packed_index]; const auto& arity = instr.arity; - std::vector args; + std::vector args; for (Index i = 0; i < arity; ++i) { args.push_back(ReadRegister(instr.packed_args[i])); } @@ -847,8 +849,9 @@ void VirtualMachine::RunLoop() { } case Opcode::InvokeClosure: { auto object = ReadRegister(instr.closure); - const auto& closure = object.AsClosure(); - std::vector args; + const auto* closure = object.as(); + + std::vector args; for (auto free_var : closure->free_vars) { args.push_back(free_var); } @@ -861,10 +864,10 @@ void VirtualMachine::RunLoop() { } case Opcode::GetField: { auto object = ReadRegister(instr.object); - CHECK(object->tag == ObjectTag::kDatatype) + const auto* tuple = object.as(); + CHECK(tuple != nullptr) << "Object is not data type object, register " << instr.object << ", Object tag " - << static_cast(object->tag); - const auto& tuple = object.AsDatatype(); + << object->type_index(); auto field = tuple->fields[instr.field_index]; WriteRegister(instr.dst, field); pc++; @@ -872,15 +875,15 @@ void VirtualMachine::RunLoop() { } case Opcode::GetTag: { auto object = ReadRegister(instr.get_tag.object); - CHECK(object->tag == ObjectTag::kDatatype) + const auto* data = object.as(); + CHECK(data != nullptr) << "Object is not data type object, register " << instr.get_tag.object << ", Object tag " - << static_cast(object->tag); - const auto& data = object.AsDatatype(); + << object->type_index(); auto tag = data->tag; auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0}); reinterpret_cast(tag_tensor->data)[0] = tag; - WriteRegister(instr.dst, Object::Tensor(tag_tensor)); + WriteRegister(instr.dst, Tensor(tag_tensor)); pc++; goto main_loop; } @@ -909,7 +912,7 @@ void VirtualMachine::RunLoop() { } auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, ctxs[0]); - auto obj = Object::Tensor(data); + auto obj = Tensor(data); WriteRegister(instr.dst, obj); pc++; goto main_loop; @@ -920,7 +923,9 @@ void VirtualMachine::RunLoop() { cpu_ctx.device_id = 0; auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); - NDArray shape_tensor = ToNDArray(shape_tensor_obj).CopyTo(cpu_ctx); + const auto* tensor = shape_tensor_obj.as(); + CHECK(tensor != nullptr); + NDArray shape_tensor = tensor->data.CopyTo(cpu_ctx); int64_t* dims = static_cast(shape_tensor->data); auto num_dims = shape_tensor->shape[0]; @@ -928,27 +933,27 @@ void VirtualMachine::RunLoop() { shape.assign(dims, dims + num_dims); auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, ctxs[0]); - auto obj = Object::Tensor(data); + auto obj = Tensor(data); WriteRegister(instr.dst, obj); pc++; goto main_loop; } case Opcode::AllocDatatype: { - std::vector fields; + std::vector fields; for (Index i = 0; i < instr.num_fields; ++i) { fields.push_back(ReadRegister(instr.datatype_fields[i])); } - Object obj = Object::Datatype(instr.constructor_tag, fields); + ObjectRef obj = Datatype(instr.constructor_tag, fields); WriteRegister(instr.dst, obj); pc++; goto main_loop; } case Opcode::AllocClosure: { - std::vector free_vars; + std::vector free_vars; for (Index i = 0; i < instr.num_freevar; i++) { free_vars.push_back(ReadRegister(instr.free_vars[i])); } - WriteRegister(instr.dst, Object::Closure(instr.func_index, free_vars)); + WriteRegister(instr.dst, Closure(instr.func_index, free_vars)); pc++; goto main_loop; } diff --git a/tests/cpp/object_protocol_test.cc b/tests/cpp/object_protocol_test.cc new file mode 100644 index 000000000000..9f3ce00f3b24 --- /dev/null +++ b/tests/cpp/object_protocol_test.cc @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +namespace tvm { +namespace test { + +using namespace tvm::runtime; + +class ObjBase : public Object { + public: + // dynamically allocate slow + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const uint32_t _type_child_slots = 1; + static constexpr const char* _type_key = "test.ObjBase"; + TVM_DECLARE_BASE_OBJECT_INFO(ObjBase, Object); +}; + +class ObjA : public ObjBase { + public: + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const uint32_t _type_child_slots = 0; + static constexpr const char* _type_key = "test.ObjA"; + TVM_DECLARE_BASE_OBJECT_INFO(ObjA, ObjBase); +}; + +class ObjB : public ObjBase { + public: + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "test.ObjB"; + TVM_DECLARE_BASE_OBJECT_INFO(ObjB, ObjBase); +}; + +class ObjAA : public ObjA { + public: + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "test.ObjAA"; + TVM_DECLARE_FINAL_OBJECT_INFO(ObjAA, ObjA); +}; + + +TVM_REGISTER_OBJECT_TYPE(ObjBase); +TVM_REGISTER_OBJECT_TYPE(ObjA); +TVM_REGISTER_OBJECT_TYPE(ObjB); +TVM_REGISTER_OBJECT_TYPE(ObjAA); + +} // namespace test +} // namespace tvm + +TEST(ObjectHierachy, Basic) { + using namespace tvm::runtime; + using namespace tvm::test; + + ObjectRef refA(make_object()); + CHECK_EQ(refA->type_index(), ObjA::type_index()); + CHECK(refA.as() != nullptr); + CHECK(refA.as() != nullptr); + CHECK(refA.as() != nullptr); + CHECK(refA.as() == nullptr); + CHECK(refA.as() == nullptr); + + ObjectRef refAA(make_object()); + CHECK_EQ(refAA->type_index(), ObjAA::type_index()); + CHECK(refAA.as() != nullptr); + CHECK(refAA.as() != nullptr); + CHECK(refAA.as() != nullptr); + CHECK(refAA.as() != nullptr); + CHECK(refAA.as() == nullptr); + + ObjectRef refB(make_object()); + CHECK_EQ(refB->type_index(), ObjB::type_index()); + CHECK(refB.as() != nullptr); + CHECK(refB.as() != nullptr); + CHECK(refB.as() == nullptr); + CHECK(refB.as() == nullptr); + CHECK(refB.as() != nullptr); +} + +int main(int argc, char ** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index f643f8ad1f0b..5289fe9f5411 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -582,7 +582,7 @@ def test_set_params(): mod["main"] = relay.Function([x, w, b], y) vm = relay.vm.compile(mod, 'llvm') vm.init(tvm.cpu()) - + x_np = np.random.uniform(size=(10, 5)).astype('float32') w_np = np.random.uniform(size=(6, 5)).astype('float32') b_np = np.random.uniform(size=(6,)).astype('float32')