Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update VoidStructInfo & globalvar #4

Merged
merged 2 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions python/tvm/script/parser/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,8 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None:
params_sinfo.append(param_sinfo)
params.append(relax.Var(arg.arg, param_shape, param_type))

# TODO(relax-team): remove the following line when fixing ret_shape issue in block builder
ret_shape = relax.RuntimeDepShape()

func_signature = relax.Function.create_unchecked(params, None, ret_type, ret_shape)
global_var = I.decl_function(node.name, func_signature)
relax.expr._update_struct_info(global_var, relax.FuncStructInfo(params_sinfo, ret_sinfo))
self.var_table.add(node.name, global_var)


Expand Down
12 changes: 10 additions & 2 deletions src/printer/relax_script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,6 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function&
param << Print(var) << PrintVarAnnotation(var);
params.push_back(param);
}
print_symbolic_shape_as_str_ = false;

if (is_global) {
ICHECK(symbolic_vars_.empty());
Expand All @@ -591,9 +590,18 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function&
doc << "@R.function" << Doc::NewLine();
doc << "def " << name << "(" << Doc::Concat(params, Doc::Text(", ")) << ")";
if (func->ret_type.defined()) {
doc << " -> " << Print(func->ret_type);
doc << " -> ";
if (const relax::DynTensorTypeNode* tty = func->ret_type.as<relax::DynTensorTypeNode>()) {
doc << PrintTensorAnnotation(GetRef<DynTensorType>(tty), func->ret_shape);
} else if (const TupleTypeNode* tty = func->ret_type.as<TupleTypeNode>()) {
doc << PrintTupleAnnotation(GetRef<TupleType>(tty), func->ret_shape);
} else {
doc << Print(func->ret_type);
}
}
doc << ":" << Doc::NewLine(4);
print_symbolic_shape_as_str_ = false;


// Step 3: print function attr
Doc header_attr;
Expand Down
5 changes: 5 additions & 0 deletions src/relax/analysis/struct_info_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ class LegacyShapeDeriver : public StructInfoFunctor<Optional<Expr>(const StructI
}

Optional<Expr> VisitStructInfo_(const TupleStructInfoNode* op) final {
if (op->fields.empty()) {
// Corner case to prevent infinite recursion.
return Tuple(Array<Expr>());
}

bool valid = true;
Array<Expr> fields = op->fields.Map([this, &valid](const StructInfo& sinfo) {
Optional<Expr> shape = this->VisitStructInfo(sinfo);
Expand Down
2 changes: 1 addition & 1 deletion src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr&
if (!normalized->IsInstance<OpNode>()) {
ICHECK(normalized->struct_info_.defined())
<< "The struct_info_ of an Expr except OpNode after "
"normalization must not be nullptr. However, this Expr does not have checked_type_: "
"normalization must not be nullptr. However, this Expr does not have struct_info_: "
<< normalized;
}

Expand Down
28 changes: 22 additions & 6 deletions src/relax/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,12 +454,14 @@ Function::Function(Array<Var> params, Expr body, Type ret_type, Expr ret_shape,
param_sinfo.push_back(GetStructInfo(param));
}

StructInfo ret_info;
StructInfo ret_info = GetStructInfo(body);

if (ret_type.defined()) {
ret_info = StructInfoFromTypeLegacyShapeHint(ret_type, ret_shape);
} else {
ret_info = GetStructInfo(body);
StructInfo given_info = StructInfoFromTypeLegacyShapeHint(ret_type, ret_shape);
CHECK(IsBaseOf(given_info, ret_info))
<< "relax.Function requires the deduced body->struct_info to be a subtype of the "
"annotated struct_info but meet body->struct_info: "
<< ret_info << ", ret_info: " << given_info;
}

FuncStructInfo func_sinfo(param_sinfo, ret_info);
Expand All @@ -468,8 +470,8 @@ Function::Function(Array<Var> params, Expr body, Type ret_type, Expr ret_shape,
ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
n->params = std::move(params);
n->body = std::move(body);
n->ret_type = std::move(ret_type);
n->ret_shape = std::move(ret_shape);
n->ret_type = GetStaticType(ret_info);
n->ret_shape = GetLegacyShapeHint(ret_info).value_or(ret_shape);
n->struct_info_ = func_sinfo;
n->checked_type_ = GetStaticType(func_sinfo);
n->attrs = std::move(attrs);
Expand All @@ -485,15 +487,29 @@ TVM_REGISTER_GLOBAL("relax.Function")

Function Function::CreateUnchecked(Array<Var> params, Expr body, Type ret_type, Expr ret_shape,
DictAttrs attrs, Span span) {
// TODO(@Hzfengsy): revisit `CreateUnchecked` after the parser_v1 removed

Array<StructInfo> param_sinfo;

for (Var param : params) {
ICHECK(param->checked_type_.defined())
<< "relax.Function requires params to contain checked_type_.";
param_sinfo.push_back(GetStructInfo(param));
}

StructInfo ret_info;

if (ret_type.defined()) {
ret_info = StructInfoFromTypeLegacyShapeHint(ret_type, ret_shape);
} else {
ret_info = FuncStructInfo::OpaqueFunc();
}

// set the fields
ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
n->params = std::move(params);
n->body = std::move(body);
n->struct_info_ = FuncStructInfo(param_sinfo, ret_info);
n->ret_type = std::move(ret_type);
n->ret_shape = std::move(ret_shape);
n->attrs = std::move(attrs);
Expand Down
24 changes: 12 additions & 12 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs) {
}

StructInfo ReturnVoidStructInfo(const Call& call, const BlockBuilder& ctx) {
return TensorStructInfo({});
return TupleStructInfo(Array<StructInfo>());
}

StructInfo ReturnObjectStructInfo(const Call& call, const BlockBuilder& ctx) {
Expand Down Expand Up @@ -142,21 +142,21 @@ TVM_REGISTER_GLOBAL("relax.op.print").set_body_typed(MakePrint);

// can't actually name it assert or else Python will consider it a syntax error

Type InferAssertType(const Call& call, DiagnosticContext diag_ctx) {
StructInfo InferAssertStructInfo(const Call& call, const BlockBuilder& ctx) {
// Ensure that the condition argument is a boolean scalar.
// Also permitted is a tensor with unknown shape and unknown dtype
// (checked dynamically in that case). Returns void.
if (call->args.size() < 1) {
diag_ctx.EmitFatal(Diagnostic::Error(call->span)
<< "Assert must have at least one argument (the condition).");
ctx->ReportFatal(Diagnostic::Error(call->span)
<< "Assert must have at least one argument (the condition).");
}
Type arg_type = call->args[0]->checked_type();
if (!IsBoolScalarType(arg_type)) {
diag_ctx.EmitFatal(Diagnostic::Error(call->span)
<< "The argument to assert must be a boolean scalar type, but received "
<< arg_type);
ctx->ReportFatal(Diagnostic::Error(call->span)
<< "The argument to assert must be a boolean scalar type, but received "
<< arg_type);
}
return VoidType();
return ReturnVoidStructInfo(call, ctx);
}

TVM_REGISTER_NODE_TYPE(AssertOpAttrs);
Expand All @@ -167,7 +167,7 @@ RELAY_REGISTER_OP("relax.assert_op")
.add_argument("vals", "Array<Expr>",
"The first value is used as the assertion condition. The others are used as "
"format arguments if there is an error.")
.set_attr<FInferType>("FInferType", InferAssertType)
.set_attr<FInferStructInfo>("FInferStructInfo", InferAssertStructInfo)
.set_attr<FCallPacked>("FCallPacked", "relax.run.assert_op");

Expr MakeAssertOp(Expr condition, Array<Expr> vals, std::string format) {
Expand Down Expand Up @@ -277,9 +277,9 @@ TVM_REGISTER_GLOBAL("relax.op.memory.alloc_storage").set_body_typed(MakeAllocSto
StructInfo InferStructInfoMemAllocTensor(const Call& call, const BlockBuilder& ctx) {
const auto* attrs = call->attrs.as<MemAllocTensorAttrs>();
ICHECK(attrs != nullptr) << "must be MemAllocTensorAttrs, but got " << call->attrs->GetTypeKey();
ICHECK(call->args[0].as<ShapeExprNode>())
<< "must be ShapeExpr, but got " << call->args[0]->GetTypeKey();
return TensorStructInfo(call->args[0], attrs->dtype);
ICHECK(GetStructInfoAs<ShapeStructInfoNode>(call->args[1]))
<< "must be a Expr of ShapeStructInfo, but got " << call->args[1]->GetTypeKey();
return TensorStructInfo(call->args[1], attrs->dtype);
}

RELAY_REGISTER_OP("relax.memory.alloc_tensor")
Expand Down
17 changes: 17 additions & 0 deletions src/script/ir_builder/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/ir/module.h>
#include <tvm/relax/analysis.h>
#include <tvm/runtime/registry.h>
#include <tvm/script/ir_builder/ir/ir.h>

Expand All @@ -39,6 +40,22 @@ GlobalVar DeclFunction(const String& func_name, const Optional<BaseFunc>& func_s
CHECK(!frame->global_var_map.count(func_name))
<< "ValueError: function " << func_name << " already exists";
GlobalVar gv = GlobalVar(func_name);
if (func_signature.defined()) {
const BaseFunc& func = func_signature.value();
if (func->struct_info_.defined()) {
gv->struct_info_ = tvm::relax::GetStructInfo(func);
} else if (const auto* prim_func = func.as<tvm::tir::PrimFuncNode>()) {
// NOTE: use a slightly different struct info than checked type
// in PrimFunc so handle can turn into Tensor.
// TODO(relax-team): add fine-grained PrimFunc struct info signature generation.
gv->struct_info_ = tvm::relax::FuncStructInfo::OpaqueFunc(
tvm::relax::StructInfoFromType(prim_func->ret_type));
} else {
LOG(FATAL) << "Unsupported function: " << func;
}
} else {
gv->struct_info_ = tvm::relax::FuncStructInfo::OpaqueFunc();
}
CHECK(frame->functions.find(gv) == frame->functions.end())
<< "ValueError: function " << func_name << " has already been defined.";
frame->global_var_map.Set(func_name, gv);
Expand Down
2 changes: 0 additions & 2 deletions src/script/ir_builder/relax/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@ void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo) {
frame->ret_sinfo = ret_sinfo;
frame->ret_type = GetStaticType(ret_sinfo);
frame->ret_shape = GetLegacyShapeHint(ret_sinfo);
// TODO(@Hzfengsy): remove it
frame->ret_shape = tvm::relax::RuntimeDepShape();
}

void FuncRetValue(const tvm::relax::Expr& value) {
Expand Down
17 changes: 10 additions & 7 deletions tests/python/relax/test_analysis_struct_info_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
# under the License.

"""Tests analysis functions of struct info"""
import pytest

import tvm
import tvm.testing
from tvm import relax as rx
from tvm import tir
from tvm import relax as rx, TVMError


def test_get_static_type_basic():
Expand All @@ -40,9 +40,9 @@ def test_get_static_type_shape():
s2 = rx.ShapeStructInfo([1, n + 1, m])
s3 = rx.ShapeStructInfo(ndim=2)

tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s2), rx.ShapeType())
tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s2), rx.ShapeType(ndim=3))

tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s3), rx.ShapeType())
tvm.ir.assert_structural_equal(rx.analysis.get_static_type(s3), rx.ShapeType(ndim=2))


def test_get_static_type_tensor():
Expand All @@ -68,7 +68,7 @@ def test_get_static_type_tuple():
rx.TupleType(
[
rx.TupleType([rx.DynTensorType(ndim=3, dtype="int64"), rx.ObjectType()]),
rx.ShapeType(),
rx.ShapeType(ndim=3),
]
),
)
Expand Down Expand Up @@ -123,8 +123,7 @@ def test_erase_to_well_defined_shape():

def test_erase_to_well_defined_tensor():
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
rshape = rx.Var("shape", type_annotation=rx.ShapeType())
rx.expr._update_struct_info(rshape, rx.ShapeStructInfo(ndim=2))
rshape = rx.Var("shape", type_annotation=rx.ShapeType(ndim=2))
s0 = rx.TensorStructInfo(rshape, dtype="int32")

# undefined
Expand Down Expand Up @@ -382,3 +381,7 @@ def fn_info_erased():
_check_lca(fopaque0(), fopaque1(), fopaque0())
_check_lca(fopaque0(), fn_info_shape(1), fopaque0())
_check_lca(fopaque2(), fn_info_shape(1), fopaque2())


if __name__ == "__main__":
tvm.testing.main()
14 changes: 7 additions & 7 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")):
y1 = R.match_shape(y, (n,))
return (m, n * 2)

x = relax.Var("x", None, DynTensorType(-1, "float32"))
y = relax.Var("y", None, DynTensorType(-1, "float32"))
x = relax.Var("x", RuntimeDepShape(), DynTensorType(-1, "float32"))
y = relax.Var("y", RuntimeDepShape(), DynTensorType(-1, "float32"))
m = tir.Var("m", dtype="int64")
n = tir.Var("n", dtype="int64")
bb = relax.BlockBuilder()
Expand Down Expand Up @@ -237,7 +237,7 @@ def foo(x: R.Tensor("float32", ndim=2)):
x0 = R.match_shape(x, (n, m))
return (x0, (n + 1, m, 1))

x = relax.Var("x", None, DynTensorType(2, "float32"))
x = relax.Var("x", RuntimeDepShape(), DynTensorType(2, "float32"))
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
bb = relax.BlockBuilder()
with bb.function("foo", (x,)):
Expand All @@ -256,7 +256,7 @@ def foo(x: R.Tensor("float32", ndim=2)):
t1 = (x, (n, m), t0)
return t1

x = relax.Var("x", None, DynTensorType(2, "float32"))
x = relax.Var("x", RuntimeDepShape(), DynTensorType(2, "float32"))
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
bb = relax.BlockBuilder()
with bb.function("foo", (x,)):
Expand Down Expand Up @@ -493,9 +493,9 @@ def _check_type_shape(binding, expected_type, expected_shape):
relax.DynTensorType(ndim=2, dtype="float32"),
relax.ShapeExpr([tvm.tir.IntImm("int64", 32), m]),
)
_check_type_shape(bindings[1], relax.DynTensorType(dtype=""), None)
_check_type_shape(bindings[2], relax.DynTensorType(ndim=2, dtype=""), None)
_check_type_shape(bindings[3], relax.DynTensorType(dtype=""), None)
_check_type_shape(bindings[1], relax.DynTensorType(dtype=""), RuntimeDepShape())
_check_type_shape(bindings[2], relax.DynTensorType(ndim=2, dtype=""), RuntimeDepShape())
_check_type_shape(bindings[3], relax.DynTensorType(dtype=""), RuntimeDepShape())
_check_type_shape(bindings[4], relax.ShapeType(), None)
_check_type_shape(
bindings[5],
Expand Down