Skip to content

Commit

Permalink
RelayModelInfo: support non-tensor outputs
Browse files Browse the repository at this point in the history
Thanks to @jokap11
  • Loading branch information
PhilippvK committed Dec 7, 2023
1 parent fd8d1fa commit d4338a9
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions mlonmcu/flow/tvm/backend/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def parse_relay_main(line):
output_tensors_str = re.compile(r"-> (.+) {").findall(line)
# The following depends on InferType annocations
if len(output_tensors_str) > 0:
output_tensor_strs = re.compile(r"Tensor\[\([\di]+(?:, [\di]+)*\), [a-zA-Z0-9_]+\]").findall(
output_tensor_strs = re.compile(r"Tensor\[\([\di]+(?:, [\di]+)*\), [a-zA-Z0-9_]+\]|(?:u?int\d+)").findall(
output_tensors_str[0]
)

Expand All @@ -156,10 +156,15 @@ def parse_relay_main(line):

for i, output_name in enumerate(output_tensor_names):
res = re.compile(r"Tensor\[\(([\di]+(?:, [\di]+)*)\), ([a-zA-Z0-9_]+)\]").match(output_tensor_strs[i])
if res is None:
res = re.compile(r"(u?int\d+)").match(output_tensor_strs[i])
assert res is not None
groups = res.groups()
assert len(groups) == 2
output_shape_str, output_type = groups
assert len(groups) in [1, 2]
if len(groups) == 2:
output_shape_str, output_type = groups
elif len(groups) == 1:
output_shape_str, output_type = "1, 1", groups[0]
output_shape = shape_from_str(output_shape_str)
output_tensor = TensorInfo(output_name, output_shape, output_type)
output_tensors.append(output_tensor)
Expand Down

0 comments on commit d4338a9

Please sign in to comment.