Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
Remove update shape/type
Browse files Browse the repository at this point in the history
  • Loading branch information
Tianqi Chen authored and tqchen committed Dec 16, 2022
1 parent 2fcfea4 commit 3662ef1
Show file tree
Hide file tree
Showing 20 changed files with 59 additions and 124 deletions.
17 changes: 0 additions & 17 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -895,23 +895,6 @@ class ExternFunc : public BaseFunc {
TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode);
};

/*!
* \brief Update the type of an Expr.
* \param expr The Expr whose type to be updated.
* \param type The type assigned to the checked_type_ of \p expr.
* \note We ensure idempotence, that is we can only update the checked_type_ of an Expr if it's
* nullptr.
*/
void UpdateType(Expr expr, Type type);

/*!
* \brief Update the shape of an Expr.
* \param expr The Expr whose shape to be updated.
* \param shape The shape assigned to the shape_ of \p expr.
* \note We ensure idempotence, that is we can only update the shape_ of an Expr if it's nullptr.
*/
void UpdateShape(Expr expr, Optional<ObjectRef> shape);

} // namespace relax
} // namespace tvm

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_legacy_shape_hint(sinfo: StructInfo) -> Optional[Expr]:
ret : Type
The corresponding shape.
"""
return _ffi_api.GetLegacyShapeHint(sinfo)
return _ffi_api.GetLegacyShapeHint(sinfo) # type: ignore


def erase_to_well_defined(
Expand Down
20 changes: 8 additions & 12 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
"""The expression nodes of Relax."""
from typing import Any, List, Optional, Union
import typing
import numpy as _np # type: ignore

import tvm
import tvm._ffi
import numpy as _np # type: ignore
from tvm.runtime import ndarray as _nd
import tvm.relax

from tvm._ffi import base as _base
from .. import relay
Expand Down Expand Up @@ -452,9 +453,12 @@ def const(

if not dtype:
# when dtype is None: int maps to "int32", float maps to "float32"
dtype = {_np.dtype("int64"): _np.int32, _np.dtype("float64"): _np.float32}.get(
dtype = { # type: ignore
_np.dtype("int64"): _np.int32, # type: ignore
_np.dtype("float64"): _np.float32, # type: ignore
}.get(
value.dtype, None
)
) # type: ignore

if isinstance(value, (_np.ndarray, _np.generic)):
if dtype is not None:
Expand All @@ -472,13 +476,5 @@ def te_tensor(value: Expr, name: str = "rxplaceholder"):
return _ffi_api.TETensor(value, name) # type: ignore


def _update_type(expr: Expr, type: Type) -> None:
_ffi_api.UpdateType(expr, type) # type: ignore


def _update_shape(expr: Expr, shape: Optional[tvm.runtime.Object]) -> None:
_ffi_api.UpdateShape(expr, shape) # type: ignore


def _update_struct_info(expr: Expr, struct_info: Optional["StructInfo"]) -> None:
def _update_struct_info(expr: Expr, struct_info: Optional["tvm.relax.StructInfo"]) -> None:
_ffi_api.UpdateStructInfo(expr, struct_info) # type: ignore
5 changes: 3 additions & 2 deletions python/tvm/relax/expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .expr import Call, If, TupleGetItem
from .expr import Binding, MatchShape, VarBinding
from .expr import BindingBlock, DataflowBlock
from .struct_info import StructInfo
from ..relay import Id
from ..ir.module import IRModule
from .block_builder import BlockBuilder
Expand Down Expand Up @@ -1444,7 +1445,7 @@ def lookup_binding(self, var: Var) -> Optional[Expr]:
# Using self._outer() to ref _PyExprMutator
return _ffi_api.PyExprMutatorLookupBinding(self._outer(), var) # type: ignore

def with_struct_info(self, var: Var, struct_info: "StructInfo") -> Var:
def with_struct_info(self, var: Var, struct_info: StructInfo) -> Var:
"""Create a new var with specified shape and type if the original var's shape or type does
not match with the specified ones.
Expand All @@ -1461,4 +1462,4 @@ def with_struct_info(self, var: Var, struct_info: "StructInfo") -> Var:
The var filled with shape and type.
"""
# Using self._outer() to ref _PyExprMutator
return _ffi_api.PyExprMutatorWithStructInfo(self._outer(), var, sinfo) # type: ignore
return _ffi_api.PyExprMutatorWithStructInfo(self._outer(), var, struct_info) # type: ignore
24 changes: 12 additions & 12 deletions python/tvm/relax/struct_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,12 @@ def __init__(
self, values: Optional[List[PrimExpr]] = None, ndim: int = -1, span: Span = None
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ShapeStructInfo, values, ndim, span
) # type: ignore
_ffi_api.ShapeStructInfo, values, ndim, span # type: ignore
)


@tvm._ffi.register_object("relax.TensorStructInfo")
class TensorStructInfo(StructInfo):
shape: Optional[Expr]
dtype: tvm.DataType
ndim: int
span: Span

"""StructInfo of a Tensor value.
Parameters
Expand All @@ -126,6 +121,11 @@ class TensorStructInfo(StructInfo):
The number of dimensions of the tensor.
"""

shape: Optional[Expr]
dtype: tvm.DataType
ndim: int
span: Span

def __init__(
self,
shape: Union[Optional[Expr], List[PrimExpr]] = None,
Expand All @@ -137,8 +137,8 @@ def __init__(
shape = ShapeExpr(shape)

self.__init_handle_by_constructor__(
_ffi_api.TensorStructInfo, shape, dtype, ndim, span
) # type: ignore
_ffi_api.TensorStructInfo, shape, dtype, ndim, span # type: ignore
)


@tvm._ffi.register_object("relax.TupleStructInfo")
Expand Down Expand Up @@ -170,8 +170,8 @@ class FuncStructInfo(StructInfo):

def __init__(self, params: List[StructInfo], ret: StructInfo, span: Span = None) -> None:
self.__init_handle_by_constructor__(
_ffi_api.FuncStructInfo, params, ret, span
) # type: ignore
_ffi_api.FuncStructInfo, params, ret, span # type: ignore
)

@staticmethod
def opaque_func(
Expand All @@ -198,4 +198,4 @@ def opaque_func(
-------
info: FuncStructInfo
"""
return _ffi_api.FuncStructInfoOpaqueFunc(ret, derive_func, span)
return _ffi_api.FuncStructInfoOpaqueFunc(ret, derive_func, span) # type: ignore
6 changes: 2 additions & 4 deletions python/tvm/relax/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,14 @@
@tvm._ffi.register_object("relax.ShapeType")
class ShapeType(Type):
"""The type of shape in Relax.
Parameters
----------
ndim : Optional[int]
The size of the shape.
"""

