Skip to content

Commit

Permalink
[AUTOTVM] End2End autotvm support for vta (apache#18)
Browse files Browse the repository at this point in the history
* support tuning a whole network

* pass unit test

* update tune resnet

* update all
  • Loading branch information
merrymercy authored and tmoreau89 committed Nov 25, 2018
1 parent 830c532 commit 7055803
Show file tree
Hide file tree
Showing 20 changed files with 924 additions and 1,160 deletions.
13 changes: 12 additions & 1 deletion python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,18 @@ def set_task(self, task):
for x in arg_bufs]
func = build(s, arg_bufs, "llvm")
tvm_buf = [nd.array(x) for x in self.ref_input]
func(*tvm_buf)

def _run_func():
"""Run tvm function in a thread.
Because there is some issues with python multiprocessing and the thread pool in tvm
"""
func(*tvm_buf)

thread = threading.Thread(target=_run_func)
thread.start()
thread.join()
del thread

self.ref_output = [x.asnumpy() for x in tvm_buf]

def get_build_kwargs(self):
Expand Down
201 changes: 65 additions & 136 deletions python/tvm/autotvm/task/nnvm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
import warnings
import logging
import sys


from ... import tensor, placeholder, create_schedule, target as _target
Expand Down Expand Up @@ -49,9 +50,9 @@ def deserialize_args(args):
# Task extractor for nnvm graph
class TaskExtractEnv:
"""Global environment for extracting tuning tasks from nnvm graph"""
current = None
registered = False

def __init__(self):
def __init__(self, wanted_symbols):
import topi
import nnvm

Expand Down Expand Up @@ -83,46 +84,62 @@ def __init__(self):
topi.nn.dense: [topi.generic.schedule_dense],
}

self._register_tracing()
# support reflection for tracing
self.func_to_reflection = {
topi.nn.conv2d: lambda x: setattr(topi.nn, 'conv2d', x),
topi.nn.depthwise_conv2d_nchw: lambda x: setattr(topi.nn, 'depthwise_conv2d_nchw', x),
topi.nn.conv2d_transpose_nchw: lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x),
topi.nn.dense: lambda x: setattr(topi.nn, 'dense', x),
}


self.wanted_topi_funcs = []
for sym_name in wanted_symbols:
if sym_name in self.symbol2topi:
self.wanted_topi_funcs.extend(self.symbol2topi[sym_name])
else:
warnings.warn("Symbol %s is not tunable, ignored" % sym_name)

self._register_topi_task()
self.task_collection = []
self.wanted_topi_funcs = list(self.topi_to_task.keys())
self.modified_funcs = []

def _register_tracing(self):
"""Register tracing function to track the topi function call"""
# register topi compute for "tracing" target
for topi_compute in self.topi_to_task:
def __enter__(self):
self.task_collection = []
self.modified_funcs = []

for topi_compute in self.wanted_topi_funcs:
def _local_scope(compute_func):
"""start a scope to hold the local function in for loop"""

@compute_func.register("tracing", )
def _tracing_topi_compute(*args, **kwargs):
assert not kwargs, "Do not support extracting tuning tasks when" \
"kwargs is used in TOPI function call." \
def _tracing_wrapper(*args, **kwargs):
assert not kwargs, "Do not support extracting tuning tasks when " \
"kwargs is used in TOPI function call. " \
"Please modify it to use only positional args."

if compute_func in self.wanted_topi_funcs: # record this call
key = (self.topi_to_task[compute_func], serialize_args(args))
if key not in self.task_collection:
self.task_collection.append(key)
key = (self.topi_to_task[compute_func], serialize_args(args))
if key not in self.task_collection:
self.task_collection.append(key)

return compute_func(*args, **kwargs)

self.func_to_reflection[topi_compute](_tracing_wrapper)
self.modified_funcs.append(topi_compute)

return compute_func.fdefault(*args)
_local_scope(topi_compute)

# register topi schedule for "tracing" target
for topi_compute in self.topi_to_task:
for topi_schedule in self.topi_to_schedule[topi_compute]:
def _local_scope_(schedule_func):
"""start a scope to hold the local function in for loop"""
return self

@schedule_func.register("tracing", )
def _tracing_topi_compute(outs):
outs = [outs] if isinstance(outs, tensor.Tensor) else outs
return create_schedule([x.op for x in outs])
_local_scope_(topi_schedule)
def __exit__(self, exc_type, exc_val, exc_tb):
# revert modification
for func in self.modified_funcs:
self.func_to_reflection[func](func)

def _register_topi_task(self):
"""register tuning wrapper for topi function"""
if TaskExtractEnv.registered:
return
TaskExtractEnv.registered = True
import topi

# Tuning wrapper for topi functions
Expand Down Expand Up @@ -175,17 +192,6 @@ def _topi_nn_dense(*args, **kwargs):
return s, [data, weight, bias, C]
return s, [data, weight, C]

def reset(self, wanted_topi_funcs):
"""Reset task collections
Parameters
----------
wanted_topi_funcs: List of function
The topi function to be extracted
"""
self.task_collection = []
self.wanted_topi_funcs = wanted_topi_funcs

def get_tasks(self):
"""Get collected tasks
Expand All @@ -196,25 +202,11 @@ def get_tasks(self):
"""
return self.task_collection

@staticmethod
def get():
"""Get the single instance of TaskExtractEnv
Returns
-------
env: TaskExtractEnv
The single instance of TaskExtractEnv
"""
if not TaskExtractEnv.current:
TaskExtractEnv.current = TaskExtractEnv()
return TaskExtractEnv.current


