Skip to content

Commit

Permalink
tflite frontend: extract model info from tflite file if definition.ym…
Browse files Browse the repository at this point in the history
…l missing
  • Loading branch information
PhilippvK committed Apr 11, 2024
1 parent c6b20aa commit 67048ca
Showing 1 changed file with 55 additions and 14 deletions.
69 changes: 55 additions & 14 deletions mlonmcu/models/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def gen_ref_data_fmt(self):
def inference(self, model: Model, input_data: Dict[str, np.array]):
raise NotImplementedError

def extract_model_info(self, model: Model):
raise NotImplementedError

def supports_formats(self, ins=None, outs=None):
"""Returs true if the frontend can handle at least one combination of input and output formats."""
assert ins is not None or outs is not None, "Please provide a list of input formats, outputs formats or both"
Expand Down Expand Up @@ -269,6 +272,17 @@ def process_metadata(self, model, cfg=None):
flattened = {f"{backend}.{key}": value for key, value in backend_options[backend].items()}
cfg.update(flattened)

if len(input_shapes) > 0:
assert len(input_types) in [len(input_shapes), 0]
input_names = list(input_shapes.keys())
elif len(input_shapes) > 0:
input_names = list(input_types.keys())
else:
input_names = []

if metadata is None:
input_names, input_shapes, input_types, input_quant_details, output_names, output_shapes, output_types, output_quant_details = self.extract_model_info(model)

# Detect model support code (Allow overwrite in metadata YAML)
support_path = model_dir / "support"
if support_path.is_dir():
Expand All @@ -294,13 +308,6 @@ def process_metadata(self, model, cfg=None):
if len(output_types) > 0:
cfg.update({f"{model.name}.output_types": output_types})
# flattened version
if len(input_shapes) > 0:
assert len(input_types) in [len(input_shapes), 0]
input_names = list(input_shapes.keys())
elif len(input_shapes) > 0:
input_names = list(input_types.keys())
else:
input_names = []
if len(output_shapes) > 0:
assert len(output_types) in [len(output_shapes), 0]
output_names = list(output_shapes.keys())
Expand Down Expand Up @@ -330,10 +337,10 @@ def process_metadata(self, model, cfg=None):
print("ii", ii)
assert input_name in input_types, f"Unknown dtype for input: {input_name}"
dtype = input_types[input_name]
quant = input_quant_details.get(name, None)
quant = input_quant_details.get(input_name, None)
if quant:
_, _, ty = quant
dtype = ty
# dtype = ty
assert input_name in input_shapes, f"Unknown shape for input: {input_name}"
shape = input_shapes[input_name]
if self.gen_data_fill_mode == "zeros":
Expand All @@ -342,17 +349,16 @@ def process_metadata(self, model, cfg=None):
arr = np.ones(shape, dtype=dtype)
elif self.gen_data_fill_mode == "random":
if "float" in dtype:
arr = np.rand(*shape).astype(dtype)
arr = np.random.rand(*shape).astype(dtype)
elif "int" in dtype:
arr = np.random.randint(np.iinfo(dtype).min, np.iinfo(dtype).max, size=shape, dtype=dtype)
else:
assert False
else:
assert False
data[input_name] = arr
assert len(data) > 0
inputs_data.append(data)
elif self.gen_data_fill_mode == "ones":
raise NotImplementedError
elif self.gen_data_fill_mode == "file":
if self.gen_data_file == "auto":
len(in_paths) > 0
Expand Down Expand Up @@ -387,6 +393,7 @@ def process_metadata(self, model, cfg=None):
else:
raise RuntimeError(f"Unsupported ext: {ext}")
# print("temp", temp)
assert len(temp) > 0
for i in range(min(self.gen_data_number, len(temp))):
print("i", i)
assert i in temp
Expand All @@ -396,7 +403,7 @@ def process_metadata(self, model, cfg=None):
assert ii in temp[i]
assert input_name in input_types, f"Unknown dtype for input: {input_name}"
dtype = input_types[input_name]
quant = input_quant_details.get(name, None)
quant = input_quant_details.get(input_name, None)
if quant:
_, _, ty = quant
dtype = ty
Expand Down Expand Up @@ -487,7 +494,7 @@ def process_metadata(self, model, cfg=None):
assert ii in temp[i]
assert output_name in output_types, f"Unknown dtype for output: {output_name}"
dtype = output_types[output_name]
dequant = output_quant_details.get(name, None)
dequant = output_quant_details.get(output_name, None)
if dequant:
_, _, ty = dequant
dtype = ty
Expand Down Expand Up @@ -691,6 +698,40 @@ def analyze_enable(self):
def analyze_script(self):
return self.config["analyze_script"]

def extract_model_info(self, model: Model):
import tensorflow as tf
model_path = str(model.paths[0])
interpreter = tf.lite.Interpreter(model_path=model_path)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_names = []
input_shapes = {}
input_types = {}
input_quant_details = {}
output_names = []
output_shapes = {}
output_types = {}
output_quant_details = {}
for inp in input_details:
name = str(inp["name"])
input_names.append(name)
input_shapes[name] = inp["shape"].tolist()
input_types[name] = np.dtype(inp["dtype"]).name
if "quantization" in inp:
scale, zero_point = inp["quantization"]
quant = [scale, zero_point, "float32"]
input_quant_details[name] = quant
for outp in output_details:
name = str(outp["name"])
output_names.append(name)
output_shapes[name] = outp["shape"].tolist()
output_types[name] = np.dtype(outp["dtype"]).name
if "quantization" in outp:
scale, zero_point = outp["quantization"]
quant = [scale, zero_point, "float32"]
output_quant_details[name] = quant
return input_names, input_shapes, input_types, input_quant_details, output_names, output_shapes, output_types, output_quant_details

def inference(self, model: Model, input_data: Dict[str, np.array], quant=False, dequant=False):
import tensorflow as tf
model_path = str(model.paths[0])
Expand Down

0 comments on commit 67048ca

Please sign in to comment.