Skip to content

Commit

Permalink
[Docs] Add more docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp v. K committed Mar 21, 2022
1 parent ee7773f commit 65c6d3b
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 13 deletions.
24 changes: 19 additions & 5 deletions mlonmcu/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@

from mlonmcu.setup import utils

# class ModelLibraryFormatPlus:
# pass

# TODO: offer pack/unpack/flatten methods for mlf
# TODO: implement restore methods
# TODO: decide if inheritance based scheme would fit better
# TODO: add artifact flags and lookup utility to find best match


class ArtifactFormat(Enum): # TODO: ArtifactType, ArtifactKind?
"""Enumeration of artifact types."""

UNKNOWN = 0
SOURCE = 1
TEXT = 2
Expand Down Expand Up @@ -64,6 +64,7 @@ def __init__(
):
# TODO: Allow to store filenames as well as raw data
self.name = name
# TODO: too many attributes...
self.content = content
self.path = path
self.data = data
Expand All @@ -75,25 +76,37 @@ def __init__(

@property
def exported(self):
"""Returns true if the artifact was writtem to disk."""
return bool(self.path is not None)

def validate(self):
"""Checker for artifact attributes for the given format."""
if self.fmt in [ArtifactFormat.TEXT, ArtifactFormat.SOURCE]:
assert self.content is not None
elif self.fmt in [ArtifactFormat.RAW, ArtifactFormat.BIN]:
assert self.raw is not None
elif self.fmt in [ArtifactFormat.MLF]:
assert self.raw is not None # TODO: load it via tvm?
assert self.raw is not None
elif self.fmt in [ArtifactFormat.PATH]:
assert self.path is not None
else:
raise NotImplementedError

def export(self, dest, extract=False):
"""Export the artifact to a given path (file or directory) and update its path.
Arguments
---------
dest : str
Path of the destination.
extract : bool
If archive: extract to destination.
"""
filename = Path(dest) / self.name
if self.fmt in [ArtifactFormat.TEXT, ArtifactFormat.SOURCE]:
assert not extract, "extract option is only available for ArtifactFormat.MLF"
with open(filename, "w") as handle:
with open(filename, "w", encoding="utf-8") as handle:
handle.write(self.content)
elif self.fmt in [ArtifactFormat.RAW, ArtifactFormat.BIN]:
assert not extract, "extract option is only available for ArtifactFormat.MLF"
Expand All @@ -113,6 +126,7 @@ def export(self, dest, extract=False):
self.path = filename if self.path is None else self.path

def print_summary(self):
"""Utility to print information about an artifact to the cmdline."""
print("Format:", self.fmt)
print("Optional: ", self.optional)
if self.fmt in [ArtifactFormat.TEXT, ArtifactFormat.SOURCE]:
Expand Down
52 changes: 48 additions & 4 deletions mlonmcu/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,64 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Collection of utilities to manage MLonMCU configs."""

from mlonmcu.feature.type import FeatureType
from mlonmcu.logging import get_logger

logger = get_logger()


def remove_config_prefix(config, prefix, skip=[]):
def remove_config_prefix(config, prefix, skip=None):
"""Iterate over keys in dict and remove given prefix.
Arguments
---------
config : dict
The configuration data.
prefix : str
The prefix to remove.
skip : List[str], optional
A list of keys which should not be altered.
Returns
-------
ret : dict
The transformed configuration.
"""
if skip is None:
skip = []

def helper(key):
return key.split(f"{prefix}.")[-1]

return {helper(key): value for key, value in config.items() if f"{prefix}." in key and key not in skip}


def filter_config(config, prefix, defaults, required_keys):
"""Filter the global config for a given component prefix.
Arguments
---------
config : dict
The configuration data.
prefix : str
The prefix for the component.
defaults : dict
The default values used if not overwritten by user.
required_keys : list
The required keys for the component.
Returns
-------
cfg : dict
The filteres configuration.
Raises
------
AssertionError: If a required key is missing.
"""
cfg = remove_config_prefix(config, prefix, skip=required_keys)
for required in required_keys:
value = None
Expand All @@ -53,7 +97,7 @@ def filter_config(config, prefix, defaults, required_keys):


def resolve_required_config(
required_keys, features=None, targets=None, config=None, cache=None
required_keys, features=None, config=None, cache=None
): # TODO: add framework, backend, and frontends as well?
"""Utility which iterates over a set of given config keys and
resolves their values using the passed config and/or cache.
Expand All @@ -75,7 +119,7 @@ def resolve_required_config(
"""

def get_cache_flags(features, targets):
def get_cache_flags(features):
result = {}
if features:
for feature in features:
Expand All @@ -84,7 +128,7 @@ def get_cache_flags(features, targets):
return result

ret = {}
cache_flags = get_cache_flags(features, targets)
cache_flags = get_cache_flags(features)
for key in required_keys:
if config is None or key not in config:
assert cache is not None, "No dependency cache was provided. Either provide a cache or config."
Expand Down
25 changes: 21 additions & 4 deletions mlonmcu/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Definitions of the Report class used by MLonMCU sessions and runs."""
from pathlib import Path
import pandas as pd

Expand All @@ -27,17 +28,28 @@


class Report:
"""Report class wrapped around multiple pandas dataframes."""

def __init__(self):
self.pre_df = pd.DataFrame()
self.main_df = pd.DataFrame()
self.post_df = pd.DataFrame()

@property
def df(self):
"""Combine the three internal dataframes to a large one and return in."""
# TODO: handle this properly by either adding NAN or use a single set(pre=, post=, main=) method
return pd.concat([self.pre_df, self.main_df, self.post_df], axis=1)

def export(self, path):
"""Export the report to a file.
Arguments
---------
path : str
Destination path.
"""
ext = Path(path).suffix[1:]
assert ext in SUPPORTED_FMTS, f"Unsupported report format: {ext}"
parent = Path(path).parent
Expand All @@ -54,25 +66,30 @@ def export(self, path):
# self.df = self.df.append(*args, **kwargs, ignore_index=True)

def set_pre(self, data):
"""Setter for the left third of the dataframe."""
self.pre_df = pd.DataFrame.from_records(data).reset_index(drop=True)

def set_post(self, data):
"""Setter for the right third of the dataframe."""
self.post_df = pd.DataFrame.from_records(data).reset_index(drop=True)

def set_main(self, data):
"""Setter for the center part of the dataframe."""
self.main_df = pd.DataFrame.from_records(data).reset_index(drop=True)

def set(self, pre=[], main=[], post=[]):
def set(self, pre=None, main=None, post=None):
"""Setter for the dataframe."""
size = len(pre)
self.set_pre(pre)
self.set_pre(pre if pre is not None else {})
if len(main) != size:
assert len(main) == 0
# self.set_main(pd.Series())
else:
self.set_main(main)
self.set_post(post)
self.set_main(main if main is not None else {})
self.set_post(post if post is not None else {})

def add(self, reports):
"""Helper function to append a line to an existing report."""
if not isinstance(reports, list):
reports = [reports]
for report in reports:
Expand Down

0 comments on commit 65c6d3b

Please sign in to comment.