Skip to content

Commit

Permalink
introduce new feature: GenRefLabels 2
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippvK committed May 27, 2024
1 parent 71e7280 commit c64bdc3
Showing 1 changed file with 115 additions and 4 deletions.
119 changes: 115 additions & 4 deletions mlonmcu/models/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class Frontend(ABC):
"gen_ref_data_mode": None,
"gen_ref_data_file": None,
"gen_ref_data_fmt": None,
"gen_ref_labels": False,
"gen_ref_labels_mode": None,
"gen_ref_labels_file": None,
"gen_ref_labels_fmt": None,
}

REQUIRED = set()
Expand Down Expand Up @@ -142,6 +146,27 @@ def gen_ref_data_fmt(self):
assert value in ["npy", "npz"]
return value

@property
def gen_ref_labels(self):
value = self.config["gen_ref_labels"]
return str2bool(value) if not isinstance(value, (bool, int)) else value

@property
def gen_ref_labels_mode(self):
value = self.config["gen_ref_labels_mode"]
assert value in ["file", "model"]
return value

@property
def gen_ref_labels_file(self):
return self.config["gen_ref_labels_file"]

@property
def gen_ref_labels_fmt(self):
value = self.config["gen_ref_labels_fmt"]
assert value in ["npy", "npz", "txt", "csv"]
return value

def inference(self, model: Model, input_data: Dict[str, np.array]):
raise NotImplementedError

