Skip to content

Commit

Permalink
Tf2 (apache#3)
Browse files Browse the repository at this point in the history
* InferType fix

* Skip annotation passes for non-main funcs

Co-authored-by: Rohan Mukherjee <[email protected]>
  • Loading branch information
anijain2305 and rohanmukh authored Sep 30, 2020
1 parent d5f9bb5 commit 488ffe5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
26 changes: 24 additions & 2 deletions python/tvm/relay/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,8 @@ def is_valid_subgraph(func):
# Remove invalid subgraphs
for subgraph in mod.get_global_vars():
name = subgraph.name_hint
if mod[name].attrs and hasattr(mod[name].attrs, "SkipOptimization") and mod[name].attrs["SkipOptimization"] == 1:
continue
if not mod[name].attrs or mod[name].attrs["Compiler"] != compiler:
continue
if not is_valid_subgraph(mod[name]):
Expand All @@ -741,6 +743,8 @@ def is_valid_subgraph(func):
subgraph_with_macs = []
for subgraph in mod.get_global_vars():
name = subgraph.name_hint
if mod[name].attrs and hasattr(mod[name].attrs, "SkipOptimization") and mod[name].attrs["SkipOptimization"] == 1:
continue
if not mod[name].attrs or mod[name].attrs["Compiler"] != compiler:
continue
num_macs = relay.analysis.get_total_mac_number(mod[name])
Expand Down Expand Up @@ -817,8 +821,24 @@ def EnableTrt(mod, params=None, trt_version=None, use_implicit_batch=True,
transform.FoldConstant(),
LegalizeLayoutTranformPass(),
transform.InferType(),
tvm.transform.PrintIR("A1"),
transform.AnnotateTarget('tensorrt'),
tvm.transform.PrintIR("A1")])
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)

def _set_optimization_attr(mod, skip_optimization=1):
# Prepare the mod such that all fucntions except main are tagged SkipOptimization
gvs = mod.get_global_vars()
for gv in gvs:
func = mod[gv]
name = gv.name_hint
if name != 'main':
new_func = func.with_attr("SkipOptimization",
tvm.tir.IntImm("int32", skip_optimization))
mod.update_func(gv, new_func)
return mod

mod = _set_optimization_attr(mod, 1)
seq = tvm.transform.Sequential([transform.AnnotateTarget('tensorrt'),
tvm.transform.PrintIR("A2"),
transform.MergeCompilerRegions(),
tvm.transform.PrintIR("A3"),
Expand All @@ -827,6 +847,8 @@ def EnableTrt(mod, params=None, trt_version=None, use_implicit_batch=True,
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
mod = PruneSubgraphs(mod, use_implicit_batch=use_implicit_batch, prune_no_macs=prune_subgraphs)
mod = _set_optimization_attr(mod, 0)

# Set environment variables used to communicate with TensorRT module.
os.environ["TVM_TENSORRT_MAX_WORKSPACE_SIZE"] = str(max_workspace_size)
os.environ["TVM_TENSORRT_USE_IMPLICIT_BATCH"] = str(int(use_implicit_batch))
Expand Down
1 change: 1 addition & 0 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
arg_target = op_expr_to_target_[arg];
compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op));
} else {
//else if (arg.as<VarNode>() != nullptr || arg.as<ConstantNode>() != nullptr){
// Input vars.
compiler_ends.push_back(arg);
}
Expand Down

0 comments on commit 488ffe5

Please sign in to comment.