Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TODO and relative import cleanup #24

Merged
merged 7 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions omlmd/cli.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""Command line interface for OMLMD."""

from __future__ import annotations

import logging
from pathlib import Path

import click
import cloup
import logging

from omlmd.helpers import Helper
from omlmd.model_metadata import deserialize_mdfile
from omlmd.provider import OMLMDRegistry

from .helpers import Helper
from .model_metadata import deserialize_mdfile

logger = logging.getLogger(__name__)

Expand All @@ -23,10 +23,6 @@
)


def get_OMLMDRegistry(plain_http: bool) -> OMLMDRegistry:
return OMLMDRegistry(insecure=plain_http)


@cloup.group()
def cli():
logging.basicConfig(level=logging.INFO)
Expand All @@ -45,7 +41,7 @@ def cli():
@click.option("--media-types", "-m", multiple=True, default=[])
def pull(plain_http: bool, target: str, output: Path, media_types: tuple[str]):
"""Pulls an OCI Artifact containing ML model and metadata, filtering if necessary."""
Helper(get_OMLMDRegistry(plain_http)).pull(target, output, media_types)
Helper.from_plain(plain_http).pull(target, output, media_types)


@cli.group()
Expand All @@ -58,15 +54,15 @@ def get():
@click.argument("target", required=True)
def config(plain_http: bool, target: str):
"""Outputs configuration of the given OCI Artifact for ML model and metadata."""
click.echo(Helper(get_OMLMDRegistry(plain_http)).get_config(target))
click.echo(Helper.from_plain(plain_http).get_config(target))


@cli.command()
@plain_http
@click.argument("targets", required=True, nargs=-1)
def crawl(plain_http: bool, targets: tuple[str]):
"""Crawls configuration for the given list of OCI Artifact for ML model and metadata."""
click.echo(Helper(get_OMLMDRegistry(plain_http)).crawl(targets))
click.echo(Helper.from_plain(plain_http).crawl(targets))


@cli.command()
Expand All @@ -83,15 +79,21 @@ def crawl(plain_http: bool, targets: tuple[str]):
"-m",
"--metadata",
type=click.Path(path_type=Path, exists=True, resolve_path=True),
help="Metadata file in JSON or YAML format"
help="Metadata file in JSON or YAML format",
),
cloup.option('--empty-metadata', help='Push with empty metadata', is_flag=True),
cloup.option("--empty-metadata", help="Push with empty metadata", is_flag=True),
constraint=cloup.constraints.require_one,
)
def push(plain_http: bool, target: str, path: Path, metadata: Path | None, empty_metadata: bool):
def push(
plain_http: bool,
target: str,
path: Path,
metadata: Path | None,
empty_metadata: bool,
):
"""Pushes an OCI Artifact containing ML model and metadata, supplying metadata from file as necessary"""

if empty_metadata:
logger.warning(f"Pushing to {target} with empty metadata.")
md = deserialize_mdfile(metadata) if metadata else {}
click.echo(Helper(get_OMLMDRegistry(plain_http)).push(target, path, **md))
click.echo(Helper.from_plain(plain_http).push(target, path, **md))
47 changes: 25 additions & 22 deletions omlmd/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@
import os
import urllib.request
from collections.abc import Sequence
from dataclasses import fields
from dataclasses import dataclass, field, fields
from pathlib import Path

from omlmd.constants import (
from .constants import (
FILENAME_METADATA_JSON,
FILENAME_METADATA_YAML,
MIME_APPLICATION_CONFIG,
MIME_APPLICATION_MLMODEL,
)
from omlmd.listener import Event, Listener, PushEvent
from omlmd.model_metadata import ModelMetadata
from omlmd.provider import OMLMDRegistry

from .listener import Event, Listener, PushEvent
from .model_metadata import ModelMetadata
from .provider import OMLMDRegistry

logger = logging.getLogger(__name__)

Expand All @@ -27,20 +26,18 @@ def download_file(uri: str):
return file_name


@dataclass
isinyaaa marked this conversation as resolved.
Show resolved Hide resolved
class Helper:
_listeners: list[Listener] = []

def __init__(self, registry: OMLMDRegistry | None = None):
if registry is None:
self._registry = OMLMDRegistry(
insecure=True
) # TODO: this is a bit limiting when used from CLI, to be refactored
else:
self._registry = registry
registry: OMLMDRegistry = (
isinyaaa marked this conversation as resolved.
Show resolved Hide resolved
field( # TODO: this is a bit limiting when used from CLI, to be refactored
isinyaaa marked this conversation as resolved.
Show resolved Hide resolved
default_factory=lambda: OMLMDRegistry(insecure=True)
)
)
_listeners: list[Listener] = field(default_factory=list)

@property
def registry(self):
return self._registry
@classmethod
def from_plain(cls, insecure: bool):
isinyaaa marked this conversation as resolved.
Show resolved Hide resolved
return cls(OMLMDRegistry(insecure=insecure))

def push(
self,
Expand Down Expand Up @@ -95,14 +92,20 @@ def push(
]
try:
# print(target, files, model_metadata.to_annotations_dict())
result = self._registry.push(
result = self.registry.push(
target=target,
files=files,
manifest_annotations=model_metadata.to_annotations_dict(),
manifest_config=manifest_cfg,
do_chunked=True,
)
self.notify_listeners(PushEvent(target, model_metadata))
self.notify_listeners(
PushEvent(
result.headers["Docker-Content-Digest"],
isinyaaa marked this conversation as resolved.
Show resolved Hide resolved
target,
model_metadata,
)
)
return result
finally:
if owns_meta_files:
Expand All @@ -112,10 +115,10 @@ def push(
def pull(
self, target: str, outdir: Path | str, media_types: Sequence[str] | None = None
):
self._registry.download_layers(target, outdir, media_types)
self.registry.download_layers(target, outdir, media_types)

def get_config(self, target: str) -> str:
return f'{{"reference":"{target}", "config": {self._registry.get_config(target)} }}' # this assumes OCI Manifest.Config later is JSON (per std spec)
return f'{{"reference":"{target}", "config": {self.registry.get_config(target)} }}' # this assumes OCI Manifest.Config later is JSON (per std spec)

def crawl(self, targets: Sequence[str]) -> str:
configs = map(self.get_config, targets)
Expand Down
17 changes: 9 additions & 8 deletions omlmd/listener.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import typing as t
from abc import ABC, abstractmethod
from typing import Any
from dataclasses import dataclass

from omlmd.model_metadata import ModelMetadata
from .model_metadata import ModelMetadata


class Listener(ABC):
Expand All @@ -12,19 +13,19 @@ class Listener(ABC):
"""

