Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Relax IRBuilder #4

Merged
merged 28 commits into from
Aug 31, 2021
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
bb206e0
Add initial IRBuilder.
YuchenJin Aug 9, 2021
ef12293
Add function output to irbuilder; update based on new AST.
YuchenJin Aug 10, 2021
d456f08
Add call method; clean up bindings
YuchenJin Aug 11, 2021
ecbdb84
Add test.
YuchenJin Aug 11, 2021
faa98da
Add multifuction test
mikepapadim Aug 12, 2021
0a3388c
Move implementation to C++; infer shape and type
YuchenJin Aug 18, 2021
b75f2c1
update op python hook
YuchenJin Aug 18, 2021
bae7e2d
More tests and bug fix
YuchenJin Aug 18, 2021
01b8789
Add comments.
YuchenJin Aug 19, 2021
a92bc78
Update shape/type inference.
YuchenJin Aug 20, 2021
078b814
Restructure code; add python type hint.
YuchenJin Aug 20, 2021
0ca2bcf
Cleanup code.
YuchenJin Aug 21, 2021
79240a3
Rebase; address comments.
YuchenJin Aug 24, 2021
9707ad5
Add call intrinsic.
YuchenJin Aug 24, 2021
bde251a
nits.
YuchenJin Aug 24, 2021
6e8a929
Remove call op.
YuchenJin Aug 25, 2021
73423d6
Migrate scope to C++ using tvm::With.
YuchenJin Aug 26, 2021
9d9431d
Address naming.
YuchenJin Aug 26, 2021
c93ebd4
Add GetBlocks API.
YuchenJin Aug 27, 2021
da53880
Unify EmitOutput APIs; add more comments.
YuchenJin Aug 27, 2021
5facded
Remove shape and type deduction code.
YuchenJin Aug 27, 2021
9c30bf6
Also remove the shape/type attr interface.
YuchenJin Aug 27, 2021
672b3c7
Address comments.
YuchenJin Aug 28, 2021
e503e48
Differentiate global and local function.
YuchenJin Aug 30, 2021
3565113
Reset counter after building func/block.
YuchenJin Aug 30, 2021
fbb3f45
Rebase.
YuchenJin Aug 30, 2021
1f7a4ab
Remove shape infer builtin.
YuchenJin Aug 31, 2021
b4f5a26
Return from void function as empty tuple.
YuchenJin Aug 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ class TupleType : public Type {
inline Type VoidType() { return TupleType::Empty(); }

/*!
* \brief Check whether the tyep represents void.
* \brief Check whether the type represents void.
* \return The check result.
*/
inline bool IsVoidType(const Type& type) {
Expand Down
197 changes: 197 additions & 0 deletions include/tvm/relax/ir_builder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/ir_builder.h
* \brief The utility for constructing Relax AST.
*/
#ifndef TVM_RELAX_IR_BUILDER_H_
#define TVM_RELAX_IR_BUILDER_H_

#include <tvm/ir/expr.h>
#include <tvm/relax/expr.h>
#include <tvm/relay/expr.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/registry.h>
#include <tvm/support/with.h>

namespace tvm {
namespace relax {

using relay::Call;

class IRBuilder;

/*!
* \brief The state of Relax function node being built.
*/
struct RelaxFunction {
/*! \brief The function name. */
Optional<GlobalVar> func_name = NullOpt;
/*! \brief The function parameters. */
Array<Var> params;
/*! \brief The bindings in the function. */
std::vector<Binding> bindings;
/*! \brief The binding blocks in the function. */
std::vector<BindingBlock> binding_blocks;
/*! \brief The return of the function. */
Expr ret;
/*! \brief The FunctionNode being built. */
Function func;
};

/*!
* \brief A builder that provides APIs to build Relax AST.
*/
class IRBuilderNode : public Object {
public:
/*!
* \brief Fill the function name and parameters.
*/
void FillFuncNameParam(const Array<Var>& params, const std::string& func_name);
/*!
* \brief Build a function node.
*/
void BuildFunction();
/*!
* \brief Build a binding block.
*/
void BuildBlock();
/*!
* \brief Emit a call node.
* \param call The CallNode to be emitted.
* \return The variable being created and binded to \p call.
*/
Var Emit(const Call& call);
/*!
* \brief Generate an output for the current dataflow block or function.
* \param output The output variable of the block/function.
* \return The variable being binded to \p ouput.
*/
Var EmitOutput(const Expr& output);
/*!
* \brief Get the function being built.
*/
Function Get();
/*!
* \brief Get binding blocks being built.
*/
std::vector<BindingBlock> GetBlocks();
/*!
* \brief Create a IRBuilder.
* \return The created IRBuilder.
*/
TVM_DLL static IRBuilder Create();

void VisitAttrs(AttrVisitor* v) {}

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.IRBuilder";
TVM_DECLARE_FINAL_OBJECT_INFO(IRBuilderNode, Object);

private:
/*! \brief The state of the function currently being built. */
RelaxFunction func;
YuchenJin marked this conversation as resolved.
Show resolved Hide resolved
/*! \brief A flag tracking if currently inside a dataflow block or not. */
bool is_dataflow = false;
/*! \brief A global variable counter for naming global variables. */
int global_var_counter = 0;
/*! \brief A dataflow variable counter for naming dataflow variables. */
int dataflow_var_counter = 0;
YuchenJin marked this conversation as resolved.
Show resolved Hide resolved
};

class IRBuilder : public ObjectRef {
public:
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(IRBuilder, ObjectRef, IRBuilderNode);
};

/*! \brief Auxiliary scope for building Relax function node,
* similar to python's with syntax.
*
* \code
* {
* With<FunctionScope> scope(ir_builder);
* // build function node.
* }
*/
class FunctionScopeNode : public Object {
public:
IRBuilder ir_builder;
void VisitAttrs(AttrVisitor* v) { v->Visit("ir_builder", &ir_builder); }

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.FunctionScope";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionScopeNode, Object);
};

class FunctionScope : public ObjectRef {
public:
TVM_DLL FunctionScope(IRBuilder ib);
TVM_DEFINE_OBJECT_REF_METHODS(FunctionScope, ObjectRef, FunctionScopeNode);
class Internal;

private:
// Classes to get the Python `with` like syntax.
friend class Internal;
friend class With<FunctionScope>;
// The entry of a function scope.
TVM_DLL void EnterWithScope();
// The exit of a function scope.
TVM_DLL void ExitWithScope();
};

/*! \brief Auxiliary scope for building Relax dataflow block,
* similar to python's with syntax.
*
* \code
* {
* With<DataflowScope> scope(ir_builder);
* // build dataflow block.
* }
*/
class DataflowScopeNode : public Object {
public:
IRBuilder ir_builder;
void VisitAttrs(AttrVisitor* v) { v->Visit("ir_builder", &ir_builder); }

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.DataflowScope";
TVM_DECLARE_FINAL_OBJECT_INFO(DataflowScopeNode, Object);
};

class DataflowScope : public ObjectRef {
public:
TVM_DLL DataflowScope(IRBuilder ib);
TVM_DEFINE_OBJECT_REF_METHODS(DataflowScope, ObjectRef, DataflowScopeNode);
class Internal;

private:
// Classes to get the Python `with` like syntax.
friend class Internal;
friend class With<DataflowScope>;
// The entry of a dataflow scope.
TVM_DLL void EnterWithScope();
// The exit of a dataflow scope.
TVM_DLL void ExitWithScope();
};

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_IR_BUILDER_H_
11 changes: 6 additions & 5 deletions include/tvm/relax/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ namespace relax {

class ShapeTypeNode : public TypeNode {
public:

void VisitAttrs(tvm::AttrVisitor* v) {
}
void VisitAttrs(tvm::AttrVisitor* v) {}

bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const {
return true;
Expand All @@ -64,10 +62,9 @@ class ShapeType : public Type {
const ShapeTypeNode* get() const {
return operator->();
}
using ContainerType = ShapeTypeNode;
using ContainerType = ShapeTypeNode;
};


class DynTensorTypeNode : public BaseTensorTypeNode {
public:
/*!
Expand All @@ -92,6 +89,10 @@ class DynTensorTypeNode : public BaseTensorTypeNode {
hash_reduce(dtype);
}

inline bool IsUnknownRank() const { return rank == -1; }

inline bool IsUnknownDtype() const { return dtype.is_void(); }

static constexpr const char* _type_key = "relax.DynTensorType";
TVM_DECLARE_FINAL_OBJECT_INFO(DynTensorTypeNode, BaseTensorTypeNode);
};
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from . import ty
from . import vm
from . import op
from . import ir_builder
from . import op


# Expr
Expand Down Expand Up @@ -56,3 +58,6 @@

# Operator
from .op.base import call_dps

# IRBuilder
IRBuilder = ir_builder.IRBuilder
Loading