Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
Address PR comments and fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
terryyylim committed May 9, 2022
1 parent bfce0c0 commit a8791f1
Show file tree
Hide file tree
Showing 12 changed files with 172 additions and 38 deletions.
6 changes: 3 additions & 3 deletions mlem/api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def apply(
target_repo: str = None,
link: bool = None,
external: bool = None,
batch: Optional[int] = None,
batch_size: Optional[int] = None,
) -> Optional[Any]:
"""Apply provided model against provided data
Expand All @@ -87,10 +87,10 @@ def apply(
except WrongMethodError:
resolved_method = PREDICT_METHOD_NAME
echo(EMOJI_APPLY + f"Applying `{resolved_method}` method...")
if batch:
if batch_size:
res: Any = []
for part in data:
batch_dataset = get_dataset_value(part, batch)
batch_dataset = get_dataset_value(part, batch_size)
for chunk in batch_dataset:
preds = w.call_method(resolved_method, chunk.data)
res += [*preds]
Expand Down
9 changes: 4 additions & 5 deletions mlem/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,16 @@
from mlem.core.objects import DatasetMeta, MlemMeta, ModelMeta


def get_dataset_value(dataset: Any, batch: Optional[int] = None) -> Any:
def get_dataset_value(dataset: Any, batch_size: Optional[int] = None) -> Any:
if isinstance(dataset, str):
return load(dataset)
if isinstance(dataset, DatasetMeta):
# TODO: https://github.com/iterative/mlem/issues/29
# fix discrepancies between model and data meta objects
if not hasattr(dataset.dataset, "data"):
if batch:
return dataset.read_batch(batch)
else:
dataset.load_value()
if batch_size:
return dataset.read_batch(batch_size)
dataset.load_value()
return dataset.data

# TODO: https://github.com/iterative/mlem/issues/29
Expand Down
13 changes: 8 additions & 5 deletions mlem/cli/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,11 @@ def apply(
# TODO: change ImportHook to MlemObject to support ext machinery
help=f"Specify how to read data file for import. Available types: {list_implementations(ImportHook)}",
),
batch: Optional[int] = Option(
None, "-b", "--batch", help="Batch size for reading data in batches."
batch_size: Optional[int] = Option(
None,
"-b",
"--batch_size",
help="Batch size for reading data in batches.",
),
link: bool = option_link,
external: bool = option_external,
Expand All @@ -80,7 +83,7 @@ def apply(

with set_echo(None if json else ...):
if import_:
if batch:
if batch_size:
raise UnsupportedDatasetBatchLoading(
"Batch data loading is currently not supported for loading data on-the-fly"
)
Expand All @@ -92,7 +95,7 @@ def apply(
data,
data_repo,
data_rev,
load_value=batch is None,
load_value=batch_size is None,
force_type=DatasetMeta,
)
meta = load_meta(model, repo, rev, force_type=ModelMeta)
Expand All @@ -104,7 +107,7 @@ def apply(
output=output,
link=link,
external=external,
batch=batch,
batch_size=batch_size,
)
if output is None and json:
print(
Expand Down
2 changes: 1 addition & 1 deletion mlem/contrib/lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def read(self, artifacts: Artifacts) -> DatasetType:
)

def read_batch(
self, artifacts: Artifacts, batch: int
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError

Expand Down
4 changes: 2 additions & 2 deletions mlem/contrib/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def read(self, artifacts: Artifacts) -> DatasetType:
return self.dataset_type.copy().bind(data)

def read_batch(
self, artifacts: Artifacts, batch: int
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError

Expand Down Expand Up @@ -222,6 +222,6 @@ def read(self, artifacts: Artifacts) -> DatasetType:
return self.dataset_type.copy().bind(data)

def read_batch(
self, artifacts: Artifacts, batch: int
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError
18 changes: 11 additions & 7 deletions mlem/contrib/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,14 +529,14 @@ def read_html(*args, **kwargs):
}


def get_pandas_batch_formats(batch: int):
def get_pandas_batch_formats(batch_size: int):
PANDAS_FORMATS = {
"csv": PandasFormat(
pd.read_csv,
pd.DataFrame.to_csv,
file_name="data.csv",
write_args={"index": False},
read_args={"chunksize": batch},
read_args={"chunksize": batch_size},
),
"json": PandasFormat(
pd.read_json,
Expand All @@ -551,13 +551,17 @@ def get_pandas_batch_formats(batch: int):
# Pandas supports batch-reading for JSON only if the JSON file is line-delimited
# and orient to be records
# https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html#line-delimited-json
read_args={"chunksize": batch, "orient": "records", "lines": True},
read_args={
"chunksize": batch_size,
"orient": "records",
"lines": True,
},
),
"stata": PandasFormat( # TODO int32 converts to int64 for some reason
pd.read_stata,
pd.DataFrame.to_stata,
write_args={"write_index": False},
read_args={"chunksize": batch},
read_args={"chunksize": batch_size},
),
}

Expand Down Expand Up @@ -599,7 +603,7 @@ def read(self, artifacts: Artifacts) -> DatasetType:
return self.dataset_type.copy().bind(data)

def read_batch(
self, artifacts: Artifacts, batch: int
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError

Expand Down Expand Up @@ -633,9 +637,9 @@ def read(self, artifacts: Artifacts) -> DatasetType:
)

def read_batch(
self, artifacts: Artifacts, batch: int
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
batch_formats = get_pandas_batch_formats(batch)
batch_formats = get_pandas_batch_formats(batch_size)
if self.format not in batch_formats:
raise UnsupportedDatasetBatchLoadingType(self.format)
fmt = batch_formats[self.format]
Expand Down
2 changes: 1 addition & 1 deletion mlem/contrib/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def read(self, artifacts: Artifacts) -> DatasetType:
return self.dataset_type.copy().bind(data)

def read_batch(
self, artifacts: Artifacts, batch: int
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError

Expand Down
10 changes: 5 additions & 5 deletions mlem/core/dataset_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def read(self, artifacts: Artifacts) -> DatasetType:

@abstractmethod
def read_batch(
self, artifacts: Artifacts, batch: int
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError

Expand Down Expand Up @@ -219,7 +219,7 @@ def read(self, artifacts: Artifacts) -> DatasetType:
return self.dataset_type.copy().bind(data)

def read_batch(
self, artifacts: Artifacts, batch: int
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError

Expand Down Expand Up @@ -301,7 +301,7 @@ def read(self, artifacts: Artifacts) -> DatasetType:
return self.dataset_type.copy().bind(data_list)

def read_batch(
self, artifacts: Artifacts, batch: int
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError

Expand Down Expand Up @@ -404,7 +404,7 @@ def read(self, artifacts: Artifacts) -> DatasetType:
return self.dataset_type.copy().bind(data_list)

def read_batch(
self, artifacts: Artifacts, batch: int
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError

Expand Down Expand Up @@ -564,7 +564,7 @@ def read(self, artifacts: Artifacts) -> DatasetType:
return self.dataset_type.copy().bind(data_dict)

def read_batch(
self, artifacts: Artifacts, batch: int
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError

Expand Down
8 changes: 4 additions & 4 deletions mlem/core/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def load(
path: str,
repo: Optional[str] = None,
rev: Optional[str] = None,
batch: Optional[int] = None,
batch_size: Optional[int] = None,
follow_links: bool = True,
) -> Any:
"""Load python object saved by MLEM
Expand All @@ -127,10 +127,10 @@ def load(
repo=repo,
rev=rev,
follow_links=follow_links,
load_value=batch is None,
load_value=batch_size is None,
)
if isinstance(meta, DatasetMeta) and batch:
return meta.read_batch(batch)
if isinstance(meta, DatasetMeta) and batch_size:
return meta.read_batch(batch_size)
return meta.get_value()


