diff --git a/mlem/api/commands.py b/mlem/api/commands.py index 388898bb..88e26ca2 100644 --- a/mlem/api/commands.py +++ b/mlem/api/commands.py @@ -58,6 +58,7 @@ def apply( output: str = None, link: bool = None, external: bool = None, + batch: Optional[int] = None, ) -> Optional[Any]: """Apply provided model against provided data @@ -85,7 +86,7 @@ def apply( resolved_method = PREDICT_METHOD_NAME echo(EMOJI_APPLY + f"Applying `{resolved_method}` method...") res = [ - w.call_method(resolved_method, get_dataset_value(part)) + w.call_method(resolved_method, get_dataset_value(part, batch)) for part in data ] if output is None: @@ -412,6 +413,7 @@ def import_object( target_repo: Optional[str] = None, target_fs: Optional[AbstractFileSystem] = None, type_: Optional[str] = None, + batch: Optional[int] = None, copy_data: bool = True, external: bool = None, link: bool = None, @@ -426,7 +428,7 @@ def import_object( if type_ not in ImportHook.__type_map__: raise ValueError(f"Unknown import type {type_}") meta = ImportHook.__type_map__[type_].process( - loc, copy_data=copy_data, modifier=modifier + loc, copy_data=copy_data, modifier=modifier, batch=batch ) else: meta = ImportAnalyzer.analyze(loc, copy_data=copy_data) diff --git a/mlem/api/utils.py b/mlem/api/utils.py index eac9d1fe..be1b04b5 100644 --- a/mlem/api/utils.py +++ b/mlem/api/utils.py @@ -7,14 +7,17 @@ from mlem.core.objects import DatasetMeta, MlemMeta, ModelMeta -def get_dataset_value(dataset: Any) -> Any: +def get_dataset_value(dataset: Any, batch: Optional[int] = None) -> Any: if isinstance(dataset, str): return load(dataset) if isinstance(dataset, DatasetMeta): # TODO: https://github.com/iterative/mlem/issues/29 # fix discrepancies between model and data meta objects if not hasattr(dataset.dataset, "data"): - dataset.load_value() + if batch: + dataset.load_batch_value(batch) + else: + dataset.load_value() return dataset.data # TODO: https://github.com/iterative/mlem/issues/29 diff --git a/mlem/cli/apply.py b/mlem/cli/apply.py index 7c670bde..8bd11a96 100644 --- a/mlem/cli/apply.py +++ b/mlem/cli/apply.py @@ -60,6 +60,9 @@ def apply( # TODO: change ImportHook to MlemObject to support ext machinery help=f"Specify how to read data file for import. Available types: {list_implementations(ImportHook)}", ), + batch: Optional[int] = Option( + None, "-b", "--batch", help="Batch size for reading data in batches." + ), link: bool = option_link, external: bool = option_external, json: bool = option_json, @@ -83,7 +86,11 @@ def apply( with set_echo(None if json else ...): if import_: dataset = import_object( - data, repo=data_repo, rev=data_rev, type_=import_type + data, + repo=data_repo, + rev=data_rev, + type_=import_type, + batch=batch, ) else: dataset = load_meta( @@ -92,6 +99,7 @@ def apply( data_rev, load_value=True, force_type=DatasetMeta, + batch=batch, ) meta = load_meta(model, repo, rev, force_type=ModelMeta) @@ -102,6 +110,7 @@ def apply( output=output, link=link, external=external, + batch=batch, ) if output is None and json: print( diff --git a/mlem/contrib/numpy.py b/mlem/contrib/numpy.py index 141c3603..aea37701 100644 --- a/mlem/contrib/numpy.py +++ b/mlem/contrib/numpy.py @@ -1,5 +1,5 @@ from types import ModuleType -from typing import Any, ClassVar, List, Optional, Tuple, Type, Union +from typing import Any, ClassVar, Iterator, List, Optional, Tuple, Type, Union import numpy as np from pydantic import BaseModel, conlist, create_model @@ -192,3 +192,6 @@ def read(self, artifacts: Artifacts) -> DatasetType: with artifacts[DatasetWriter.art_name].open() as f: data = np.load(f)[DATA_KEY] return self.dataset_type.copy().bind(data) + + def read_batch(self, artifacts: Artifacts, batch: int) -> Iterator: + raise NotImplementedError diff --git a/mlem/contrib/pandas.py b/mlem/contrib/pandas.py index 1bca0ada..9a7245ce 100644 --- a/mlem/contrib/pandas.py +++ b/mlem/contrib/pandas.py @@ -9,6 +9,7 @@ Callable, ClassVar, Dict, + Iterator, List, Optional, Tuple, @@ -44,7 +45,11 @@ DatasetType, DatasetWriter, ) -from mlem.core.errors import DeserializationError, SerializationError +from mlem.core.errors import ( + DeserializationError, + SerializationError, + UnsupportedDatasetBatchLoadingType, +) from mlem.core.import_objects import ExtImportHook from mlem.core.meta_io import Location from mlem.core.metadata import get_object_metadata @@ -386,6 +391,9 @@ class PandasFormat: file_name: str = "data.pd" string_buffer: bool = False + def update_variable(self, name: str, val: Any): + setattr(self, name, val) + def read(self, artifacts: Artifacts, **kwargs): """Read DataFrame""" read_kwargs = {} @@ -472,12 +480,49 @@ def read_html(*args, **kwargs): "pickle": PandasFormat( # TODO buffer closed error for some reason pd.read_pickle, pd.DataFrame.to_pickle, file_name="data.pkl" ), - "strata": PandasFormat( # TODO int32 converts to int64 for some reason - pd.read_stata, pd.DataFrame.to_stata, write_args={"write_index": False} + "stata": PandasFormat( # TODO int32 converts to int64 for some reason + pd.read_stata, + pd.DataFrame.to_stata, + write_args={"write_index": False}, ), } +def get_pandas_batch_formats(batch: int): + PANDAS_FORMATS = { + "csv": PandasFormat( + pd.read_csv, + pd.DataFrame.to_csv, + file_name="data.csv", + write_args={"index": False}, + read_args={"chunksize": batch}, + ), + "json": PandasFormat( + pd.read_json, + pd.DataFrame.to_json, + file_name="data.json", + write_args={ + "date_format": "iso", + "date_unit": "ns", + "orient": "records", + "lines": True, + }, + # Pandas supports batch-reading for JSON only if the JSON file is line-delimited + # and orient to be records + # https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html#line-delimited-json + read_args={"chunksize": batch, "orient": "records", "lines": True}, + ), + "stata": PandasFormat( # TODO int32 converts to int64 for some reason + pd.read_stata, + pd.DataFrame.to_stata, + write_args={"write_index": False}, + read_args={"chunksize": batch}, + ), + } + + return PANDAS_FORMATS + + class _PandasIO(BaseModel): format: str @@ -505,6 +550,29 @@ def read(self, artifacts: Artifacts) -> DatasetType: self.dataset_type.align(self.fmt.read(artifacts)) ) + def read_batch(self, artifacts: Artifacts, batch: int) -> Iterator: + batch_formats = get_pandas_batch_formats(batch) + if self.format not in batch_formats: + raise UnsupportedDatasetBatchLoadingType(self.format) + fmt = batch_formats[self.format] + + read_kwargs = {} + if fmt.read_args: + read_kwargs.update(fmt.read_args) + with artifacts[DatasetWriter.art_name].open() as f: + iter_df = fmt.read_func(f, **read_kwargs) + for df in iter_df: + unnamed = {} + for col in df.columns: + if col.startswith("Unnamed: "): + unnamed[col] = "" + if unnamed: + df = df.rename(unnamed, axis=1) + + yield self.dataset_type.copy().bind( + self.dataset_type.align(df) + ) + class PandasWriter(DatasetWriter, _PandasIO): """DatasetWriter for pandas dataframes""" @@ -537,10 +605,13 @@ def process( obj: Location, copy_data: bool = True, modifier: Optional[str] = None, + batch: Optional[int] = None, **kwargs, ) -> MlemMeta: ext = modifier or posixpath.splitext(obj.path)[1][1:] fmt = PANDAS_FORMATS[ext] + if batch: + fmt = get_pandas_batch_formats(batch)[ext] read_args = fmt.read_args or {} read_args.update(kwargs) with obj.open("rb") as f: diff --git a/mlem/core/dataset_type.py b/mlem/core/dataset_type.py index b59236ca..0625d962 100644 --- a/mlem/core/dataset_type.py +++ b/mlem/core/dataset_type.py @@ -3,7 +3,17 @@ """ import builtins from abc import ABC, abstractmethod -from typing import Any, ClassVar, Dict, List, Optional, Sized, Tuple, Type +from typing import ( + Any, + ClassVar, + Dict, + Iterator, + List, + Optional, + Sized, + Tuple, + Type, +) from pydantic import BaseModel from pydantic.main import create_model @@ -404,6 +414,10 @@ class Config: def read(self, artifacts: Artifacts) -> DatasetType: raise NotImplementedError + @abstractmethod + def read_batch(self, artifacts: Artifacts, batch: int) -> Iterator: + raise NotImplementedError + class DatasetWriter(MlemObject): """""" diff --git a/mlem/core/errors.py b/mlem/core/errors.py index 709b58a6..343926cc 100644 --- a/mlem/core/errors.py +++ b/mlem/core/errors.py @@ -70,6 +70,21 @@ class MlemObjectNotLoadedError(ValueError, MlemError): """Thrown if model or dataset value is not loaded""" +class UnsupportedDatasetBatchLoadingType(ValueError, MlemError): + """Thrown if batch loading of dataset with unsupported file type is called""" + + _message = "Batch-loading Dataset of type '{dataset_type}' is currently not supported. Please remove batch parameter." + + def __init__( + self, + dataset_type, + ) -> None: + + self.dataset_type = dataset_type + self.message = self._message.format(dataset_type=dataset_type) + super().__init__(self.message) + + class WrongMethodError(ValueError, MlemError): """Thrown if wrong method name for model is provided""" diff --git a/mlem/core/metadata.py b/mlem/core/metadata.py index 885d0989..f4c64743 100644 --- a/mlem/core/metadata.py +++ b/mlem/core/metadata.py @@ -107,6 +107,7 @@ def load( path: str, repo: Optional[str] = None, rev: Optional[str] = None, + batch: Optional[int] = None, follow_links: bool = True, ) -> Any: """Load python object saved by MLEM @@ -115,6 +116,7 @@ def load( path (str): Path to the object. Could be local path or path inside a git repo. repo (Optional[str], optional): URL to repo if object is located there. rev (Optional[str], optional): revision, could be git commit SHA, branch name or tag. + batch (Optional[int], optional): batch, required if performing batch-reading of Dataset. follow_links (bool, optional): If object we read is a MLEM link, whether to load the actual object link points to. Defaults to True. @@ -127,6 +129,7 @@ def load( rev=rev, follow_links=follow_links, load_value=True, + batch=batch, ) return meta.get_value() @@ -142,6 +145,7 @@ def load_meta( follow_links: bool = True, load_value: bool = False, fs: Optional[AbstractFileSystem] = None, + batch: Optional[int] = None, *, force_type: Literal[None] = None, ) -> MlemMeta: @@ -156,6 +160,7 @@ def load_meta( follow_links: bool = True, load_value: bool = False, fs: Optional[AbstractFileSystem] = None, + batch: Optional[int] = None, *, force_type: Optional[Type[T]] = None, ) -> T: @@ -169,6 +174,7 @@ def load_meta( follow_links: bool = True, load_value: bool = False, fs: Optional[AbstractFileSystem] = None, + batch: Optional[int] = None, *, force_type: Optional[Type[T]] = None, ) -> T: @@ -199,6 +205,9 @@ def load_meta( follow_links=follow_links, ) if load_value: + if isinstance(meta, DatasetMeta) and batch: + meta.load_batch_value(batch) + return meta # type: ignore[return-value] meta.load_value() if not isinstance(meta, cls): raise WrongMetaType(meta, force_type) diff --git a/mlem/core/objects.py b/mlem/core/objects.py index 2b5c81c8..a90cb921 100644 --- a/mlem/core/objects.py +++ b/mlem/core/objects.py @@ -668,6 +668,9 @@ def write_value(self) -> Artifacts: def load_value(self): self.dataset = self.reader.read(self.relative_artifacts) + def load_batch_value(self, batch: int): + self.dataset = self.reader.read_batch(self.relative_artifacts, batch) # type: ignore + def get_value(self): return self.data diff --git a/tests/conftest.py b/tests/conftest.py index d6069a6b..b05d1673 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import posixpath import tempfile from pathlib import Path -from typing import Any, Callable, Type +from typing import Any, Callable, Iterator, Type import git import pandas as pd @@ -18,6 +18,7 @@ from mlem.api import init, save from mlem.constants import PREDICT_ARG_NAME, PREDICT_METHOD_NAME from mlem.contrib.fastapi import FastAPIServer +from mlem.contrib.pandas import PandasReader, get_pandas_batch_formats from mlem.contrib.sklearn import SklearnModel from mlem.core.artifacts import LOCAL_STORAGE, FSSpecStorage, LocalArtifact from mlem.core.dataset_type import DatasetReader, DatasetType, DatasetWriter @@ -328,6 +329,43 @@ def dataset_write_read_check( assert new.data == dataset.data +def dataset_write_read_batch_check( + dataset: DatasetType, + format: str, + reader_type: Type[DatasetReader] = None, + custom_eq: Callable[[Any, Any], bool] = None, +): + with tempfile.TemporaryDirectory() as tmpdir: + BATCH_SIZE = 2 + storage = LOCAL_STORAGE + + fmt = get_pandas_batch_formats(BATCH_SIZE)[format] + art = fmt.write(dataset.data, storage, posixpath.join(tmpdir, "data")) + reader = PandasReader(dataset_type=dataset, format=format) + artifacts = {"data": art} + if reader_type is not None: + assert isinstance(reader, reader_type) + + df_iterable: Iterator = reader.read_batch(artifacts, BATCH_SIZE) + df = None + col_types = None + while True: + try: + chunk = next(df_iterable) + if df is None: + df = pd.DataFrame(columns=chunk.columns, dtype=col_types) + col_types = { + chunk.columns[idx]: chunk.dtypes[idx] + for idx in range(len(chunk.columns)) + } + df = df.astype(dtype=col_types) + df = pd.concat([df, chunk.data], ignore_index=True) + except StopIteration: + break + + assert custom_eq(df, dataset.data) + + def check_model_type_common_interface( model_type: ModelType, data_type: DatasetType, diff --git a/tests/contrib/test_pandas.py b/tests/contrib/test_pandas.py index 22e4102c..ea94c6e8 100644 --- a/tests/contrib/test_pandas.py +++ b/tests/contrib/test_pandas.py @@ -29,7 +29,11 @@ from mlem.core.meta_io import MLEM_EXT from mlem.core.metadata import load, save from mlem.core.objects import DatasetMeta -from tests.conftest import dataset_write_read_check, long +from tests.conftest import ( + dataset_write_read_batch_check, + dataset_write_read_check, + long, +) PD_DATA_FRAME = pd.DataFrame( [ @@ -75,7 +79,12 @@ def pandas_assert(actual: pd.DataFrame, expected: pd.DataFrame): @pytest.fixture def data(): - return pd.DataFrame([{"a": 1, "b": 3, "c": 5}, {"a": 2, "b": 4, "c": 6}]) + return pd.DataFrame( + [ + {"a": 1, "b": 3, "c": 5}, + {"a": 2, "b": 4, "c": 6}, + ] + ) @pytest.fixture @@ -120,6 +129,24 @@ def test_simple_df(data, format): ) +@for_all_formats( + exclude=[ # Following file formats do not support Pandas chunksize parameter + "html", + "excel", + "parquet", + "feather", + "pickle", + ] +) +def test_simple_batch_df(data, format): + dataset_write_read_batch_check( + DatasetType.create(data), + format, + PandasReader, + pd.DataFrame.equals, + ) + + @for_all_formats def test_with_index(data, format): writer = PandasWriter(format=format) @@ -146,7 +173,7 @@ def test_with_multiindex(data, format): exclude=[ "excel", # Excel does not support datetimes with timezones "parquet", # Casting from timestamp[ns] to timestamp[ms] would lose data - "strata", # Data type datetime64[ns, UTC] not supported. + "stata", # Data type datetime64[ns, UTC] not supported. ] ) @pytest.mark.parametrize(