Skip to content

Commit

Permalink
[Frontend] [Feature] add new relay frontend and relayviz feature
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippvK committed Apr 25, 2022
1 parent bb834c0 commit e18b5e0
Show file tree
Hide file tree
Showing 9 changed files with 177 additions and 6 deletions.
3 changes: 3 additions & 0 deletions mlonmcu/feature/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def get_frontend_config(self, frontend):
def add_frontend_config(self, frontend, config):
config.update(self.get_frontend_config(frontend))

def update_formats(self, frontend, input_formats, output_formats):
pass


class FrameworkFeature(FeatureBase):
"""Framework related feature"""
Expand Down
37 changes: 37 additions & 0 deletions mlonmcu/feature/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from mlonmcu.utils import is_power_of_two
from mlonmcu.config import str2bool
from mlonmcu.artifact import ArtifactFormat
from .feature import (
BackendFeature,
FrameworkFeature,
Expand Down Expand Up @@ -522,6 +523,42 @@ def get_backend_config(self, backend):
}
)

@register_feature("relayviz")
class Relayviz(FrontendFeature):
"""Visualize TVM relay models."""


DEFAULTS = {
**FeatureBase.DEFAULTS,
"plotter": "term",
}

def __init__(self, config=None):
super().__init__("relayviz", config=config)

@property
def plotter(self):
return self.config.get("plotter", None)

def get_frontend_config(self, frontend):
assert (
frontend in ["relay"]
), f"Unsupported feature '{self.name}' for frontend '{frontend}'"
return filter_none(
{
f"{frontend}.visualize_graph": self.enabled,
f"{frontend}.relayviz_plotter": self.plotter,
}
)

def update_formats(self, frontend, input_formats, output_formats):
assert (
frontend in ["relay"]
), f"Unsupported feature '{self.name}' for frontend '{frontend}'"
if self.enabled:
output_formats.append(ArtifactFormat.TEXT)



@register_feature("autotuned")
class Autotuned(BackendFeature):
Expand Down
3 changes: 3 additions & 0 deletions mlonmcu/flow/tvm/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,5 +244,8 @@ def load_model(self, model):
model_buf = handle.read()
self.model_info = get_tflite_model_info(model_buf)
self.input_shapes = {tensor.name: tensor.shape for tensor in self.model_info.in_tensors}
elif fmt == ModelFormats.RELAY:
# Warning: the wrapper generateion does currently not work because of the missing possibility to get the relay models input names and shapes
self.model_format = "relay"
else:
raise RuntimeError(f"Unsupported model format '{fmt.name}' for backend '{self.name}'")
12 changes: 12 additions & 0 deletions mlonmcu/flow/tvm/backend/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def __init__(self, t, fix_names=False):
dtype = type_lookup[t.Type()]
super().__init__(name, shape, dtype)

class RelayTensorInfo(TensorInfo):
def __init__(self, t, fix_names=False):
pass


class ModelInfo:
def __init__(self, in_tensors, out_tensors, fix_names=False):
Expand All @@ -79,7 +83,15 @@ def __init__(self, model, fix_names=False):
super().__init__(in_tensors, out_tensors)


class RelayModelInfo(ModelInfo):
def __init__(self, text, fix_names=False):
pass

def get_tflite_model_info(model_buf):
tflite_model = tflite.Model.GetRootAsModel(model_buf, 0)
model_info = TfLiteModelInfo(tflite_model)
return model_info

def get_relay_model_info(mod_text):
model_info = RelayModelInfo(mod_text)
return model_info
3 changes: 2 additions & 1 deletion mlonmcu/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
# limitations under the License.
#
from mlonmcu.models.lookup import print_summary
from .frontend import TfLiteFrontend, PackedFrontend, ONNXFrontend
from .frontend import TfLiteFrontend, PackedFrontend, ONNXFrontend, RelayFrontend

SUPPORTED_FRONTENDS = {
"tflite": TfLiteFrontend,
"relay": RelayFrontend,
"packed": PackedFrontend,
"onnx": ONNXFrontend,
} # TODO: use registry instead
Expand Down
117 changes: 114 additions & 3 deletions mlonmcu/models/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import tempfile
import multiprocessing
from pathlib import Path
from abc import ABC, abstractmethod

from mlonmcu.feature.features import get_matching_features
from mlonmcu.models.model import ModelFormats
from mlonmcu.feature.type import FeatureType
from mlonmcu.config import filter_config
from mlonmcu.config import filter_config, str2bool
from mlonmcu.artifact import Artifact, ArtifactFormat
from mlonmcu.setup import utils

Expand Down Expand Up @@ -87,13 +89,14 @@ def process_features(self, features):
return []
features = get_matching_features(features, FeatureType.FRONTEND)
for feature in features:
assert ( # If this assertion occurs, continue with the next frontend instea dof failing
assert ( # If this assertion occurs, continue with the next frontend instead of failing
# (TODO: create custom exception type)
feature.name
in self.FEATURES
), f"Incompatible feature: {feature.name}"
# Instead we might introduce self.compatible and set it to true at this line
feature.add_frontend_config(self.name, self.config)
feature.update_formats(self.name, self.input_formats, self.output_formats)
return features

@abstractmethod
Expand Down Expand Up @@ -241,7 +244,115 @@ def __init__(self, features=None, config=None):
config=config,
)

