Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Apr 4, 2022
1 parent fd423ef commit 75c61da
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 9 deletions.
2 changes: 1 addition & 1 deletion python/tvm/script/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@ def transform_block(self, block: ast.Block) -> relax.SeqExpr:
blocks.append(relax.BindingBlock(current_block, self.to_tvm_span(stmt.span)))
current_block = []
blocks.append(parsed_stmt)
elif isinstance(parsed_stmt, relax.Function) or isinstance(parsed_stmt, tir.PrimFunc):
elif isinstance(parsed_stmt, (relax.Function, tir.PrimFunc)):
func_var = self.decl_var(stmt.name, None, None, stmt.span)
current_block.append(
relax.VarBinding(func_var, parsed_stmt, self.to_tvm_span(stmt.span))
Expand Down
7 changes: 1 addition & 6 deletions src/relax/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,7 @@ tvm::Array<GlobalVar> RecGlobalVars(const Expr& expr) { return VarVisitor().RecG

TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars);

TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body([](TVMArgs args, TVMRetValue* ret) {
ObjectRef x = args[0];
if (x.as<ExprNode>()) {
*ret = BoundVars(Downcast<Expr>(x));
}
});
TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars);

TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars);

Expand Down
2 changes: 1 addition & 1 deletion src/relax/transform/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

/*!
* \file tvm/relax/backend/vm/lambda_lift.cc
* \file tvm/relax/transform/lambda_lift.cc
* \brief Lift local functions into global functions.
*/

Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_lambda_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def while_loop(i: Tensor[(), "int32"], s: Tensor[(2, 3), "float32"]):
# Perform Lamda Lifting
after = transform.LambdaLift()(before)
assert len(after.functions) == 2
assert_structural_equal(after["lifted_func_0"], expected["lifted_func_0"], map_free_vars=True)
assert_structural_equal(after, expected, map_free_vars=True)
_check_save_roundtrip(after)


Expand Down

0 comments on commit 75c61da

Please sign in to comment.