Skip to content

Commit

Permalink
feat: experimaestro classes as datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed May 31, 2024
1 parent bc2f562 commit 8240a9b
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 31 deletions.
12 changes: 9 additions & 3 deletions src/datamaestro/annotations/agreement.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import logging
from datamaestro.definitions import DatasetAnnotation, AbstractDataset, hook
from typing import Optional
from datamaestro.definitions import AbstractDataset, hook


@hook("pre-use")
def useragreement(definition: AbstractDataset, message, id=None):
def useragreement(definition: AbstractDataset, message: str, id: Optional[str] = None):
"""Asks for a user-agreement
:param definition: The dataset for which the agreement is asked
:param message: The agreement text
:param id: The ID of the agreement (default to the dataset ID)
"""
# Skip agreement when testing
if definition.context.running_test:
return
Expand Down
27 changes: 18 additions & 9 deletions src/datamaestro/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,20 @@ def repository(self, repositoryid):
if repositoryid is None:
return None

l = [
entry_points = [
x
for x in pkg_resources.iter_entry_points(
"datamaestro.repositories", repositoryid
)
]
if not l:
if not entry_points:
raise Exception("No datasets repository named %s", repositoryid)
if len(l) > 1:
if len(entry_points) > 1:
raise Exception(
"Too many datasets repository named %s (%d)" % (repositoryid, len(l))
"Too many datasets repository named %s (%d)"
% (repositoryid, len(entry_points))
)
return l[0].load()(self)
return entry_points[0].load()(self)

@property
def running_test(self):
Expand Down Expand Up @@ -175,7 +176,6 @@ def getPaths(hasher):
if dlpath.is_file():
logging.debug("Using cached file %s for %s", dlpath, url)
else:

logging.info("Downloading %s", url)
tmppath = dlpath.with_suffix(".tmp")

Expand All @@ -188,7 +188,7 @@ def getPaths(hasher):

def ask(self, question: str, options: Dict[str, str]):
"""Ask a question to the user"""
print(question)
print(question) # noqa: T201
answer = None
while answer not in options:
answer = input().strip().lower()
Expand Down Expand Up @@ -268,6 +268,7 @@ def _getdoc(self):

def __iter__(self) -> Iterable["AbstractDataset"]:
from .definitions import DatasetWrapper
from datamaestro.data import Base

# Iterates over defined symbols
for key, value in self.module.__dict__.items():
Expand All @@ -276,10 +277,18 @@ def __iter__(self) -> Iterable["AbstractDataset"]:
# Ensure it comes from the module
if self.module.__name__ == value.t.__module__:
yield value
elif (
inspect.isclass(value)
and issubclass(value, Base)
and hasattr(value, "__dataset__")
):
if self.module.__name__ == value.__module__:
yield value.__dataset__


class Repository:
"""A repository regroup a set of datasets and their corresponding specific handlers (downloading, filtering, etc.)"""
"""A repository regroup a set of datasets and their corresponding specific
handlers (downloading, filtering, etc.)"""

def __init__(self, context: Context):
"""Initialize a new repository
Expand Down Expand Up @@ -315,7 +324,7 @@ def version(cls):
try:
return get_distribution(cls.__module__).version
except DistributionNotFound:
__version__ = None
return None

def __repr__(self):
return "Repository(%s)" % self.basedir
Expand Down
68 changes: 51 additions & 17 deletions src/datamaestro/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ class AbstractDataset(AbstractData):
"""

name: Optional[str] = None
"""The name of the dataset"""

url: Optional[str] = None
"""The URL of the dataset"""

doi: Optional[str] = None
"""The DOI of this dataset"""

def __init__(self, repository: Optional["Repository"]):
super().__init__()
Expand All @@ -136,6 +143,7 @@ def __init__(self, repository: Optional["Repository"]):

# Associated resources
self.resources: Dict[str, "Download"] = {}
self.ordered_resources = []

# Hooks
# pre-use: before returning the dataset object
Expand Down Expand Up @@ -194,13 +202,15 @@ def setDataIDs(self, data: Config, id: str):
def download(self, force=False):
"""Download all the necessary resources"""
success = True
for key, resource in self.resources.items():
logging.info("Materializing %d resources", len(self.ordered_resources))
for resource in self.ordered_resources:
try:
resource.download(force)
except Exception:
logging.error("Could not download resource %s", key)
logging.error("Could not download resource %s", resource)
traceback.print_exc()
success = False
break
return success

@staticmethod
Expand Down Expand Up @@ -249,6 +259,7 @@ class DatasetWrapper(AbstractDataset):
def __init__(self, annotation, t: type):
self.t = t
self.base = annotation.base
self.config = None
assert self.base is not None, f"Could not set the Config type for {t}"

repository, components = DataDefinition.repository_relpath(t)
Expand Down Expand Up @@ -323,7 +334,18 @@ def __getattr__(self, key):
"""Returns a pointer to a potential attribute"""
return FutureAttr(self, [key])

def download(self, force=False):
if self.base is self.t:
self._prepare()
return super().download(force=force)

def _prepare(self, download=False) -> "Base":
if self.config is not None:
return self.config

if self.base is self.t:
self.config = self.base.__create_dataset__(self)

if download:
for hook in self.hooks["pre-download"]:
hook(self)
Expand All @@ -333,23 +355,23 @@ def _prepare(self, download=False) -> "Base":
for hook in self.hooks["pre-use"]:
hook(self)

resources = {key: value.prepare() for key, value in self.resources.items()}
dict = self.t(**resources)
if dict is None:
name = self.t.__name__
filename = inspect.getfile(self.t)
raise Exception(
f"The dataset method {name} defined in "
f"{filename} returned a null object"
)

# Construct the object
data = self.base(**dict)
if self.config is None:
resources = {key: value.prepare() for key, value in self.resources.items()}
dict = self.t(**resources)
if dict is None:
name = self.t.__name__
filename = inspect.getfile(self.t)
raise Exception(
f"The dataset method {name} defined in "
f"{filename} returned a null object"
)
self.config = self.base(**dict)

# Set the ids
self.setDataIDs(data, self.id)
self.setDataIDs(self.config, self.id)

return data
return self.config

@property
def _path(self) -> Path:
Expand Down Expand Up @@ -496,14 +518,26 @@ def __init__(
def __call__(self, t):
try:
if self.base is None:
# Get type from return annotation
self.base = t.__annotations__["return"]
from datamaestro.data import Base

if inspect.isclass(t) and issubclass(t, Base):
self.base = t
else:
# Get type from return annotation
try:
self.base = t.__annotations__["return"]
except KeyError:
logging.warning("No return annotation in %s", t)
raise
object.__getattribute__(t, "__datamaestro__")
raise AssertionError("@data should only be called once")
except AttributeError:
pass

dw = DatasetWrapper(self, t)
t.__dataset__ = dw
if inspect.isclass(t) and issubclass(t, Base):
return t
return dw


Expand Down
33 changes: 31 additions & 2 deletions src/datamaestro/download/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Union
from abc import ABC, abstractmethod
from datamaestro.definitions import AbstractDataset, DatasetAnnotation
from datamaestro.utils import deprecated
from attrs import define


def initialized(method):
Expand All @@ -15,7 +17,12 @@ def wrapper(self, *args, **kwargs):
return wrapper


class Download(DatasetAnnotation, ABC):
@define(kw_only=True)
class SetupOptions:
pass


class Resource(DatasetAnnotation, ABC):
"""
Base class for all download handlers
"""
Expand All @@ -24,13 +31,16 @@ def __init__(self, varname: str):
self.varname = varname
# Ensures that the object is initialized
self._post = False
self.definition = None

def annotate(self, dataset: AbstractDataset):
assert self.definition is None
# Register has a resource download
if self.varname in dataset.resources:
raise AssertionError("Name %s already declared as a resource", self.varname)

dataset.resources[self.varname] = self
dataset.ordered_resources.append(self)
self.definition = dataset

@property
Expand All @@ -53,10 +63,29 @@ def prepare(self):
"""Prepares the dataset"""
...

def setup(
self,
dataset: Union[AbstractDataset],
options: SetupOptions = None,
):
"""Direct way to setup the resource (no annotation)"""
self(dataset)
return self.prepare()


# Keeps downwards compatibility
Download = Resource


class reference(Download):
def __init__(self, varname, reference):
def __init__(self, varname=None, reference=None):
"""References another dataset
:param varname: The name of the variable
:param reference: Another dataset
"""
super().__init__(varname)
assert reference is not None, "Reference cannot be null"
self.reference = reference

def prepare(self):
Expand Down

0 comments on commit 8240a9b

Please sign in to comment.