Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
Support batch-reading for data-types with chunksize parameter (#221)
Browse files Browse the repository at this point in the history
* Implement batch dataset iterator reader

* Add non-implemented batch functions

* Address PR comments and fix linting

* Fix incorrect naming

Co-authored-by: Mikhail Sveshnikov <[email protected]>
  • Loading branch information
terryyylim and mike0sv authored May 15, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent e93d927 commit 00b0162
Showing 13 changed files with 376 additions and 16 deletions.
19 changes: 15 additions & 4 deletions mlem/api/commands.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
from collections import defaultdict
from typing import Any, Dict, Iterable, List, Optional, Type, Union

import numpy as np
from fsspec import AbstractFileSystem
from fsspec.implementations.local import LocalFileSystem

@@ -59,6 +60,7 @@ def apply(
target_repo: str = None,
index: bool = None,
external: bool = None,
batch_size: Optional[int] = None,
) -> Optional[Any]:
"""Apply provided model against provided data
@@ -85,10 +87,19 @@ def apply(
except WrongMethodError:
resolved_method = PREDICT_METHOD_NAME
echo(EMOJI_APPLY + f"Applying `{resolved_method}` method...")
res = [
w.call_method(resolved_method, get_dataset_value(part))
for part in data
]
if batch_size:
res: Any = []
for part in data:
batch_dataset = get_dataset_value(part, batch_size)
for chunk in batch_dataset:
preds = w.call_method(resolved_method, chunk.data)
res += [*preds]
res = [np.array(res)]
else:
res = [
w.call_method(resolved_method, get_dataset_value(part))
for part in data
]
if output is None:
if len(res) == 1:
return res[0]
4 changes: 3 additions & 1 deletion mlem/api/utils.py
Original file line number Diff line number Diff line change
@@ -7,13 +7,15 @@
from mlem.core.objects import MlemDataset, MlemModel, MlemObject


def get_dataset_value(dataset: Any) -> Any:
def get_dataset_value(dataset: Any, batch_size: Optional[int] = None) -> Any:
if isinstance(dataset, str):
return load(dataset)
if isinstance(dataset, MlemDataset):
# TODO: https://github.com/iterative/mlem/issues/29
# fix discrepancies between model and data meta objects
if not hasattr(dataset.dataset, "data"):
if batch_size:
return dataset.read_batch(batch_size)
dataset.load_value()
return dataset.data

14 changes: 13 additions & 1 deletion mlem/cli/apply.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
option_target_repo,
)
from mlem.core.dataset_type import DatasetAnalyzer
from mlem.core.errors import UnsupportedDatasetBatchLoading
from mlem.core.import_objects import ImportHook
from mlem.core.metadata import load_meta
from mlem.core.objects import MlemDataset, MlemModel
@@ -54,6 +55,12 @@ def apply(
# TODO: change ImportHook to MlemObject to support ext machinery
help=f"Specify how to read data file for import. Available types: {list_implementations(ImportHook)}",
),
batch_size: Optional[int] = Option(
None,
"-b",
"--batch_size",
help="Batch size for reading data in batches.",
),
index: bool = option_index,
external: bool = option_external,
json: bool = option_json,
@@ -76,6 +83,10 @@ def apply(

with set_echo(None if json else ...):
if import_:
if batch_size:
raise UnsupportedDatasetBatchLoading(
"Batch data loading is currently not supported for loading data on-the-fly"
)
dataset = import_object(
data, repo=data_repo, rev=data_rev, type_=import_type
)
@@ -84,7 +95,7 @@ def apply(
data,
data_repo,
data_rev,
load_value=True,
load_value=batch_size is None,
force_type=MlemDataset,
)
meta = load_meta(model, repo, rev, force_type=MlemModel)
@@ -96,6 +107,7 @@ def apply(
output=output,
index=index,
external=external,
batch_size=batch_size,
)
if output is None and json:
print(
7 changes: 6 additions & 1 deletion mlem/contrib/lightgbm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import posixpath
import tempfile
from typing import Any, ClassVar, List, Optional, Tuple, Type
from typing import Any, ClassVar, Iterator, List, Optional, Tuple, Type

import lightgbm as lgb
from pydantic import BaseModel
@@ -114,6 +114,11 @@ def read(self, artifacts: Artifacts) -> DatasetType:
)
)

def read_batch(
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError


class LightGBMModelIO(ModelIO):
"""
12 changes: 11 additions & 1 deletion mlem/contrib/numpy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from types import ModuleType
from typing import Any, ClassVar, List, Optional, Tuple, Type, Union
from typing import Any, ClassVar, Iterator, List, Optional, Tuple, Type, Union

import numpy as np
from pydantic import BaseModel, conlist, create_model
@@ -188,6 +188,11 @@ def read(self, artifacts: Artifacts) -> DatasetType:
data = self.dataset_type.actual_type(res)
return self.dataset_type.copy().bind(data)

def read_batch(
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError


class NumpyArrayWriter(DatasetWriter):
"""DatasetWriter implementation for numpy ndarray"""
@@ -215,3 +220,8 @@ def read(self, artifacts: Artifacts) -> DatasetType:
with artifacts[DatasetWriter.art_name].open() as f:
data = np.load(f)[DATA_KEY]
return self.dataset_type.copy().bind(data)

def read_batch(
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError
79 changes: 78 additions & 1 deletion mlem/contrib/pandas.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
Callable,
ClassVar,
Dict,
Iterator,
List,
Optional,
Tuple,
@@ -44,7 +45,11 @@
DatasetType,
DatasetWriter,
)
from mlem.core.errors import DeserializationError, SerializationError
from mlem.core.errors import (
DeserializationError,
SerializationError,
UnsupportedDatasetBatchLoadingType,
)
from mlem.core.import_objects import ExtImportHook
from mlem.core.meta_io import Location
from mlem.core.objects import MlemDataset, MlemObject
@@ -523,6 +528,45 @@ def read_html(*args, **kwargs):
}


def get_pandas_batch_formats(batch_size: int):
PANDAS_FORMATS = {
"csv": PandasFormat(
pd.read_csv,
pd.DataFrame.to_csv,
file_name="data.csv",
write_args={"index": False},
read_args={"chunksize": batch_size},
),
"json": PandasFormat(
pd.read_json,
pd.DataFrame.to_json,
file_name="data.json",
write_args={
"date_format": "iso",
"date_unit": "ns",
"orient": "records",
"lines": True,
},
# Pandas supports batch-reading for JSON only if the JSON file is line-delimited
# and orient to be records
# https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html#line-delimited-json
read_args={
"chunksize": batch_size,
"orient": "records",
"lines": True,
},
),
"stata": PandasFormat( # TODO int32 converts to int64 for some reason
pd.read_stata,
pd.DataFrame.to_stata,
write_args={"write_index": False},
read_args={"chunksize": batch_size},
),
}

return PANDAS_FORMATS


class _PandasIO(BaseModel):
format: str

@@ -557,6 +601,11 @@ def read(self, artifacts: Artifacts) -> DatasetType:
data.index.name = None
return self.dataset_type.copy().bind(data)

def read_batch(
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError


class PandasSeriesWriter(DatasetWriter, _PandasIO):
"""DatasetWriter for pandas series"""
@@ -586,6 +635,34 @@ def read(self, artifacts: Artifacts) -> DatasetType:
self.dataset_type.align(self.fmt.read(artifacts))
)

def read_batch(
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
batch_formats = get_pandas_batch_formats(batch_size)
if self.format not in batch_formats:
raise UnsupportedDatasetBatchLoadingType(self.format)
fmt = batch_formats[self.format]

read_kwargs = {}
if fmt.read_args:
read_kwargs.update(fmt.read_args)
with artifacts[DatasetWriter.art_name].open() as f:
iter_df = fmt.read_func(f, **read_kwargs)
for df in iter_df:
if self.format == "csv":
unnamed = {}
for col in df.columns:
if col.startswith("Unnamed: "):
unnamed[col] = ""
if unnamed:
df = df.rename(unnamed, axis=1)
else:
df = df.reset_index(drop=True)

yield self.dataset_type.copy().bind(
self.dataset_type.align(df)
)


class PandasWriter(DatasetWriter, _PandasIO):
"""DatasetWriter for pandas dataframes"""
7 changes: 6 additions & 1 deletion mlem/contrib/torch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, ClassVar, Optional, Tuple
from typing import Any, ClassVar, Iterator, Optional, Tuple

import torch

@@ -100,6 +100,11 @@ def read(self, artifacts: Artifacts) -> DatasetType:
data = torch.load(f)
return self.dataset_type.copy().bind(data)

def read_batch(
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError


class TorchModelIO(ModelIO):
"""
38 changes: 37 additions & 1 deletion mlem/core/dataset_type.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,17 @@
import builtins
import posixpath
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, List, Optional, Sized, Tuple, Type
from typing import (
Any,
ClassVar,
Dict,
Iterator,
List,
Optional,
Sized,
Tuple,
Type,
)

import flatdict
from pydantic import BaseModel
@@ -116,6 +126,12 @@ class Config:
def read(self, artifacts: Artifacts) -> DatasetType:
raise NotImplementedError

@abstractmethod
def read_batch(
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError


class DatasetWriter(MlemABC):
""""""
@@ -204,6 +220,11 @@ def read(self, artifacts: Artifacts) -> DatasetType:
data = self.dataset_type.to_type(res)
return self.dataset_type.copy().bind(data)

def read_batch(
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError


class ListDatasetType(DatasetType, DatasetSerializer):
"""
@@ -281,6 +302,11 @@ def read(self, artifacts: Artifacts) -> DatasetType:
data_list.append(elem_dtype.data)
return self.dataset_type.copy().bind(data_list)

def read_batch(
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError


class _TupleLikeDatasetType(DatasetType, DatasetSerializer):
"""
@@ -379,6 +405,11 @@ def read(self, artifacts: Artifacts) -> DatasetType:
data_list = self.dataset_type.actual_type(data_list)
return self.dataset_type.copy().bind(data_list)

def read_batch(
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError


class TupleLikeListDatasetType(_TupleLikeDatasetType):
"""
@@ -534,6 +565,11 @@ def read(self, artifacts: Artifacts) -> DatasetType:
data_dict[key] = v_dataset_type.data
return self.dataset_type.copy().bind(data_dict)

def read_batch(
self, artifacts: Artifacts, batch_size: int
) -> Iterator[DatasetType]:
raise NotImplementedError


#
#
19 changes: 19 additions & 0 deletions mlem/core/errors.py
Original file line number Diff line number Diff line change
@@ -70,6 +70,25 @@ class MlemObjectNotLoadedError(ValueError, MlemError):
"""Thrown if model or dataset value is not loaded"""


class UnsupportedDatasetBatchLoadingType(ValueError, MlemError):
"""Thrown if batch loading of dataset with unsupported file type is called"""

_message = "Batch-loading Dataset of type '{dataset_type}' is currently not supported. Please remove batch parameter."

def __init__(
self,
dataset_type,
) -> None:

self.dataset_type = dataset_type
self.message = self._message.format(dataset_type=dataset_type)
super().__init__(self.message)


class UnsupportedDatasetBatchLoading(MlemError):
"""Thrown if batch loading of dataset is called for import workflow"""


class WrongMethodError(ValueError, MlemError):
"""Thrown if wrong method name for model is provided"""

Loading

0 comments on commit 00b0162

Please sign in to comment.