# TODO: ModelFormats.OTHER as placeholder for visualization artifacts

class RelayFrontend(SimpleFrontend):

FEATURES = Frontend.FEATURES + ["relayviz"]

DEFAULTS = {**Frontend.DEFAULTS, "visualize_graph": False, "relayviz_plotter": "term"}

REQUIRED = Frontend.REQUIRED + ["tvm.build_dir", "tvm.pythonpath"]

def __init__(self, features=None, config=None):
super().__init__(
"relay",
ModelFormats.RELAY,
features=features,
config=config,
)

@property
def visualize_graph(self):
return str2bool(self.config["visualize_graph"])

@property
def relayviz_plotter(self):
return self.config["relayviz_plotter"]

@property
def tvm_build_dir(self):
return self.config["tvm.build_dir"]

@property
def tvm_pythonpath(self):
return self.config["tvm.pythonpath"]

def produce_artifacts(self, model):
print("produce_artifacts")
assert len(self.input_formats) == len(model.paths) == 1
artifacts = []

name = model.name
path = model.paths[0]
ext = self.input_formats[0].extension
print("name", name, "path", path, "ext", ext)
with open(path, "rb") as handle: # TODO: is an onnx model raw data or text?
raw = handle.read()
print("fff", f"{name}.{ext}")
artifacts.append(Artifact(f"{name}.{ext}", raw=raw, fmt=ArtifactFormat.RAW))

if not self.visualize_graph:
assert len(self.output_formats) == 1
else:
assert len(self.output_formats) == 2

def _relayviz(in_file, out_file, plotter_name, env={}):
print("relayviz", plotter_name)
import sys
import os
sys.path.append(env["PYTHONPATH"])
os.environ["TVM_LIBRARY_PATH"] = env["TVM_LIBRARY_PATH"]
from tvm import parser
from tvm.contrib import relay_viz
from tvm.contrib.relay_viz.terminal import TermPlotter
from tvm.contrib.relay_viz.dot import DotPlotter
if plotter_name == "term":
plotter_cls = TermPlotter
elif plotter_name == "dot":
plotter_cls = DotPlotter
else:
raise RuntimeError(f"Invalid plotter name: {plotter_name}")

with open(in_file, "r", encoding="utf-8") as relay_text:
text = relay_text.read()

mod = parser.fromtext(text)

plotter_inst = plotter_cls()
viz = relay_viz.RelayVisualizer(mod, plotter=plotter_inst)
out_file_base = os.path.splitext(out_file)[0]
viz.render(filename=out_file_base)

in_file = model.paths[0]
ext = "txt" if self.relayviz_plotter == "term" else "pdf"
with tempfile.TemporaryDirectory() as tmpdirname:
out_file = str(Path(tmpdirname) / f"relayviz.{ext}")
proc = multiprocessing.Process(target=_relayviz, args=[in_file, out_file, self.relayviz_plotter], kwargs={"env": {"PYTHONPATH": self.tvm_pythonpath, "TVM_LIBRARY_PATH": self.tvm_build_dir}})
proc.start()
proc.join()

if self.relayviz_plotter == "term":
with open(out_file, "r") as handle:
relayviz_text = handle.read()

relayviz_artifact = Artifact(
f"relayviz.{ext}",
content=relayviz_text,
fmt=ArtifactFormat.TEXT,
)
else:
with open(out_file, "rb") as handle:
relayviz_data = handle.read()

relayviz_artifact = Artifact(
f"relayviz.{ext}",
raw=relayviz_data,
fmt=ArtifactFormat.RAW,
)
artifacts.append(relayviz_artifact)


return artifacts


class PackedFrontend(Frontend): # Inherit from TFLiteFrontend? -> how to do constructor?
Expand Down
4 changes: 2 additions & 2 deletions mlonmcu/models/lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def list_models(directory, depth=1, formats=None, config=None):
[Path(directory) / f"{main_model}.{ext}"],
config=main_config,
alt=main_model,
formats=[ModelFormats.TFLITE],
formats=[fmt],
)
)

Expand All @@ -122,7 +122,7 @@ def list_models(directory, depth=1, formats=None, config=None):
submodel,
[Path(directory) / f"{submodel}.{ext}"],
config=submodel_config,
formats=[ModelFormats.TFLITE],
formats=[fmt],
)
)

Expand Down
1 change: 1 addition & 0 deletions mlonmcu/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def from_extension(cls, ext):
PACKED = ModelFormat(2, ["tflm"])
IPYNB = ModelFormat(3, ["ipynb"])
ONNX = ModelFormat(4, ["onnx"])
RELAY = ModelFormat(5, ["relay"])


def parse_metadata_from_path(path):
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,6 @@ reedsolo>=1.5.3,<=1.5.4
bitstring>=3.1.6
ecdsa>=0.16.0
construct==2.10.54

# relayviz
graphviz

0 comments on commit e18b5e0

Please sign in to comment.