def __init__(self,
ndim: int = -1,
span: Span = None) -> None:
def __init__(self, ndim: int = -1, span: Span = None) -> None:
self.__init_handle_by_constructor__(_ffi_api.ShapeType, ndim, span) # type: ignore


Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin, wrong-import-order
# pylint: disable=redefined-builtin, wrong-import-order, no-member, invalid-name
"""IRBuilder for Relax dialect"""

import functools
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ def __call__(
return RxTuple(fields)
else:
fields = list(fields)
for i in range(len(fields)):
if callable(fields[i]):
fields[i] = fields[i]()
for i, x in enumerate(fields):
if callable(x):
fields[i] = x()
if all([isinstance(f, StructInfo) for f in fields]):
return relax.TupleStructInfo(fields)
else:
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/script/parser/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...ir_builder import relax as R
from ...ir_builder.base import IRBuilder
from .._core import Parser, dispatch, doc
from .entry import MatchShapePair, Tensor
from .entry import MatchShapePair


def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
Expand Down Expand Up @@ -108,6 +108,7 @@ def eval_shape_annotation(
return None


# pylint: disable=inconsistent-return-statements
def eval_type_annotation(
self: Parser, node: Union[doc.Expression, doc.expr]
) -> Tuple[Type, Optional[Expr], StructInfo]:
Expand Down
1 change: 0 additions & 1 deletion src/printer/relax_script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,6 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function&
doc << ":" << Doc::NewLine(4);
print_symbolic_shape_as_str_ = false;


// Step 3: print function attr
Doc header_attr;
if (func->attrs.defined()) {
Expand Down
30 changes: 8 additions & 22 deletions src/relax/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< node->false_branch << ")";
});

// Eager composition
Tuple::Tuple(tvm::Array<relay::Expr> fields, Span span) {
ObjectPtr<TupleNode> n = make_object<TupleNode>();
n->fields = std::move(fields);
Expand Down Expand Up @@ -292,11 +293,16 @@ TVM_REGISTER_NODE_TYPE(VarNode);
Var::Var(Id vid, Optional<Expr> shape_annotation, Optional<Type> type_annotation, Span span) {
ObjectPtr<VarNode> n = make_object<VarNode>();
n->vid = std::move(vid);
// invariance for transition, alwasy require type ann if shape is provided.
if (shape_annotation) {
ICHECK(type_annotation) << "Var requires type annoation if we provide shape ann";
}
if (type_annotation) {
n->struct_info_ = StructInfoFromTypeLegacyShapeHint(type_annotation.value(), shape_annotation);
StructInfo sinfo = StructInfoFromTypeLegacyShapeHint(type_annotation.value(), shape_annotation);
n->struct_info_ = sinfo;
n->checked_type_ = std::move(type_annotation.value());
n->shape_ = GetLegacyShapeHint(sinfo);
}
n->shape_ = std::move(shape_annotation);
n->span = std::move(span);
data_ = std::move(n);
}
Expand Down Expand Up @@ -573,25 +579,5 @@ TVM_REGISTER_GLOBAL("relax.ExternFunc").set_body_typed([](String global_symbol,
return ExternFunc(global_symbol, span);
});

void UpdateType(Expr expr, Type type) {
ICHECK(!expr->checked_type_.defined() || tvm::StructuralEqual()(expr->checked_type_, type))
<< "the checked_type_ of the Expr to be updated must be nullptr for idempotency";
expr->checked_type_ = type;
}

TVM_REGISTER_GLOBAL("relax.UpdateType").set_body_typed([](Expr expr, Type type) {
UpdateType(expr, type);
});

void UpdateShape(Expr expr, Optional<ObjectRef> shape) {
ICHECK(!expr->shape_.defined())
<< "the shape_ of the Expr to be updated must be nullptr for idempotency";
expr->shape_ = shape;
}

TVM_REGISTER_GLOBAL("relax.UpdateShape").set_body_typed([](Expr expr, Optional<ObjectRef> shape) {
UpdateShape(expr, shape);
});

} // namespace relax
} // namespace tvm
38 changes: 2 additions & 36 deletions src/relax/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -642,43 +642,9 @@ BindingBlock ExprMutator::VisitBindingBlock_(const DataflowBlockNode* block) {
return builder_->EndBlock();
}

