From 7c975af263c4e07b69eeea59d8906f29f773a1b4 Mon Sep 17 00:00:00 2001 From: wbenbihi Date: Tue, 23 Aug 2022 17:57:42 +0800 Subject: [PATCH] [ADD][FEAT](cli) add summary command --- cli/model.py | 69 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/cli/model.py b/cli/model.py index 28905f1..6501191 100644 --- a/cli/model.py +++ b/cli/model.py @@ -157,5 +157,74 @@ def plot(verbose, shapes, types, names, expand, activation, dpi, input, output): logger.error("Operation aborted!") +@click.command() +@click.option( + "--verbose/--no-verbose", + "-v", + default=False, + help="Activate Logs", + type=bool, +) +@click.option( + "--expand/--no-expand", + "-e", + default=False, + help="Whether to expand the nested models. (default: false)", + type=bool, +) +@click.option( + "--trainable/--no-trainable", + "-t", + default=False, + help="Whether to show if a layer is trainable. (default: false)", + type=bool, +) +@click.argument("input") +@click.argument("output") +def summary(verbose, expand, trainable, input, output): + """Create a summary image of model Graph + + NOTES:\n + To work this command requires `pydot` and graphviz to be installed + pydot:\n + `$ pip install pydot`\n + graphviz:\n + see instructions at https://graphviz.gitlab.io/download/\n + """ + if verbose: + logger.debug(f"input:\t {input}") + logger.debug(f"output:\t {output}") + try: + if verbose: + logger.info(f"Reading config from {input}...") + config = HTFConfig.parse_obj( + HTFConfigParser.parse(filename=input, verbose=verbose) + ) + obj: HTFObjectReference[_HTFModelHandler] = config.model.object + if verbose: + logger.info("Building Graph...") + model_handler = obj.init(config=config.model, verbose=verbose) + model_handler() + if verbose: + logger.info("Generating Summary...") + lines = [] + model_handler._model.summary( + print_fn=(lambda x: lines.append(x)), + expand_nested=expand, + show_trainable=trainable, + ) + if verbose: + logger.info("Writing Summary...") + with open(output, "w") as f: + f.write("\n".join(lines)) + if verbose: + logger.success("Operation completed!") + except Exception as e: + if verbose: + logger.exception(e) + logger.error("Operation aborted!") + + model.add_command(log) model.add_command(plot) +model.add_command(summary)