Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
fix batch and dataset types
Browse files Browse the repository at this point in the history
  • Loading branch information
mike0sv committed May 19, 2022
1 parent 6f98313 commit 6eb6951
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 57 deletions.
16 changes: 4 additions & 12 deletions mlem/api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,8 @@ def apply(
return res[0]
return res
if len(res) == 1:
return save(
res[0], output, repo=target_repo, external=external, index=index
)

raise NotImplementedError(
"Saving several input data objects is not implemented yet"
)
res = res[0]
return save(res, output, repo=target_repo, external=external, index=index)


def apply_remote(
Expand Down Expand Up @@ -152,11 +147,8 @@ def apply_remote(
return res[0]
return res
if len(res) == 1:
return save(res[0], output, repo=target_repo, index=index)

raise NotImplementedError(
"Saving several input data objects is not implemented yet"
)
res = res[0]
return save(res, output, repo=target_repo, index=index)


def clone(
Expand Down
2 changes: 1 addition & 1 deletion mlem/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def get_dataset_value(dataset: Any, batch_size: Optional[int] = None) -> Any:
if isinstance(dataset, str):
return load(dataset)
return load(dataset, batch_size=batch_size)
if isinstance(dataset, MlemDataset):
# TODO: https://github.com/iterative/mlem/issues/29
# fix discrepancies between model and data meta objects
Expand Down
3 changes: 2 additions & 1 deletion mlem/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,13 @@ def build_mlem_object(
**kwargs,
):
not_links, links = parse_links(model, str_conf or [])
if model.__is_root__:
kwargs[model.__config__.type_field] = subtype
return build_model(
model,
str_conf=not_links,
file_conf=file_conf,
conf=conf,
**{model.__config__.type_field: subtype},
**kwargs,
**links,
)
Expand Down
23 changes: 13 additions & 10 deletions mlem/core/dataset_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,9 @@ def write(
res[str(i)] = art
readers.append(elem_reader)

return ListReader(
dataset_type=dataset, readers=readers
), flatdict.FlatterDict(res, delimiter="/")
return ListReader(dataset_type=dataset, readers=readers), dict(
flatdict.FlatterDict(res, delimiter="/")
)


class ListReader(DatasetReader):
Expand Down Expand Up @@ -392,7 +392,7 @@ def write(

return (
_TupleLikeDatasetReader(dataset_type=dataset, readers=readers),
flatdict.FlatterDict(res, delimiter="/"),
dict(flatdict.FlatterDict(res, delimiter="/")),
)


Expand Down Expand Up @@ -460,9 +460,12 @@ def process(cls, obj, **kwargs) -> DatasetType:
if not py_types.intersection(
PrimitiveType.PRIMITIVES
): # py_types is guaranteed to be singleton set here
return TupleLikeListDatasetType(
items=[DatasetAnalyzer.analyze(o) for o in obj]
)
items_types = [DatasetAnalyzer.analyze(o) for o in obj]
first, *others = items_types
for other in others:
if first != other:
return TupleLikeListDatasetType(items=items_types)
return ListDatasetType(dtype=first, size=len(obj))

# optimization for large lists of same primitive type elements
return ListDatasetType(
Expand Down Expand Up @@ -552,9 +555,9 @@ def write(
)
res[key] = art
readers[key] = dtype_reader
return DictReader(
dataset_type=dataset, item_readers=readers
), flatdict.FlatterDict(res, delimiter="/")
return DictReader(dataset_type=dataset, item_readers=readers), dict(
flatdict.FlatterDict(res, delimiter="/")
)


class DictReader(DatasetReader):
Expand Down
5 changes: 4 additions & 1 deletion mlem/core/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,10 @@ def load_value(self):
self.dataset = self.reader.read(self.relative_artifacts)

def read_batch(self, batch_size: int) -> Iterator[DatasetType]:
assert isinstance(self.reader, DatasetReader)
if self.reader is None:
raise MlemObjectNotSavedError(
"Cannot read batch from not saved dataset"
)
return self.reader.read_batch(self.relative_artifacts, batch_size)

def get_value(self):
Expand Down
2 changes: 1 addition & 1 deletion mlem/polydantic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def validate(cls, value):
"""Polymorphic magic goes here"""
if isinstance(value, cls):
return value
if not cls.__is_root__:
if not cls.__is_root__ and cls.__config__.type_field not in value:
return super().validate(value)
if isinstance(value, str):
value = {cls.__config__.type_field: value}
Expand Down
14 changes: 12 additions & 2 deletions tests/cli/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from sklearn.tree import DecisionTreeClassifier

from mlem.api import load, save
from mlem.core.dataset_type import ListDatasetType
from mlem.core.errors import MlemRootNotFound
from mlem.core.metadata import load_meta
from mlem.core.objects import MlemDataset
from mlem.runtime.client import HTTPClient
from tests.conftest import MLEM_TEST_REPO, long, need_test_repo_auth

Expand Down Expand Up @@ -69,6 +72,7 @@ def test_apply_batch(runner, model_path_batch, data_path_batch):
path = posixpath.join(dir, "data")
result = runner.invoke(
[
"--tb",
"apply",
model_path_batch,
data_path_batch,
Expand All @@ -82,8 +86,12 @@ def test_apply_batch(runner, model_path_batch, data_path_batch):
],
)
assert result.exit_code == 0, (result.output, result.exception)
predictions = load(path)
assert isinstance(predictions, ndarray)
predictions_meta = load_meta(
path, load_value=True, force_type=MlemDataset
)
assert isinstance(predictions_meta.dataset, ListDatasetType)
predictions = predictions_meta.get_value()
assert isinstance(predictions, list)


def test_apply_with_import(runner, model_meta_saved_single, tmp_path_factory):
Expand All @@ -94,6 +102,7 @@ def test_apply_with_import(runner, model_meta_saved_single, tmp_path_factory):
path = posixpath.join(dir, "data")
result = runner.invoke(
[
"--tb",
"apply",
model_meta_saved_single.loc.uri,
data_path,
Expand Down Expand Up @@ -122,6 +131,7 @@ def test_apply_batch_with_import(
path = posixpath.join(dir, "data")
result = runner.invoke(
[
"--tb",
"apply",
model_meta_saved_single.loc.uri,
data_path,
Expand Down
62 changes: 33 additions & 29 deletions tests/core/test_dataset_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from mlem.core.dataset_type import (
DatasetAnalyzer,
DatasetReader,
DatasetType,
DictDatasetType,
DictReader,
Expand All @@ -11,6 +12,7 @@
PrimitiveReader,
PrimitiveType,
TupleDatasetType,
TupleLikeListDatasetType,
_TupleLikeDatasetReader,
_TupleLikeDatasetWriter,
)
Expand Down Expand Up @@ -93,11 +95,11 @@ def test_list_source():
)

assert list(artifacts.keys()) == [f"{x}/data" for x in range(len(l_value))]
assert artifacts["0"]["data"].uri.endswith("data/0")
assert artifacts["1"]["data"].uri.endswith("data/1")
assert artifacts["2"]["data"].uri.endswith("data/2")
assert artifacts["3"]["data"].uri.endswith("data/3")
assert artifacts["4"]["data"].uri.endswith("data/4")
assert artifacts["0/data"].uri.endswith("data/0")
assert artifacts["1/data"].uri.endswith("data/1")
assert artifacts["2/data"].uri.endswith("data/2")
assert artifacts["3/data"].uri.endswith("data/3")
assert artifacts["4/data"].uri.endswith("data/4")


def test_tuple():
Expand Down Expand Up @@ -150,16 +152,22 @@ def test_tuple_source():
"4/data",
"5/data",
]
assert list(artifacts["1"].keys()) == [
f"{x}/data" for x in range(len(t_value[1]))
]
assert artifacts["0"]["data"].uri.endswith("data/0")
assert artifacts["1"]["0"]["data"].uri.endswith("data/1/0")
assert artifacts["1"]["1"]["data"].uri.endswith("data/1/1")
assert artifacts["2"]["data"].uri.endswith("data/2")
assert artifacts["3"]["data"].uri.endswith("data/3")
assert artifacts["4"]["data"].uri.endswith("data/4")
assert artifacts["5"]["data"].uri.endswith("data/5")
assert artifacts["0/data"].uri.endswith("data/0")
assert artifacts["1/0/data"].uri.endswith("data/1/0")
assert artifacts["1/1/data"].uri.endswith("data/1/1")
assert artifacts["2/data"].uri.endswith("data/2")
assert artifacts["3/data"].uri.endswith("data/3")
assert artifacts["4/data"].uri.endswith("data/4")
assert artifacts["5/data"].uri.endswith("data/5")


def test_tuple_reader():
dataset_type = TupleLikeListDatasetType(items=[])
assert dataset_type.dict()["type"] == "tuple_like_list"
reader = _TupleLikeDatasetReader(dataset_type=dataset_type, readers=[])
new_reader = parse_obj_as(DatasetReader, reader.dict())
res = new_reader.read({})
assert res.data == []


def test_mixed_list_source():
Expand All @@ -181,16 +189,13 @@ def test_mixed_list_source():
"4/data",
"5/data",
]
assert list(artifacts["1"].keys()) == [
f"{x}/data" for x in range(len(t_value[1]))
]
assert artifacts["0"]["data"].uri.endswith("data/0")
assert artifacts["1"]["0"]["data"].uri.endswith("data/1/0")
assert artifacts["1"]["1"]["data"].uri.endswith("data/1/1")
assert artifacts["2"]["data"].uri.endswith("data/2")
assert artifacts["3"]["data"].uri.endswith("data/3")
assert artifacts["4"]["data"].uri.endswith("data/4")
assert artifacts["5"]["data"].uri.endswith("data/5")
assert artifacts["0/data"].uri.endswith("data/0")
assert artifacts["1/0/data"].uri.endswith("data/1/0")
assert artifacts["1/1/data"].uri.endswith("data/1/1")
assert artifacts["2/data"].uri.endswith("data/2")
assert artifacts["3/data"].uri.endswith("data/3")
assert artifacts["4/data"].uri.endswith("data/4")
assert artifacts["5/data"].uri.endswith("data/5")


def test_dict():
Expand Down Expand Up @@ -238,7 +243,6 @@ def custom_assert(x, y):
)

assert list(artifacts.keys()) == ["1/data", "2/data", "3/1/data"]
assert list(artifacts["3"].keys()) == ["1/data"]
assert artifacts["1"]["data"].uri.endswith("data/1")
assert artifacts["2"]["data"].uri.endswith("data/2")
assert artifacts["3"]["1"]["data"].uri.endswith("data/3/1")
assert artifacts["1/data"].uri.endswith("data/1")
assert artifacts["2/data"].uri.endswith("data/2")
assert artifacts["3/1/data"].uri.endswith("data/3/1")

0 comments on commit 6eb6951

Please sign in to comment.