@abstractmethod
def update(self, source: Any, event: Event) -> None:
def update(self, source: t.Any, event: Event) -> None:
"""
Receive update event.
"""
pass


class Event:
class Event(ABC):
pass


@dataclass
class PushEvent(Event):
def __init__(self, target: str, metadata: ModelMetadata):
# TODO: cannot just receive yet the push sha, waiting for: https://github.com/oras-project/oras-py/pull/146 in a release.
self.target = target
self.metadata = metadata
sha: str
isinyaaa marked this conversation as resolved.
Show resolved Hide resolved
target: str
metadata: ModelMetadata
38 changes: 6 additions & 32 deletions omlmd/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,15 @@
import os
import tempfile

import oras.defaults
import oras.oci
import oras.provider
import oras.schemas
import oras.utils
from oras import provider
from oras.decorator import ensure_container
from oras.provider import container_type
from oras.defaults import annotation_title as ANNOTATION_TITLE
from oras.utils import sanitize_path

logger = logging.getLogger(__name__)


class OMLMDRegistry(oras.provider.Registry):
class OMLMDRegistry(provider.Registry):
@ensure_container
def download_layers(self, package, download_dir, media_types):
"""
Expand All @@ -33,8 +30,8 @@ def download_layers(self, package, download_dir, media_types):
or len(media_types) == 0
or layer["mediaType"] in media_types
):
artifact = layer["annotations"]["org.opencontainers.image.title"]
outfile = oras.utils.sanitize_path(
artifact = layer["annotations"][ANNOTATION_TITLE]
outfile = sanitize_path(
download_dir, os.path.join(download_dir, artifact)
)
path = self.download_blob(package, layer["digest"], outfile)
Expand Down Expand Up @@ -74,26 +71,3 @@ def get_config(self, package) -> str:
os.rmdir(temp_dir)
# print("Temporary directory and its contents have been removed.")
raise RuntimeError("Unable to locate config layer")

@ensure_container
def get_manifest_response(
self,
container: container_type,
allowed_media_type: list | None = None,
refresh_headers: bool = True,
) -> dict:
"""
like get_manifest but return response,
temporary until https://github.com/oras-project/oras-py/pull/146 in a release.
"""
if not allowed_media_type:
allowed_media_type = [oras.defaults.default_manifest_media_type]
headers = {"Accept": ";".join(allowed_media_type)}

if not refresh_headers:
headers.update(self.headers)

get_manifest = f"{self.prefix}://{container.manifest_url()}" # type: ignore
response = self.do_request(get_manifest, "GET", headers=headers)
self._check_200_response(response)
return response
11 changes: 6 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ scikit-learn = "^1.5.0"
ipykernel = "^6.29.4"
nbconvert = "^7.16.4"
markdown-it-py = "^3.0.0"
model-registry = "^0.2.4a1"
model-registry = ">=0.2.9,<0.3.0"
ruff = "^0.6.1"
mypy = "^1.11.1"
types-pyyaml = "^6.0.12.20240808"
Expand All @@ -50,7 +50,9 @@ target-version = "py39"
respect-gitignore = true

[tool.ruff.lint.per-file-ignores]
"*.ipynb" = ["E402"] # exclude https://docs.astral.sh/ruff/rules/module-import-not-at-top-of-file/#notebook-behavior from linting, especially for demos.
"*.ipynb" = [
"E402",
] # exclude https://docs.astral.sh/ruff/rules/module-import-not-at-top-of-file/#notebook-behavior from linting, especially for demos.

[tool.mypy]
python_version = "3.9"
Expand Down
19 changes: 12 additions & 7 deletions tests/test_e2e_model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
def from_oci_to_kfmr(
model_registry: ModelRegistry, push_event: PushEvent, sha: str
) -> RegisteredModel:
assert push_event.metadata.name
assert push_event.metadata.model_format_name
assert push_event.metadata.model_format_version
rm = model_registry.register_model(
name=push_event.metadata.name,
uri=f"oci-artifact://{push_event.target}",
Expand All @@ -35,16 +38,13 @@ def test_e2e_model_registry_scenario1(tmp_path, target):
)

class ListenerForModelRegistry(Listener):
sha = None
rm = None
sha: str
rm: RegisteredModel

def update(self, source: Helper, event: Event) -> None:
if isinstance(event, PushEvent):
self.sha = source.registry.get_manifest_response(event.target).headers[
"Docker-Content-Digest"
]
print(self.sha)
self.rm = from_oci_to_kfmr(model_registry, event, self.sha)
self.sha = event.sha
self.rm = from_oci_to_kfmr(model_registry, event, event.sha)

listener = ListenerForModelRegistry()
omlmd = Helper()
Expand All @@ -67,15 +67,18 @@ def update(self, source: Helper, event: Event) -> None:
v = quote(listener.sha)

rm = model_registry.get_registered_model("mnist")
assert rm
assert rm.id == listener.rm.id
assert rm.name == "mnist"

mv = model_registry.get_model_version("mnist", v)
assert mv
assert mv.description == "Lorem ipsum"
assert mv.author == "John Doe"
assert mv.custom_properties == {"accuracy": 0.987}

ma = model_registry.get_model_artifact("mnist", v)
assert ma
assert ma.uri == f"oci-artifact://{target}"

# curl http://localhost:5001/v2/testorgns/ml-model-artifact/manifests/v1 -H "Accept: application/vnd.oci.image.manifest.v1+json" --verbose
Expand Down Expand Up @@ -112,7 +115,9 @@ def test_e2e_model_registry_scenario2(tmp_path, target):

_ = model_registry.get_registered_model(lookup_name)
model_version = model_registry.get_model_version(lookup_name, lookup_version)
assert model_version
model_artifact = model_registry.get_model_artifact(lookup_name, lookup_version)
assert model_artifact

file_from_mr = download_file(model_artifact.uri)

Expand Down
Loading
Loading