Expand Down Expand Up @@ -402,13 +427,58 @@ def generate_output_ref_data(
data[output_name] = arr
outputs_data.append(data)
else:
assert self.gen_data_file is not None, "Missing value for gen_data_file"
assert self.gen_ref_data_file is not None, "Missing value for gen_ref_data_file"
file = Path(self.gen_data_file)
assert file.is_file(), f"File not found: {file}"
raise NotImplementedError
else:
raise RuntimeError(f"unsupported fill_mode: {self.gen_ref_data_mode}")
return outputs_data

def generate_ref_labels(
self, inputs_data, model, out_labels_paths, output_names, output_types, output_shapes, output_quant_details
):
assert self.gen_ref_labels
labels = []
if self.gen_ref_labels_mode == "model":
assert len(inputs_data) > 0
for i, input_data in enumerate(inputs_data):
output_data = self.inference(model, input_data, quant=False, dequant=True)
assert len(output_data) == 1, "Does not support multi-output classification"
output_data = output_data[list(output_data)[0]]
top_label = np.argmax(output_data)
labels.append(top_label)

elif self.gen_ref_labels_mode == "file":
if self.gen_ref_labels_file == "auto":
assert len(out_labels_paths) > 0, "labels_paths is empty"
assert len(out_labels_paths) == 1
file = Path(out_labels_paths[0])
# file = f"{out_paths[0]}_labels.csv"
print("file", file)
else:
assert self.gen_ref_labels_file is not None, "Missing value for gen_ref_labels_file"
file = Path(self.gen_ref_labels_file)
assert file.is_file(), f"File not found: {file}"
ext = file.suffix
assert len(ext) > 1
fmt = ext[1:].lower()
if fmt == "csv":
import pandas as pd
labels_df = pd.read_csv(file, sep=",")
assert "i" in labels_df.columns
assert "label_idx" in labels_df.columns
print("len(inputs_data)", len(inputs_data))
print("len(labels_df)", len(labels_df))
assert len(inputs_data) <= len(labels_df)
labels_df.sort_values("i", inplace=True)
labels = list(labels_df["label_idx"].astype(int))[:len(inputs_data)]
else:
raise NotImplementedError(f"Fmt not supported: {fmt}")
else:
raise RuntimeError(f"unsupported fill_mode: {self.gen_ref_labels_mode}")
return labels

def generate_model_info(
self,
input_names,
Expand Down Expand Up @@ -454,6 +524,7 @@ def process_metadata(self, model, cfg=None):
metadata = model.metadata
in_paths = []
out_paths = []
labels_paths = []
input_shapes = {}
output_shapes = {}
input_types = {}
Expand Down Expand Up @@ -521,19 +592,33 @@ def process_metadata(self, model, cfg=None):
out_path.is_dir()
), f"Output data directory defined in model metadata does not exist: {out_path}"
out_paths.append(out_path)
if self.gen_ref_labels and self.gen_ref_labels_mode == "file" and self.gen_ref_labels_file == "auto":
if "test_labels_file" in outp:
labels_file = Path(outp["test_labels_file"])
labels_path = model_dir / labels_file
assert (
labels_path.is_file()
), f"Labels file defined in model metadata does not exist: {labels_path}"
labels_paths.append(labels_path)
else:
fallback_in_path = model_dir / "input"
if fallback_in_path.is_dir():
in_paths.append(fallback_in_path)
fallback_out_path = model_dir / "output"
if fallback_out_path.is_dir():
out_paths.append(fallback_out_path)
fallback_labels_path = model_dir / "output_labels.csv"
if fallback_labels_path.is_file():
labels_paths.append(fallback_labels_path)
if model.inputs_path:
logger.info("Overriding default model input data with user path")
in_paths = [model.inputs_path]
if model.outputs_path:
logger.info("Overriding default model output data with user path")
out_paths = [model.outputs_path]
if model.output_labels_path: # TODO
logger.info("Overriding default model output labels with user path")
labels_paths = [model.output_labels_path]

if metadata is not None and "backends" in metadata:
assert cfg is not None
Expand Down Expand Up @@ -586,6 +671,8 @@ def process_metadata(self, model, cfg=None):
cfg.update({"mlif.output_data_path": out_paths})
# cfg.update({"espidf.output_data_path": out_paths})
# cfg.update({"zephyr.output_data_path": out_paths})
if len(labels_paths) > 0:
cfg.update({"mlif.output_labels_path": labels_paths})
if len(input_shapes) > 0:
cfg.update({f"{model.name}.input_shapes": input_shapes})
if len(output_shapes) > 0:
Expand Down Expand Up @@ -657,10 +744,34 @@ def process_metadata(self, model, cfg=None):
else:
raise RuntimeError(f"Unsupported fmt: {fmt}")
assert raw
outputs_data_artifact = Artifact(
outputs_ref_artifact = Artifact(
f"outputs_ref.{fmt}", raw=raw, fmt=ArtifactFormat.BIN, flags=("outputs_ref", fmt)
)
artifacts.append(outputs_data_artifact)
artifacts.append(outputs_ref_artifact)
if self.gen_ref_labels:
labels_ref = self.generate_ref_labels(
inputs_data, model, labels_paths, output_names, output_types, output_shapes, output_quant_details
)
fmt = self.gen_ref_labels_fmt
if fmt == "npy":
with tempfile.TemporaryDirectory() as tmpdirname:
tempfilename = Path(tmpdirname) / "labels.npy"
np.save(tempfilename, labels_ref)
with open(tempfilename, "rb") as f:
raw = f.read()
elif fmt == "npz":
raise NotImplementedError
elif fmt == "txt":
raise NotImplementedError
elif fmt == "csv":
raise NotImplementedError
else:
raise RuntimeError(f"Unsupported fmt: {fmt}")
assert raw
labels_ref_artifact = Artifact(
f"labels_ref.{fmt}", raw=raw, fmt=ArtifactFormat.BIN, flags=("labels_ref", fmt)
)
artifacts.append(labels_ref_artifact)
return artifacts

def generate(self, model) -> Tuple[dict, dict]:
Expand Down Expand Up @@ -762,7 +873,7 @@ def produce_artifacts(self, model):
# TODO: frontend parsed metadata instead of lookup.py?
# TODO: how to find inout_data?
class TfLiteFrontend(SimpleFrontend):
FEATURES = Frontend.FEATURES | {"visualize", "split_layers", "tflite_analyze", "gen_data", "gen_ref_data"}
FEATURES = Frontend.FEATURES | {"visualize", "split_layers", "tflite_analyze", "gen_data", "gen_ref_data", "gen_ref_labels"}

DEFAULTS = {
**Frontend.DEFAULTS,
Expand Down

0 comments on commit c64bdc3

Please sign in to comment.