From b092b3dd8858325eee6786ce6a6e54f14f131356 Mon Sep 17 00:00:00 2001 From: brimoor Date: Tue, 10 Sep 2024 01:58:17 -0400 Subject: [PATCH 1/3] upgrades to support remote zoo models --- eta/core/models.py | 159 +++++++++++++++++++++++++++++++-------------- 1 file changed, 111 insertions(+), 48 deletions(-) diff --git a/eta/core/models.py b/eta/core/models.py index 4aa7f552..adb06399 100644 --- a/eta/core/models.py +++ b/eta/core/models.py @@ -1038,12 +1038,17 @@ class Model(Serializable): Attributes: base_name: the base name of the model (no version info) - base_filename: the base filename of the model (if any, no version info) + base_filename: the base filename or directory of the model (if any) + (no version info) + subdir: the model's subdirectory (if any) manager: the ModelManager instance that describes the remote storage location of the models_dir (if any) - version: the version of the model (if any) + author (optional): the author of the model + version: (optional) the model version + url (optional): the URL where the model is hosted + source (optional): the source of the model + license (optional): the license under which the model is distributed description: the description of the model (if any) - source: the source of the model (if any) size_bytes: the size of the model on disk (if any) default_deployment_config_dict: a dictionary representation of an `eta.core.learning.ModelConfig` describing the recommended settings @@ -1061,10 +1066,14 @@ def __init__( self, base_name, base_filename=None, + subdir=None, manager=None, + author=None, version=None, - description=None, + url=None, source=None, + license=None, + description=None, size_bytes=None, default_deployment_config_dict=None, requirements=None, @@ -1076,10 +1085,14 @@ def __init__( Args: base_name: the base name of the model base_filename (optional): the base filename for the model + subdir: the model's subdirectory (if any) manager (optional): the ModelManager for the model + author (optional): the author of the model version: (optional) the model version + url (optional): the URL where the model is hosted + source (optional): the source of the model + license (optional): the license under which the model is distributed description: (optional) the description of the model - source: (optional) the source of the model size_bytes: (optional) the size of the model on disk default_deployment_config_dict: (optional) a dictionary representation of an `eta.core.learning.ModelConfig` describing @@ -1090,10 +1103,14 @@ def __init__( """ self.base_name = base_name self.base_filename = base_filename + self.subdir = subdir self.manager = manager + self.author = author self.version = version or None - self.description = description + self.url = url self.source = source + self.license = license + self.description = description self.size_bytes = size_bytes self.default_deployment_config_dict = default_deployment_config_dict self.requirements = requirements @@ -1112,14 +1129,19 @@ def name(self): @property def filename(self): """The version-aware filename of the model.""" - if not self.has_version: - return self.base_filename - if self.base_filename is None: return None - base, ext = os.path.splitext(self.base_filename) - return base + "-v" + self.version + ext + if self.has_version: + base, ext = os.path.splitext(self.base_filename) + filename = base + "-v" + self.version + ext + else: + filename = self.base_filename + + if self.subdir is not None: + filename = os.path.join(self.subdir, filename) + + return filename @property def has_manager(self): @@ -1383,17 +1405,11 @@ def parse_name(name): Returns: base_name: the base name of the model version: the version of the model, or None if no version was found - - Raises: - ModelError: if the model name was invalid """ - chunks = name.split("@") + chunks = name.rsplit("@", 1) if len(chunks) == 1: return name, None - if chunks[1] == "" or len(chunks) > 2: - raise ModelError("Invalid model name '%s'" % name) - return chunks[0], chunks[1] @staticmethod @@ -1406,7 +1422,7 @@ def has_version_str(name): Returns: True/False """ - return bool(Model.parse_name(name)[1]) + return Model.parse_name(name)[1] is not None def attributes(self): """Returns a list of class attributes to be serialized. @@ -1417,9 +1433,12 @@ def attributes(self): return [ "base_name", "base_filename", + "author", "version", - "description", + "url", "source", + "license", + "description", "size_bytes", "manager", "default_deployment_config_dict", @@ -1429,11 +1448,12 @@ def attributes(self): ] @classmethod - def from_dict(cls, d): + def from_dict(cls, d, subdir=None): """Constructs a Model from a JSON dictionary. Args: d: a JSON dictionary + subdir (optional): a subdirectory for the model Returns: a Model instance @@ -1453,10 +1473,14 @@ def from_dict(cls, d): return cls( d["base_name"], base_filename=d.get("base_filename", None), + subdir=subdir, manager=manager, + author=d.get("author", None), version=d.get("version", None), - description=d.get("description", None), + url=d.get("url", None), source=d.get("source", None), + license=d.get("license", None), + description=d.get("description", None), size_bytes=d.get("size_bytes", None), default_deployment_config_dict=d.get( "default_deployment_config_dict", None @@ -1472,70 +1496,103 @@ class ModelsManifest(Serializable): _MODEL_CLS = Model - def __init__(self, models=None): + def __init__(self, models=None, name=None, url=None): """Creates a ModelsManifest instance. Args: models: a list of Model instances + name (optional): a name for the manifest + url (optional): the source location of the manifest """ - self.models = models or [] + if models is None: + models = [] + + if name is not None: + subdir = os.path.join(*name.split("/")) + for model in models: + model.subdir = subdir + else: + subdir = None + + self.models = models + self.name = name + self.url = url + self._subdir = subdir def __iter__(self): return iter(self.models) - def add_model(self, model): + @property + def subdir(self): + return self._subdir + + def add_model(self, model, error_level=0): """Adds the given model to the manifest. Args: model: a Model instance + error_level: the error level to use, defined as: - Raises: - ModelError: if the model conflicts with an existing model in the - manifest + 0: raise error if the model cannot be added + 1: log warning if the model cannot be added + 2: ignore models that cannot be added """ if self.has_model_with_name(model.name): - raise ModelError( + error_msg = ( "Manifest already contains model called '%s'" % model.name ) + etau.handle_error(ModelError(error_msg), error_level) + return - if model.filename is not None and self.has_model_with_filename( - model.filename - ): - raise ModelError( + if self.has_model_with_filename(model): + error_msg = ( "Manifest already contains model with filename '%s'" - % (model.filename) + % model.filename ) + etau.handle_error(ModelError(error_msg), error_level) + return if self.has_model_with_name(model.base_name): - raise ModelError( + error_msg = ( "Manifest already contains a versionless model called '%s', " - "so a versioned model is not allowed" % model.base_name - ) + "so a versioned model is not allowed" + ) % model.base_name + etau.handle_error(ModelError(error_msg), error_level) + return self.models.append(model) - def remove_model(self, name): + def remove_model(self, name, error_level=0): """Removes the model with the given name from the ModelsManifest. Args: name: the name of the model + error_level: the error level to use, defined as: - Raises: - ModelError: if the model was not found + 0: raise error if the model cannot be added + 1: log warning if the model cannot be added + 2: ignore models that cannot be added """ if not self.has_model_with_name(name): - raise ModelError("Manifest does not contain model '%s'" % name) + error_msg = "Manifest does not contain model '%s'" % name + etau.handle_error(ModelError(error_msg), error_level) + return self.models = [model for model in self.models if model.name != name] - def merge(self, models_manifest): + def merge(self, models_manifest, error_level=0): """Merges the models manifest into this one. Args: models_manifest: a ModelsManifest + error_level: the error level to use, defined as: + + 0: raise error if the model cannot be added + 1: log warning if the model cannot be added + 2: ignore models that cannot be added """ for model in models_manifest: - self.add_model(model) + self.add_model(model, error_level=error_level) def get_model_with_name(self, name): """Gets the model with the given name. @@ -1593,17 +1650,20 @@ def has_model_with_name(self, name): """ return any(name == model.name for model in self.models) - def has_model_with_filename(self, filename): - """Determines whether this manifest contains a model with the given + def has_model_with_filename(self, model): + """Determines whether this manifest contains a model with a conflicting filename. Args: - filename: the filename + model: a Model instance Returns: True/False """ - return any(filename == model.filename for model in self.models) + if model.filename is None: + return False + + return any(model.filename == m.filename for m in self.models) @staticmethod def make_manifest_path(models_dir): @@ -1664,7 +1724,10 @@ def from_dict(cls, d): Returns: a ModelsManifest """ - return cls(models=[cls._MODEL_CLS.from_dict(md) for md in d["models"]]) + models = [cls._MODEL_CLS.from_dict(md) for md in d.get("models", [])] + name = d.get("name", None) + url = d.get("url", None) + return cls(models=models, name=name, url=url) class ModelManager(Configurable, Serializable): From 4b053736ef28e2c1559149bb4e4572ae9ac8facb Mon Sep 17 00:00:00 2001 From: brimoor Date: Tue, 10 Sep 2024 02:02:20 -0400 Subject: [PATCH 2/3] linting --- eta/core/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/eta/core/models.py b/eta/core/models.py index adb06399..caf63b9b 100644 --- a/eta/core/models.py +++ b/eta/core/models.py @@ -1587,8 +1587,8 @@ def merge(self, models_manifest, error_level=0): models_manifest: a ModelsManifest error_level: the error level to use, defined as: - 0: raise error if the model cannot be added - 1: log warning if the model cannot be added + 0: raise error if a model cannot be added + 1: log warning if a model cannot be added 2: ignore models that cannot be added """ for model in models_manifest: From d72657d56e6139956e8ad9b7c79a1f2d08cfd2cb Mon Sep 17 00:00:00 2001 From: brimoor Date: Mon, 16 Sep 2024 09:15:08 -0400 Subject: [PATCH 3/3] bumping version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2112648b..1fb1cc2c 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ from wheel.bdist_wheel import bdist_wheel -VERSION = "0.12.7" +VERSION = "0.13.0" class BdistWheelCustom(bdist_wheel):