Var ExprMutator::VisitVarDef_(const DataflowVarNode* var) {
bool shape_unchanged = true;
Expr new_shape;
if (var->shape_) {
new_shape = this->VisitExpr(Downcast<Expr>(var->shape_.value()));
shape_unchanged &= new_shape.same_as(var->shape_);
}

if (shape_unchanged) {
return GetRef<Var>(var);
} else {
Var new_var = DataflowVar(var->vid, NullOpt, var->checked_type_, var->span);
UpdateShape(new_var, new_shape);

this->var_remap_[var->vid] = new_var;
return new_var;
}
}
Var ExprMutator::VisitVarDef_(const DataflowVarNode* var) { return GetRef<Var>(var); }

Var ExprMutator::VisitVarDef_(const VarNode* var) {
bool shape_unchanged = true;
Expr new_shape;
if (var->shape_) {
new_shape = this->VisitExpr(Downcast<Expr>(var->shape_.value()));
shape_unchanged &= new_shape.same_as(var->shape_);
}

if (shape_unchanged) {
return GetRef<Var>(var);
} else {
Var new_var = Var(var->vid, NullOpt, var->checked_type_, var->span);
UpdateShape(new_var, new_shape);

this->var_remap_[var->vid] = new_var;
return new_var;
}
}
Var ExprMutator::VisitVarDef_(const VarNode* var) { return GetRef<Var>(var); }

void ExprMutator::VisitBinding(const Binding& binding) {
if (const auto* node = binding.as<VarBindingNode>()) {
Expand Down
1 change: 0 additions & 1 deletion src/relax/op/tensor/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
namespace tvm {
namespace relax {


Type InferTypeEwiseFMA(const Call& call, DiagnosticContext diag_ctx) {
if (call->args.size() != 3) {
diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "EwiseFMA op should have 3 arguments");
Expand Down
2 changes: 0 additions & 2 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,6 @@ class FunctionCreator : public ExprMutator {
function_ = Function(/*params=*/params_, //
/*body=*/body, //
Type(), Expr(),
///*ret_type=*/body->checked_type_,
///*ret_shape=*/RuntimeDepShape(),
/*attrs=*/DictAttrs(attrs));
}

Expand Down
5 changes: 3 additions & 2 deletions src/relax/transform/to_non_dataflow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
*/
#include <tvm/relax/attrs/memory.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>
#include <tvm/relax/type.h>
#include <tvm/tir/op.h>
Expand All @@ -33,8 +34,8 @@ class ToNonDFMutator : public ExprMutator {
public:
Var VisitVarDef(const Var& var) final {
if (var.as<DataflowVarNode>()) {
Var new_var = Var(var->vid, NullOpt, var->checked_type_, var->span);
UpdateShape(new_var, var->shape_);
Var new_var = Var(var->vid, NullOpt, NullOpt, var->span);
UpdateStructInfo(new_var, GetStructInfo(var));
this->var_remap_[var->vid] = new_var;
return new_var;
}
Expand Down
4 changes: 3 additions & 1 deletion tests/python/relax/test_analysis_well_formed.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def test_symbolic_var():


def test_symbolic_var_invalid_type():
with pytest.raises(tvm.TVMError, match="the value in ShapeStructInfo can only have dtype of int64"):
with pytest.raises(
tvm.TVMError, match="the value in ShapeStructInfo can only have dtype of int64"
):
dim = tir.Var("dim", "float32")
type_anno = rx.DynTensorType(ndim=1, dtype="float32")
y = rx.Var("y", [dim], type_anno)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_ast_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def test_shape_of():
assert 'Var(name_hint="v0")' in s0_str

shape_anno = [96, 54]
v1 = rx.Var("v1", shape_anno)
v1 = rx.Var("v1", shape_anno, rx.DynTensorType(ndim=2))
s1 = v1.shape
s1_str = dump_ast(s1)
assert s1_str.startswith("ShapeExpr("), s1_str
Expand Down
Loading

0 comments on commit 3662ef1

Please sign in to comment.