Skip to content

Commit

Permalink
[Refactor] Rename Datatype to ADT (#4156)
Browse files Browse the repository at this point in the history
We think it will reduce the confusion with the meaning.

https://discuss.tvm.ai/t/discuss-consider-rename-vm-datatype/4339
  • Loading branch information
wweic authored and icemelon committed Oct 20, 2019
1 parent 3c4b7cc commit 32aad56
Show file tree
Hide file tree
Showing 13 changed files with 69 additions and 69 deletions.
10 changes: 5 additions & 5 deletions docs/dev/virtual_machine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ AllocTensor
Allocate a tensor value of the appropriate shape (stored in `shape_register`) and `dtype`. The result
is saved to register `dst`.

AllocDatatype
AllocADT
^^^^^^^^^^^^^
**Arguments**:
::
Expand Down Expand Up @@ -176,7 +176,7 @@ GetTagi
RegName object
RegName dst

Get the object tag for Datatype object in register `object`. And saves the reult to register `dst`.
Get the object tag for ADT object in register `object`. And saves the reult to register `dst`.

Fatal
^^^^^
Expand Down Expand Up @@ -251,9 +251,9 @@ Currently, we support 3 types of objects: tensors, data types, and closures.

::

VMObject VMTensor(const tvm::runtime::NDArray& data);
VMObject VMDatatype(size_t tag, const std::vector<VMObject>& fields);
VMObject VMClosure(size_t func_index, std::vector<VMObject> free_vars);
Object Tensor(const tvm::runtime::NDArray& data);
Object ADT(size_t tag, const std::vector<Object>& fields);
Object Closure(size_t func_index, std::vector<Object> free_vars);


Stack and State
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ enum TypeIndex {
kRoot = 0,
kVMTensor = 1,
kVMClosure = 2,
kVMDatatype = 3,
kVMADT = 3,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
Expand Down
24 changes: 12 additions & 12 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,31 +57,31 @@ class Tensor : public ObjectRef {


/*! \brief An object representing a structure or enumeration. */
class DatatypeObj : public Object {
class ADTObj : public Object {
public:
/*! \brief The tag representing the constructor used. */
size_t tag;
/*! \brief The fields of the structure. */
std::vector<ObjectRef> fields;

static constexpr const uint32_t _type_index = TypeIndex::kVMDatatype;
static constexpr const char* _type_key = "vm.Datatype";
TVM_DECLARE_FINAL_OBJECT_INFO(DatatypeObj, Object);
static constexpr const uint32_t _type_index = TypeIndex::kVMADT;
static constexpr const char* _type_key = "vm.ADT";
TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object);
};

/*! \brief reference to data type. */
class Datatype : public ObjectRef {
/*! \brief reference to algebraic data type objects. */
class ADT : public ObjectRef {
public:
Datatype(size_t tag, std::vector<ObjectRef> fields);
ADT(size_t tag, std::vector<ObjectRef> fields);

/*!
* \brief construct a tuple object.
* \param fields The fields of the tuple.
* \return The constructed tuple type.
*/
static Datatype Tuple(std::vector<ObjectRef> fields);
static ADT Tuple(std::vector<ObjectRef> fields);

TVM_DEFINE_OBJECT_REF_METHODS(Datatype, ObjectRef, DatatypeObj);
TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj);
};

/*! \brief An object representing a closure. */
Expand Down Expand Up @@ -129,7 +129,7 @@ enum class Opcode {
InvokePacked = 4U,
AllocTensor = 5U,
AllocTensorReg = 6U,
AllocDatatype = 7U,
AllocADT = 7U,
AllocClosure = 8U,
GetField = 9U,
If = 10U,
Expand Down Expand Up @@ -237,7 +237,7 @@ struct Instruction {
/*! \brief The register to project from. */
RegName object;
} get_tag;
struct /* AllocDatatype Operands */ {
struct /* AllocADT Operands */ {
/*! \brief The datatype's constructor tag. */
Index constructor_tag;
/*! \brief The number of fields to store in the datatype. */
Expand Down Expand Up @@ -294,7 +294,7 @@ struct Instruction {
* \param dst The register name of the destination.
* \return The allocate instruction tensor.
*/
static Instruction AllocDatatype(Index tag, Index num_fields, const std::vector<RegName>& fields,
static Instruction AllocADT(Index tag, Index num_fields, const std::vector<RegName>& fields,
RegName dst);
/*! \brief Construct an allocate closure instruction.
* \param func_index The index of the function table.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .interpreter import Executor

Tensor = _obj.Tensor
Datatype = _obj.Datatype
ADT = _obj.ADT

def _convert(arg, cargs):
if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
Expand Down
20 changes: 10 additions & 10 deletions python/tvm/relay/backend/vmobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ def asnumpy(self):
return self.data.asnumpy()


@register_object("vm.Datatype")
class Datatype(Object):
"""Datatype object.
@register_object("vm.ADT")
class ADT(Object):
"""Algebatic data type(ADT) object.
Parameters
----------
tag : int
The tag of datatype.
The tag of ADT.
fields : list[Object] or tuple[Object]
The source tuple.
Expand All @@ -77,22 +77,22 @@ def __init__(self, tag, fields):
for f in fields:
assert isinstance(f, Object)
self.__init_handle_by_constructor__(
_vmobj.Datatype, tag, *fields)
_vmobj.ADT, tag, *fields)

@property
def tag(self):
return _vmobj.GetDatatypeTag(self)
return _vmobj.GetADTTag(self)

def __getitem__(self, idx):
return getitem_helper(
self, _vmobj.GetDatatypeFields, len(self), idx)
self, _vmobj.GetADTFields, len(self), idx)

def __len__(self):
return _vmobj.GetDatatypeNumberOfFields(self)
return _vmobj.GetADTNumberOfFields(self)


def tuple_object(fields):
"""Create a datatype object from source tuple.
"""Create a ADT object from source tuple.
Parameters
----------
Expand All @@ -101,7 +101,7 @@ def tuple_object(fields):
Returns
-------
ret : Datatype
ret : ADT
The created object.
"""
for f in fields:
Expand Down
8 changes: 4 additions & 4 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
DLOG(INFO) << "VMCompiler::Emit: instr=" << instr;
CHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op;
switch (instr.op) {
case Opcode::AllocDatatype:
case Opcode::AllocADT:
case Opcode::AllocTensor:
case Opcode::AllocTensorReg:
case Opcode::GetField:
Expand Down Expand Up @@ -287,7 +287,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
}

// TODO(@jroesch): use correct tag
Emit(Instruction::AllocDatatype(
Emit(Instruction::AllocADT(
0,
tuple->fields.size(),
fields_registers,
Expand Down Expand Up @@ -626,7 +626,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
for (size_t i = arity - return_count; i < arity; ++i) {
fields_registers.push_back(unpacked_arg_regs[i]);
}
Emit(Instruction::AllocDatatype(0, return_count, fields_registers, NewRegister()));
Emit(Instruction::AllocADT(0, return_count, fields_registers, NewRegister()));
}
}

Expand Down Expand Up @@ -659,7 +659,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
}
} else if (auto constructor_node = op.as<ConstructorNode>()) {
auto constructor = GetRef<Constructor>(constructor_node);
Emit(Instruction::AllocDatatype(constructor->tag, call_node->args.size(), args_registers,
Emit(Instruction::AllocADT(constructor->tag, call_node->args.size(), args_registers,
NewRegister()));
} else if (auto var_node = op.as<VarNode>()) {
VisitExpr(GetRef<Var>(var_node));
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
fields.push_back(instr.dst);
break;
}
case Opcode::AllocDatatype: {
case Opcode::AllocADT: {
// Number of fields = 3 + instr.num_fields
fields.assign({instr.constructor_tag, instr.num_fields, instr.dst});

Expand Down Expand Up @@ -551,7 +551,7 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {

return Instruction::AllocTensorReg(shape_register, dtype, dst);
}
case Opcode::AllocDatatype: {
case Opcode::AllocADT: {
// Number of fields = 3 + instr.num_fields
DCHECK_GE(instr.fields.size(), 3U);
DCHECK_EQ(instr.fields.size(), 3U + static_cast<size_t>(instr.fields[1]));
Expand All @@ -561,7 +561,7 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
RegName dst = instr.fields[2];
std::vector<Index> fields = ExtractFields(instr.fields, 3, num_fields);

return Instruction::AllocDatatype(constructor_tag, num_fields, fields, dst);
return Instruction::AllocADT(constructor_tag, num_fields, fields, dst);
}
case Opcode::AllocClosure: {
// Number of fields = 3 + instr.num_freevar
Expand Down
28 changes: 14 additions & 14 deletions src/runtime/vm/object.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ Tensor::Tensor(NDArray data) {
data_ = std::move(ptr);
}

Datatype::Datatype(size_t tag, std::vector<ObjectRef> fields) {
auto ptr = make_object<DatatypeObj>();
ADT::ADT(size_t tag, std::vector<ObjectRef> fields) {
auto ptr = make_object<ADTObj>();
ptr->tag = tag;
ptr->fields = std::move(fields);
data_ = std::move(ptr);
}

Datatype Datatype::Tuple(std::vector<ObjectRef> fields) {
return Datatype(0, fields);
ADT ADT::Tuple(std::vector<ObjectRef> fields) {
return ADT(0, fields);
}

Closure::Closure(size_t func_index, std::vector<ObjectRef> free_vars) {
Expand All @@ -66,28 +66,28 @@ TVM_REGISTER_GLOBAL("_vmobj.GetTensorData")
*rv = cell->data;
});

TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeTag")
TVM_REGISTER_GLOBAL("_vmobj.GetADTTag")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto* cell = obj.as<DatatypeObj>();
const auto* cell = obj.as<ADTObj>();
CHECK(cell != nullptr);
*rv = static_cast<int64_t>(cell->tag);
});

TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeNumberOfFields")
TVM_REGISTER_GLOBAL("_vmobj.GetADTNumberOfFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto* cell = obj.as<DatatypeObj>();
const auto* cell = obj.as<ADTObj>();
CHECK(cell != nullptr);
*rv = static_cast<int64_t>(cell->fields.size());
});


TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeFields")
TVM_REGISTER_GLOBAL("_vmobj.GetADTFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
int idx = args[1];
const auto* cell = obj.as<DatatypeObj>();
const auto* cell = obj.as<ADTObj>();
CHECK(cell != nullptr);
CHECK_LT(idx, cell->fields.size());
*rv = cell->fields[idx];
Expand All @@ -104,22 +104,22 @@ TVM_REGISTER_GLOBAL("_vmobj.Tuple")
for (auto i = 0; i < args.size(); ++i) {
fields.push_back(args[i]);
}
*rv = Datatype::Tuple(fields);
*rv = ADT::Tuple(fields);
});

TVM_REGISTER_GLOBAL("_vmobj.Datatype")
TVM_REGISTER_GLOBAL("_vmobj.ADT")
.set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
std::vector<ObjectRef> fields;
for (int i = 1; i < args.size(); i++) {
fields.push_back(args[i]);
}
*rv = Datatype(tag, fields);
*rv = ADT(tag, fields);
});

TVM_REGISTER_OBJECT_TYPE(TensorObj);
TVM_REGISTER_OBJECT_TYPE(DatatypeObj);
TVM_REGISTER_OBJECT_TYPE(ADTObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj);
} // namespace vm
} // namespace runtime
Expand Down
Loading

0 comments on commit 32aad56

Please sign in to comment.