Expand Down
4 changes: 2 additions & 2 deletions mlem/core/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,9 +677,9 @@ def write_value(self) -> Artifacts:
def load_value(self):
self.dataset = self.reader.read(self.relative_artifacts)

def read_batch(self, batch: int) -> Iterator[DatasetType]:
def read_batch(self, batch_size: int) -> Iterator[DatasetType]:
assert isinstance(self.reader, DatasetReader)
return self.reader.read_batch(self.relative_artifacts, batch)
return self.reader.read_batch(self.relative_artifacts, batch_size)

def get_value(self):
return self.data
Expand Down
82 changes: 81 additions & 1 deletion tests/cli/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

import pytest
from numpy import ndarray
from pandas import DataFrame
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

from mlem.api import load
from mlem.api import load, save
from mlem.core.errors import MlemRootNotFound
from mlem.runtime.client.base import HTTPClient
from tests.conftest import MLEM_TEST_REPO, long, need_test_repo_auth
Expand Down Expand Up @@ -38,6 +40,52 @@ def test_apply(runner, model_path, data_path):
assert isinstance(predictions, ndarray)


@pytest.fixture
def model_train_batch():
train, target = load_iris(return_X_y=True)
train = DataFrame(train)
train.columns = train.columns.astype(str)
model = DecisionTreeClassifier().fit(train, target)
return model, train


@pytest.fixture
def model_path_batch(model_train_batch, tmp_path_factory):
path = os.path.join(tmp_path_factory.getbasetemp(), "saved-model")
model, train = model_train_batch
save(model, path, tmp_sample_data=train, link=False)
yield path


@pytest.fixture
def data_path_batch(model_train_batch, tmpdir_factory):
temp_dir = str(tmpdir_factory.mktemp("saved-data") / "data")
save(model_train_batch[1], temp_dir, link=False)
yield temp_dir


def test_apply_batch(runner, model_path_batch, data_path_batch):
with tempfile.TemporaryDirectory() as dir:
path = posixpath.join(dir, "data")
result = runner.invoke(
[
"apply",
model_path_batch,
data_path_batch,
"-m",
"predict",
"-o",
path,
"--no-link",
"-b",
"5",
],
)
assert result.exit_code == 0, (result.output, result.exception)
predictions = load(path)
assert isinstance(predictions, ndarray)


def test_apply_with_import(runner, model_meta_saved_single, tmp_path_factory):
data_path = os.path.join(tmp_path_factory.getbasetemp(), "import_data")
load_iris(return_X_y=True, as_frame=True)[0].to_csv(data_path, index=False)
Expand All @@ -64,6 +112,38 @@ def test_apply_with_import(runner, model_meta_saved_single, tmp_path_factory):
assert isinstance(predictions, ndarray)


def test_apply_batch_with_import(
runner, model_meta_saved_single, tmp_path_factory
):
data_path = os.path.join(tmp_path_factory.getbasetemp(), "import_data")
load_iris(return_X_y=True, as_frame=True)[0].to_csv(data_path, index=False)

with tempfile.TemporaryDirectory() as dir:
path = posixpath.join(dir, "data")
result = runner.invoke(
[
"apply",
model_meta_saved_single.loc.uri,
data_path,
"-m",
"predict",
"-o",
path,
"--no-link",
"--import",
"--it",
"pandas[csv]",
"-b",
"2",
],
)
assert result.exit_code == 1, (result.output, result.exception)
assert (
"Batch data loading is currently not supported for loading data on-the-fly"
in result.output
)


def test_apply_no_output(runner, model_path, data_path):
result = runner.invoke(
["apply", model_path, data_path, "-m", "predict", "--no-link"],
Expand Down
Loading

0 comments on commit a8791f1

Please sign in to comment.