def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
""" Extract tuning tasks from a nnvm graph.
This function collects tuning tasks by building the graph
with a "tracing" target and tracing all the calls to topi.
This function collects tuning tasks by building the graph and trace all the calls to topi.
Parameters
----------
Expand All @@ -237,97 +229,34 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
collected tasks
"""
import nnvm.compiler
import topi

env = TaskExtractEnv.get()
env = TaskExtractEnv(symbols)

topi_funcs = []
for sym_name in symbols:
if sym_name in env.symbol2topi:
topi_funcs.extend(env.symbol2topi[sym_name])
else:
warnings.warn("Symbol %s is not tunable, ignored" % sym_name)
with env:
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True

# run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs)
# run compiler to collect all TOPI calls during compilation
nnvm.compiler.engine.clear_cache()
nnvm.compiler.build(graph, target=target, shape=shape, dtype=dtype)
nnvm.compiler.engine.clear_cache()

# disable logger temporarily
old_state = logger.disabled
logger.disabled = True

# use a "tracing" target to do a fake compile for collecting topi calls
tracing_target = _target.create("llvm -device=tracing")
nnvm.compiler.engine.clear_cache()
nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype)

logger.disabled = old_state
logger.disabled = old_state

# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
tasks.append(create(task_name, args,
target=target, target_host=target_host,
template_key='direct'))
try:
tsk = create(task_name, args,
target=target, target_host=target_host,
template_key='direct')
tasks.append(tsk)
except topi.InvalidShapeError:
print("shape error")

return tasks


def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_host=None):
""" Extract tuning tasks from multiple nnvm graphs.
This function is the multiple graph version of extract_from_graph
Parameters
----------
graphs : List of Graph
The list of graphs to tune
shapes : List of dict of str to tuple
The input shape to the graph
dtypes : List of str or dict of str to str
The input types to the graph
target: tvm.target.Target
The compilation target
symbols : Array of nnvm.symbol
Array of nnvm symbols want to be tuned
target_host: tvm.target.Target
The host compilation target
Returns
-------
task: Array of autotvm.task.Task
collected tasks
"""
import nnvm.compiler

env = TaskExtractEnv.get()

topi_funcs = []
for sym_name in symbols:
if sym_name in env.symbol2topi:
topi_funcs.extend(env.symbol2topi[sym_name])
else:
warnings.warn("Symbol %s is not tunable, ignored" % sym_name)

# run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs)

# disable logger temporarily
old_state = logger.disabled
logger.disabled = True

# use a "tracing" target to do a fake compile for collecting topi calls
tracing_target = _target.create("llvm -device=tracing")

nnvm.compiler.engine.clear_cache()
for graph, shape, dtype in zip(graphs, shapes, dtypes):
nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype)

logger.disabled = old_state

# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
tasks.append(create(task_name, args,
target=target, target_host=target_host,
template_key='direct'))

return tasks
def extract_from_multiple_graph(graph, shape, dtype, target, symbols, target_host=None):
pass
7 changes: 7 additions & 0 deletions python/tvm/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,13 @@ def rasp(options=None):
return arm_cpu('rasp3b', options)


def vta(model='unknown', options=None):
opts = ["-device=vta", '-keys=cpu', '-model=%s' % model]
opts = _merge_opts(opts, options)
ret = _api_internal._TargetCreate("ext_dev", *opts)
return ret


def create(target_str):
"""Get a target given target string.
Expand Down
11 changes: 8 additions & 3 deletions src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Target CreateTarget(const std::string& target_name,

std::string libs_flag = "-libs=";
std::string device_flag = "-device=";
std::string keys_flag = "-keys=";
for (auto& item : options) {
t->options_array.push_back(ir::StringImm::make(item));

Expand All @@ -50,12 +51,16 @@ Target CreateTarget(const std::string& target_name,
}
} else if (item.find(device_flag) == 0) {
t->device_name = item.substr(device_flag.length());
t->keys_array.push_back(ir::StringImm::make(t->device_name));
} else if (item.find(keys_flag) == 0) {
std::stringstream ss(item.substr(keys_flag.length()));
std::string key_item;
while (std::getline(ss, key_item, ',')) {
t->keys_array.push_back(ir::StringImm::make(key_item));
}
}
}

if (t->device_name.length() > 0) {
t->keys_array.push_back(ir::StringImm::make(t->device_name));
}
t->device_type = kDLCPU;
t->thread_warp_size = 1;
if (target_name == "llvm") {
Expand Down
4 changes: 4 additions & 0 deletions topi/python/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
from . import image
from . import sparse
from . import hls

# some short cut
from .util import InvalidShapeError

# not import testing by default
# because testing can have extra deps that are not necessary
# we can import them from test cases explicitly
Expand Down
4 changes: 4 additions & 0 deletions topi/python/topi/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import tvm
from . import tag

class InvalidShapeError(ValueError):
"""Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)"""
pass

def traverse_inline(s, final_op, callback):
"""Traverse computation graph and do auto inline
Expand Down
3 changes: 2 additions & 1 deletion vta/python/vta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# to maintain minimum dependency on the board
if sys.argv[0] not in ("-c", "-m"):
from . import top
from .build_module import build_config, lower, build
from . import graph

from .build_module import build_config, lower, build, vta_autotvm_build_func
from .ptr_alias import reinterpret
Loading

0 comments on commit 7055803

Please sign in to comment.