From 67048cac5b4865329441acec9075ca8d6840ef37 Mon Sep 17 00:00:00 2001 From: Philipp van Kempen Date: Wed, 17 Jan 2024 22:41:28 +0100 Subject: [PATCH] tflite frontend: extract model info from tflite file if definition.yml missing --- mlonmcu/models/frontend.py | 69 ++++++++++++++++++++++++++++++-------- 1 file changed, 55 insertions(+), 14 deletions(-) diff --git a/mlonmcu/models/frontend.py b/mlonmcu/models/frontend.py index ff6c7efa6..2c35cc0a0 100644 --- a/mlonmcu/models/frontend.py +++ b/mlonmcu/models/frontend.py @@ -145,6 +145,9 @@ def gen_ref_data_fmt(self): def inference(self, model: Model, input_data: Dict[str, np.array]): raise NotImplementedError + def extract_model_info(self, model: Model): + raise NotImplementedError + def supports_formats(self, ins=None, outs=None): """Returs true if the frontend can handle at least one combination of input and output formats.""" assert ins is not None or outs is not None, "Please provide a list of input formats, outputs formats or both" @@ -269,6 +272,17 @@ def process_metadata(self, model, cfg=None): flattened = {f"{backend}.{key}": value for key, value in backend_options[backend].items()} cfg.update(flattened) + if len(input_shapes) > 0: + assert len(input_types) in [len(input_shapes), 0] + input_names = list(input_shapes.keys()) + elif len(input_shapes) > 0: + input_names = list(input_types.keys()) + else: + input_names = [] + + if metadata is None: + input_names, input_shapes, input_types, input_quant_details, output_names, output_shapes, output_types, output_quant_details = self.extract_model_info(model) + # Detect model support code (Allow overwrite in metadata YAML) support_path = model_dir / "support" if support_path.is_dir(): @@ -294,13 +308,6 @@ def process_metadata(self, model, cfg=None): if len(output_types) > 0: cfg.update({f"{model.name}.output_types": output_types}) # flattened version - if len(input_shapes) > 0: - assert len(input_types) in [len(input_shapes), 0] - input_names = list(input_shapes.keys()) - elif len(input_shapes) > 0: - input_names = list(input_types.keys()) - else: - input_names = [] if len(output_shapes) > 0: assert len(output_types) in [len(output_shapes), 0] output_names = list(output_shapes.keys()) @@ -330,10 +337,10 @@ def process_metadata(self, model, cfg=None): print("ii", ii) assert input_name in input_types, f"Unknown dtype for input: {input_name}" dtype = input_types[input_name] - quant = input_quant_details.get(name, None) + quant = input_quant_details.get(input_name, None) if quant: _, _, ty = quant - dtype = ty + # dtype = ty assert input_name in input_shapes, f"Unknown shape for input: {input_name}" shape = input_shapes[input_name] if self.gen_data_fill_mode == "zeros": @@ -342,7 +349,7 @@ def process_metadata(self, model, cfg=None): arr = np.ones(shape, dtype=dtype) elif self.gen_data_fill_mode == "random": if "float" in dtype: - arr = np.rand(*shape).astype(dtype) + arr = np.random.rand(*shape).astype(dtype) elif "int" in dtype: arr = np.random.randint(np.iinfo(dtype).min, np.iinfo(dtype).max, size=shape, dtype=dtype) else: @@ -350,9 +357,8 @@ def process_metadata(self, model, cfg=None): else: assert False data[input_name] = arr + assert len(data) > 0 inputs_data.append(data) - elif self.gen_data_fill_mode == "ones": - raise NotImplementedError elif self.gen_data_fill_mode == "file": if self.gen_data_file == "auto": len(in_paths) > 0 @@ -387,6 +393,7 @@ def process_metadata(self, model, cfg=None): else: raise RuntimeError(f"Unsupported ext: {ext}") # print("temp", temp) + assert len(temp) > 0 for i in range(min(self.gen_data_number, len(temp))): print("i", i) assert i in temp @@ -396,7 +403,7 @@ def process_metadata(self, model, cfg=None): assert ii in temp[i] assert input_name in input_types, f"Unknown dtype for input: {input_name}" dtype = input_types[input_name] - quant = input_quant_details.get(name, None) + quant = input_quant_details.get(input_name, None) if quant: _, _, ty = quant dtype = ty @@ -487,7 +494,7 @@ def process_metadata(self, model, cfg=None): assert ii in temp[i] assert output_name in output_types, f"Unknown dtype for output: {output_name}" dtype = output_types[output_name] - dequant = output_quant_details.get(name, None) + dequant = output_quant_details.get(output_name, None) if dequant: _, _, ty = dequant dtype = ty @@ -691,6 +698,40 @@ def analyze_enable(self): def analyze_script(self): return self.config["analyze_script"] + def extract_model_info(self, model: Model): + import tensorflow as tf + model_path = str(model.paths[0]) + interpreter = tf.lite.Interpreter(model_path=model_path) + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + input_names = [] + input_shapes = {} + input_types = {} + input_quant_details = {} + output_names = [] + output_shapes = {} + output_types = {} + output_quant_details = {} + for inp in input_details: + name = str(inp["name"]) + input_names.append(name) + input_shapes[name] = inp["shape"].tolist() + input_types[name] = np.dtype(inp["dtype"]).name + if "quantization" in inp: + scale, zero_point = inp["quantization"] + quant = [scale, zero_point, "float32"] + input_quant_details[name] = quant + for outp in output_details: + name = str(outp["name"]) + output_names.append(name) + output_shapes[name] = outp["shape"].tolist() + output_types[name] = np.dtype(outp["dtype"]).name + if "quantization" in outp: + scale, zero_point = outp["quantization"] + quant = [scale, zero_point, "float32"] + output_quant_details[name] = quant + return input_names, input_shapes, input_types, input_quant_details, output_names, output_shapes, output_types, output_quant_details + def inference(self, model: Model, input_data: Dict[str, np.array], quant=False, dequant=False): import tensorflow as tf model_path = str(model.paths[0])