Skip to content

Commit

Permalink
Support dump ir for each pass (apache#693) (apache#791)
Browse files Browse the repository at this point in the history
* Support dump ir for each pass(apache#693)

* expose DumpIR

* fix comments

* fix comments
  • Loading branch information
xqdan authored and sergei-mironov committed Aug 8, 2018
1 parent 71b7bb5 commit 1d5d345
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 3 deletions.
89 changes: 87 additions & 2 deletions python/tvm/build_module.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,96 @@
"""
from __future__ import absolute_import as _abs
import warnings
import types

from . import api
from . import tensor
from . import schedule
from . import expr
from . import ir_pass
from . import stmt as _stmt
from . import container
from . import module
from . import codegen
from . import ndarray
from . import target as _target

class DumpIR(object):
"""Dump IR for each pass.
With it, you can dump ir just like gcc/llvm.
How to use:
-----------
.. code-block:: python
with tvm.build_config(dump_pass_ir=True)
run()
"""
scope_level = 0
def __init__(self):
self._pass_id = 0
self._recover_list = []

def decorate(self, func):
''' decorate the pass function'''
def dump(*args, **kwargs):
'''dump function'''
retv = func(*args, **kwargs)
if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)):
return retv
pname = str(self._pass_id) + "_" + func.func_name + "_ir.cc"
with open(pname, "a") as f:
out = retv.body if isinstance(retv, container.LoweredFunc) else retv
f.write(str(out))
if isinstance(retv, container.Array):
for x in retv:
out = x.body if isinstance(x, container.LoweredFunc) else x
f.write("---------%s\n%s\n-----------\n"%(x.name, str(out)))
self._pass_id += 1
return retv
return dump

def decorate_irpass(self):
'''decorate ir_pass and ScheduleOps'''
self._old_sgpass = schedule.ScheduleOps
schedule.ScheduleOps = self.decorate(schedule.ScheduleOps)
vset = vars(ir_pass)
k = v = 0
def recover():
vset[k] = v
for k, v in vset.items():
self._recover_list.append(recover)
vset[k] = self.decorate(v) if isinstance(v, types.FunctionType) else v

def decorate_custompass(self):
''' decorate add_lower_pass pass in BuildConfig'''
cfg = BuildConfig.current
self._old_custom_pass = cfg.add_lower_pass
custom_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
pass_list = [(x[0], self.decorate(x[1])) for x in custom_pass]
BuildConfig.current.add_lower_pass = pass_list

def enter(self):
'''only decorate outermost nest'''
if DumpIR.scope_level > 0:
return
self.decorate_irpass()
self.decorate_custompass()
self._pass_id = 0
DumpIR.scope_level += 1

def exit(self):
'''recover outermost nest'''
if DumpIR.scope_level > 1:
return
# recover decorated functions
for f in self._recover_list:
f()
schedule.ScheduleOps = self._old_sgpass
BuildConfig.current.add_lower_pass = self._old_custom_pass
DumpIR.scope_level -= 1

class BuildConfig(object):
"""Configuration scope to set a build config option.
Expand All @@ -37,10 +115,12 @@ class BuildConfig(object):
"data_alignment": -1,
"restricted_func": True,
"double_buffer_split_loop": 1,
"add_lower_pass": None
"add_lower_pass": None,
"dump_pass_ir": False
}
def __init__(self, **kwargs):
self._old_scope = None
self._dump_ir = DumpIR()
for k, _ in kwargs.items():
if k not in BuildConfig.defaults:
raise ValueError(
Expand All @@ -59,10 +139,14 @@ def __enter__(self):
attr.update(self._attr)
self._attr = attr
BuildConfig.current = self
if self.dump_pass_ir is True:
self._dump_ir.enter()
return self

def __exit__(self, ptype, value, trace):
assert self._old_scope
if self.dump_pass_ir is True:
self._dump_ir.exit()
BuildConfig.current = self._old_scope


Expand Down Expand Up @@ -115,6 +199,8 @@ def build_config(**kwargs):
phase contains an integer on which optimization pass we apply the pass.
Additional lowering passes to be applied before make_api.
dump_pass_ir: dump ir of each pass into file idx_passname_ir.cc, default=False
Returns
-------
config: BuildConfig
Expand Down Expand Up @@ -247,7 +333,6 @@ def lower(sch,
return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)


def build(sch,
args=None,
target=None,
Expand Down
19 changes: 18 additions & 1 deletion tests/python/unittest/test_pass_unroll.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tvm
import os

def test_unroll_loop():
dtype = 'int64'
Expand All @@ -24,4 +25,20 @@ def test_unroll_loop():


if __name__ == "__main__":
test_unroll_loop()
with tvm.build_config(dump_pass_ir=True):
test_unroll_loop()

def end_with(*suffix):
ends = suffix
def run(s):
f = map(s.endswith, ends)
if True in f: return s
return run

file_list = os.listdir('./')
cc_file = end_with('.cc')
cc_file = filter(cc_file, file_list)
assert len(cc_file) == 3
for i in cc_file:
os.remove(i)

0 comments on commit 1d5d345

Please sign in to comment.