Skip to content

Commit

Permalink
[ADD][FEAT](cli) add summary command
Browse files Browse the repository at this point in the history
  • Loading branch information
wbenbihi committed Aug 23, 2022
1 parent 9fd9e19 commit 7c975af
Showing 1 changed file with 69 additions and 0 deletions.
69 changes: 69 additions & 0 deletions cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,5 +157,74 @@ def plot(verbose, shapes, types, names, expand, activation, dpi, input, output):
logger.error("Operation aborted!")


@click.command()
@click.option(
"--verbose/--no-verbose",
"-v",
default=False,
help="Activate Logs",
type=bool,
)
@click.option(
"--expand/--no-expand",
"-e",
default=False,
help="Whether to expand the nested models. (default: false)",
type=bool,
)
@click.option(
"--trainable/--no-trainable",
"-t",
default=False,
help="Whether to show if a layer is trainable. (default: false)",
type=bool,
)
@click.argument("input")
@click.argument("output")
def summary(verbose, expand, trainable, input, output):
"""Create a summary image of model Graph
NOTES:\n
To work this command requires `pydot` and graphviz to be installed
pydot:\n
`$ pip install pydot`\n
graphviz:\n
see instructions at https://graphviz.gitlab.io/download/\n
"""
if verbose:
logger.debug(f"input:\t {input}")
logger.debug(f"output:\t {output}")
try:
if verbose:
logger.info(f"Reading config from {input}...")
config = HTFConfig.parse_obj(
HTFConfigParser.parse(filename=input, verbose=verbose)
)
obj: HTFObjectReference[_HTFModelHandler] = config.model.object
if verbose:
logger.info("Building Graph...")
model_handler = obj.init(config=config.model, verbose=verbose)
model_handler()
if verbose:
logger.info("Generating Summary...")
lines = []
model_handler._model.summary(
print_fn=(lambda x: lines.append(x)),
expand_nested=expand,
show_trainable=trainable,
)
if verbose:
logger.info("Writing Summary...")
with open(output, "w") as f:
f.write("\n".join(lines))
if verbose:
logger.success("Operation completed!")
except Exception as e:
if verbose:
logger.exception(e)
logger.error("Operation aborted!")


model.add_command(log)
model.add_command(plot)
model.add_command(summary)

0 comments on commit 7c975af

Please sign in to comment.