From 90a4ee43565a5ca3c8d692a41169efafa34c0bec Mon Sep 17 00:00:00 2001 From: Terence Lim Date: Sat, 23 Apr 2022 19:16:25 +0800 Subject: [PATCH 1/7] Add support for CSV and JSON Dataset batch reading --- mlem/api/commands.py | 6 +- mlem/api/utils.py | 7 +- mlem/cli/apply.py | 11 ++- mlem/contrib/numpy.py | 11 ++- mlem/contrib/pandas.py | 180 ++++++++++++++++++++++++++++---------- mlem/core/dataset_type.py | 10 ++- mlem/core/errors.py | 19 ++++ mlem/core/metadata.py | 9 ++ mlem/core/objects.py | 3 + 9 files changed, 202 insertions(+), 54 deletions(-) diff --git a/mlem/api/commands.py b/mlem/api/commands.py index 388898bb..88e26ca2 100644 --- a/mlem/api/commands.py +++ b/mlem/api/commands.py @@ -58,6 +58,7 @@ def apply( output: str = None, link: bool = None, external: bool = None, + batch: Optional[int] = None, ) -> Optional[Any]: """Apply provided model against provided data @@ -85,7 +86,7 @@ def apply( resolved_method = PREDICT_METHOD_NAME echo(EMOJI_APPLY + f"Applying `{resolved_method}` method...") res = [ - w.call_method(resolved_method, get_dataset_value(part)) + w.call_method(resolved_method, get_dataset_value(part, batch)) for part in data ] if output is None: @@ -412,6 +413,7 @@ def import_object( target_repo: Optional[str] = None, target_fs: Optional[AbstractFileSystem] = None, type_: Optional[str] = None, + batch: Optional[int] = None, copy_data: bool = True, external: bool = None, link: bool = None, @@ -426,7 +428,7 @@ def import_object( if type_ not in ImportHook.__type_map__: raise ValueError(f"Unknown import type {type_}") meta = ImportHook.__type_map__[type_].process( - loc, copy_data=copy_data, modifier=modifier + loc, copy_data=copy_data, modifier=modifier, batch=batch ) else: meta = ImportAnalyzer.analyze(loc, copy_data=copy_data) diff --git a/mlem/api/utils.py b/mlem/api/utils.py index eac9d1fe..be1b04b5 100644 --- a/mlem/api/utils.py +++ b/mlem/api/utils.py @@ -7,14 +7,17 @@ from mlem.core.objects import DatasetMeta, MlemMeta, ModelMeta -def get_dataset_value(dataset: Any) -> Any: +def get_dataset_value(dataset: Any, batch: Optional[int] = None) -> Any: if isinstance(dataset, str): return load(dataset) if isinstance(dataset, DatasetMeta): # TODO: https://github.com/iterative/mlem/issues/29 # fix discrepancies between model and data meta objects if not hasattr(dataset.dataset, "data"): - dataset.load_value() + if batch: + dataset.load_batch_value(batch) + else: + dataset.load_value() return dataset.data # TODO: https://github.com/iterative/mlem/issues/29 diff --git a/mlem/cli/apply.py b/mlem/cli/apply.py index 7c670bde..8bd11a96 100644 --- a/mlem/cli/apply.py +++ b/mlem/cli/apply.py @@ -60,6 +60,9 @@ def apply( # TODO: change ImportHook to MlemObject to support ext machinery help=f"Specify how to read data file for import. Available types: {list_implementations(ImportHook)}", ), + batch: Optional[int] = Option( + None, "-b", "--batch", help="Batch size for reading data in batches." + ), link: bool = option_link, external: bool = option_external, json: bool = option_json, @@ -83,7 +86,11 @@ def apply( with set_echo(None if json else ...): if import_: dataset = import_object( - data, repo=data_repo, rev=data_rev, type_=import_type + data, + repo=data_repo, + rev=data_rev, + type_=import_type, + batch=batch, ) else: dataset = load_meta( @@ -92,6 +99,7 @@ def apply( data_rev, load_value=True, force_type=DatasetMeta, + batch=batch, ) meta = load_meta(model, repo, rev, force_type=ModelMeta) @@ -102,6 +110,7 @@ def apply( output=output, link=link, external=external, + batch=batch, ) if output is None and json: print( diff --git a/mlem/contrib/numpy.py b/mlem/contrib/numpy.py index 141c3603..7c6524a4 100644 --- a/mlem/contrib/numpy.py +++ b/mlem/contrib/numpy.py @@ -1,5 +1,5 @@ from types import ModuleType -from typing import Any, ClassVar, List, Optional, Tuple, Type, Union +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union import numpy as np from pydantic import BaseModel, conlist, create_model @@ -172,7 +172,11 @@ class NumpyArrayWriter(DatasetWriter): type: ClassVar[str] = "numpy" def write( - self, dataset: DatasetType, storage: Storage, path: str + self, + dataset: DatasetType, + storage: Storage, + path: str, + writer_fmt_args: Optional[Dict[str, Any]] = None, ) -> Tuple[DatasetReader, Artifacts]: with storage.open(path) as (f, art): np.savez_compressed(f, **{DATA_KEY: dataset.data}) @@ -192,3 +196,6 @@ def read(self, artifacts: Artifacts) -> DatasetType: with artifacts[DatasetWriter.art_name].open() as f: data = np.load(f)[DATA_KEY] return self.dataset_type.copy().bind(data) + + def read_batch(self, artifacts: Artifacts, batch: int) -> DatasetType: + raise NotImplementedError diff --git a/mlem/contrib/pandas.py b/mlem/contrib/pandas.py index 1bca0ada..4217e66a 100644 --- a/mlem/contrib/pandas.py +++ b/mlem/contrib/pandas.py @@ -44,7 +44,12 @@ DatasetType, DatasetWriter, ) -from mlem.core.errors import DeserializationError, SerializationError +from mlem.core.errors import ( + DatasetBatchLoadingJSONError, + DeserializationError, + SerializationError, + UnsupportedDatasetBatchLoadingType, +) from mlem.core.import_objects import ExtImportHook from mlem.core.meta_io import Location from mlem.core.metadata import get_object_metadata @@ -320,7 +325,7 @@ def get_writer(self, **kwargs): if filename is not None: _, ext = os.path.splitext(filename) ext = ext.lstrip(".") - if ext in PANDAS_FORMATS: + if ext in get_pandas_formats(): fmt = ext return PandasWriter(format=fmt) @@ -386,6 +391,9 @@ class PandasFormat: file_name: str = "data.pd" string_buffer: bool = False + def update_variable(self, name: str, val: Any): + setattr(self, name, val) + def read(self, artifacts: Artifacts, **kwargs): """Read DataFrame""" read_kwargs = {} @@ -428,54 +436,88 @@ def read_csv_with_unnamed(*args, **kwargs): return df.rename(unnamed, axis=1) # pylint: disable=no-member +def read_batch_csv_with_unnamed(*args, **kwargs): + df = None + df_iterator = pd.read_csv(*args, **kwargs) + unnamed = {} + for i, df_chunk in enumerate(df_iterator): + # Instantiate Pandas DataFrame if it is the first chunk + if i == 0: + df = pd.DataFrame(columns=df_chunk.columns) + for col in df_chunk.columns: + if col.startswith("Unnamed: "): + unnamed[col] = "" + df = pd.concat([df, df_chunk], ignore_index=True) + if not unnamed: + return df + return df.rename(unnamed, axis=1) # pylint: disable=no-member + + def read_json_reset_index(*args, **kwargs): return pd.read_json(*args, **kwargs).reset_index(drop=True) +def read_batch_json_reset_index(*args, **kwargs): + df = None + df_iterator = pd.read_json(*args, **kwargs) + for i, df_chunk in enumerate(df_iterator): + # Instantiate Pandas DataFrame if it is the first chunk + if i == 0: + df = pd.DataFrame(columns=df_chunk.columns) + df = pd.concat([df, df_chunk], ignore_index=True) + df = df.reset_index(drop=True) + return df + + def read_html(*args, **kwargs): # read_html returns list of dataframes return pd.read_html(*args, **kwargs)[0] -PANDAS_FORMATS = { - "csv": PandasFormat( - read_csv_with_unnamed, - pd.DataFrame.to_csv, - file_name="data.csv", - write_args={"index": False}, - ), - "json": PandasFormat( - read_json_reset_index, - pd.DataFrame.to_json, - file_name="data.json", - write_args={"date_format": "iso", "date_unit": "ns"}, - ), - "html": PandasFormat( - read_html, - pd.DataFrame.to_html, - file_name="data.html", - write_args={"index": False}, - string_buffer=True, - ), - "excel": PandasFormat( - pd.read_excel, - pd.DataFrame.to_excel, - file_name="data.xlsx", - write_args={"index": False}, - ), - "parquet": PandasFormat( - pd.read_parquet, pd.DataFrame.to_parquet, file_name="data.parquet" - ), - "feather": PandasFormat( - pd.read_feather, pd.DataFrame.to_feather, file_name="data.feather" - ), - "pickle": PandasFormat( # TODO buffer closed error for some reason - pd.read_pickle, pd.DataFrame.to_pickle, file_name="data.pkl" - ), - "strata": PandasFormat( # TODO int32 converts to int64 for some reason - pd.read_stata, pd.DataFrame.to_stata, write_args={"write_index": False} - ), -} +def get_pandas_formats(): + PANDAS_FORMATS = { + "csv": PandasFormat( + read_csv_with_unnamed, + pd.DataFrame.to_csv, + file_name="data.csv", + write_args={"index": False}, + ), + "json": PandasFormat( + read_json_reset_index, + pd.DataFrame.to_json, + file_name="data.json", + write_args={"date_format": "iso", "date_unit": "ns"}, + ), + "html": PandasFormat( + read_html, + pd.DataFrame.to_html, + file_name="data.html", + write_args={"index": False}, + string_buffer=True, + ), + "excel": PandasFormat( + pd.read_excel, + pd.DataFrame.to_excel, + file_name="data.xlsx", + write_args={"index": False}, + ), + "parquet": PandasFormat( + pd.read_parquet, pd.DataFrame.to_parquet, file_name="data.parquet" + ), + "feather": PandasFormat( + pd.read_feather, pd.DataFrame.to_feather, file_name="data.feather" + ), + "pickle": PandasFormat( # TODO buffer closed error for some reason + pd.read_pickle, pd.DataFrame.to_pickle, file_name="data.pkl" + ), + "stata": PandasFormat( # TODO int32 converts to int64 for some reason + pd.read_stata, + pd.DataFrame.to_stata, + write_args={"write_index": False}, + ), + } + + return PANDAS_FORMATS class _PandasIO(BaseModel): @@ -485,13 +527,13 @@ class _PandasIO(BaseModel): def is_valid_format( # pylint: disable=no-self-argument cls, value # noqa: B902 ): - if value not in PANDAS_FORMATS: + if value not in get_pandas_formats(): raise ValueError(f"format {value} is not supported") return value @property def fmt(self): - return PANDAS_FORMATS[self.format] + return get_pandas_formats()[self.format] class PandasReader(_PandasIO, DatasetReader): @@ -505,6 +547,21 @@ def read(self, artifacts: Artifacts) -> DatasetType: self.dataset_type.align(self.fmt.read(artifacts)) ) + def read_batch(self, artifacts: Artifacts, batch: int) -> DatasetType: + fmt = update_batch_args(self.format, self.fmt, batch) + + # Pandas supports batch-reading for JSON only if the JSON file is line-delimited + # https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html#line-delimited-json + if self.format == "json": + dataset_lines = sum(1 for line in open(artifacts["data"].uri)) + if dataset_lines <= 1: + raise DatasetBatchLoadingJSONError( + "Batch-loading Dataset of type JSON requires provided JSON file to be line-delimited." + ) + return self.dataset_type.copy().bind( + self.dataset_type.align(fmt.read(artifacts)) + ) + class PandasWriter(DatasetWriter, _PandasIO): """DatasetWriter for pandas dataframes""" @@ -512,9 +569,16 @@ class PandasWriter(DatasetWriter, _PandasIO): type: ClassVar[str] = "pandas" def write( - self, dataset: DatasetType, storage: Storage, path: str + self, + dataset: DatasetType, + storage: Storage, + path: str, + writer_fmt_args: Optional[Dict[str, Any]] = None, ) -> Tuple[DatasetReader, Artifacts]: fmt = self.fmt + if writer_fmt_args: + for k, v in writer_fmt_args.items(): + setattr(fmt, k, v) art = fmt.write(dataset.data, storage, path) if not isinstance(dataset, DataFrameType): raise ValueError("Cannot write non-pandas Dataset") @@ -524,7 +588,7 @@ def write( class PandasImport(ExtImportHook): - EXTS: ClassVar = tuple(f".{k}" for k in PANDAS_FORMATS) + EXTS: ClassVar = tuple(f".{k}" for k in get_pandas_formats()) type: ClassVar = "pandas" @classmethod @@ -537,10 +601,13 @@ def process( obj: Location, copy_data: bool = True, modifier: Optional[str] = None, + batch: Optional[int] = None, **kwargs, ) -> MlemMeta: ext = modifier or posixpath.splitext(obj.path)[1][1:] - fmt = PANDAS_FORMATS[ext] + fmt = get_pandas_formats()[ext] + if batch: + fmt = update_batch_args(ext, fmt, batch) read_args = fmt.read_args or {} read_args.update(kwargs) with obj.open("rb") as f: @@ -555,3 +622,24 @@ def process( ) } return meta + + +def update_batch_args( + type_: str, fmt: PandasFormat, batch: int +) -> PandasFormat: + # Check if batch reading is supported for specified _type format + if type_ == "csv": + fmt.read_func = read_batch_csv_with_unnamed + fmt.read_args = {"chunksize": batch} + elif type_ == "json": + fmt.read_func = read_batch_json_reset_index + # JSON batch-reading requires line-delimited data, and orient to be records + fmt.read_args = { + "chunksize": batch, + "lines": True, + "orient": "records", + } + else: + raise UnsupportedDatasetBatchLoadingType(type_) + + return fmt diff --git a/mlem/core/dataset_type.py b/mlem/core/dataset_type.py index b59236ca..9f49b367 100644 --- a/mlem/core/dataset_type.py +++ b/mlem/core/dataset_type.py @@ -404,6 +404,10 @@ class Config: def read(self, artifacts: Artifacts) -> DatasetType: raise NotImplementedError + @abstractmethod + def read_batch(self, artifacts: Artifacts, batch: int) -> DatasetType: + raise NotImplementedError + class DatasetWriter(MlemObject): """""" @@ -416,6 +420,10 @@ class Config: @abstractmethod def write( - self, dataset: DatasetType, storage: Storage, path: str + self, + dataset: DatasetType, + storage: Storage, + path: str, + writer_fmt_args: Optional[Dict[str, Any]] = None, ) -> Tuple[DatasetReader, Artifacts]: raise NotImplementedError diff --git a/mlem/core/errors.py b/mlem/core/errors.py index 709b58a6..a7426f05 100644 --- a/mlem/core/errors.py +++ b/mlem/core/errors.py @@ -70,6 +70,25 @@ class MlemObjectNotLoadedError(ValueError, MlemError): """Thrown if model or dataset value is not loaded""" +class UnsupportedDatasetBatchLoadingType(ValueError, MlemError): + """Thrown if batch loading of dataset with unsupported file type is called""" + + _message = "Batch-loading Dataset of type '{dataset_type}' is currently not supported. Please remove batch parameter." + + def __init__( + self, + dataset_type, + ) -> None: + + self.dataset_type = dataset_type + self.message = self._message.format(dataset_type=dataset_type) + super().__init__(self.message) + + +class DatasetBatchLoadingJSONError(ValueError, MlemError): + """Thrown if batch loading of JSON dataset is not line-delimited""" + + class WrongMethodError(ValueError, MlemError): """Thrown if wrong method name for model is provided""" diff --git a/mlem/core/metadata.py b/mlem/core/metadata.py index 885d0989..f4c64743 100644 --- a/mlem/core/metadata.py +++ b/mlem/core/metadata.py @@ -107,6 +107,7 @@ def load( path: str, repo: Optional[str] = None, rev: Optional[str] = None, + batch: Optional[int] = None, follow_links: bool = True, ) -> Any: """Load python object saved by MLEM @@ -115,6 +116,7 @@ def load( path (str): Path to the object. Could be local path or path inside a git repo. repo (Optional[str], optional): URL to repo if object is located there. rev (Optional[str], optional): revision, could be git commit SHA, branch name or tag. + batch (Optional[int], optional): batch, required if performing batch-reading of Dataset. follow_links (bool, optional): If object we read is a MLEM link, whether to load the actual object link points to. Defaults to True. @@ -127,6 +129,7 @@ def load( rev=rev, follow_links=follow_links, load_value=True, + batch=batch, ) return meta.get_value() @@ -142,6 +145,7 @@ def load_meta( follow_links: bool = True, load_value: bool = False, fs: Optional[AbstractFileSystem] = None, + batch: Optional[int] = None, *, force_type: Literal[None] = None, ) -> MlemMeta: @@ -156,6 +160,7 @@ def load_meta( follow_links: bool = True, load_value: bool = False, fs: Optional[AbstractFileSystem] = None, + batch: Optional[int] = None, *, force_type: Optional[Type[T]] = None, ) -> T: @@ -169,6 +174,7 @@ def load_meta( follow_links: bool = True, load_value: bool = False, fs: Optional[AbstractFileSystem] = None, + batch: Optional[int] = None, *, force_type: Optional[Type[T]] = None, ) -> T: @@ -199,6 +205,9 @@ def load_meta( follow_links=follow_links, ) if load_value: + if isinstance(meta, DatasetMeta) and batch: + meta.load_batch_value(batch) + return meta # type: ignore[return-value] meta.load_value() if not isinstance(meta, cls): raise WrongMetaType(meta, force_type) diff --git a/mlem/core/objects.py b/mlem/core/objects.py index 2b5c81c8..a90cb921 100644 --- a/mlem/core/objects.py +++ b/mlem/core/objects.py @@ -668,6 +668,9 @@ def write_value(self) -> Artifacts: def load_value(self): self.dataset = self.reader.read(self.relative_artifacts) + def load_batch_value(self, batch: int): + self.dataset = self.reader.read_batch(self.relative_artifacts, batch) # type: ignore + def get_value(self): return self.data From 83a05f6a8b081d62e1017a64bf121c883f80ad3e Mon Sep 17 00:00:00 2001 From: Terence Lim Date: Sat, 23 Apr 2022 19:17:28 +0800 Subject: [PATCH 2/7] Add tests for batch reading workflows --- tests/conftest.py | 35 +++++++++++++++++++-- tests/contrib/test_pandas.py | 59 ++++++++++++++++++++++++++++++++--- tests/core/test_dataset_io.py | 4 +-- 3 files changed, 89 insertions(+), 9 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index d6069a6b..d42fbc33 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import posixpath import tempfile from pathlib import Path -from typing import Any, Callable, Type +from typing import Any, Callable, Optional, Type import git import pandas as pd @@ -21,6 +21,7 @@ from mlem.contrib.sklearn import SklearnModel from mlem.core.artifacts import LOCAL_STORAGE, FSSpecStorage, LocalArtifact from mlem.core.dataset_type import DatasetReader, DatasetType, DatasetWriter +from mlem.core.errors import UnsupportedDatasetBatchLoadingType from mlem.core.meta_io import MLEM_EXT, get_fs from mlem.core.metadata import load_meta from mlem.core.model import Argument, ModelType, Signature @@ -305,18 +306,23 @@ def dataset_write_read_check( reader_type: Type[DatasetReader] = None, custom_eq: Callable[[Any, Any], bool] = None, custom_assert: Callable[[Any, Any], Any] = None, + batch: Optional[int] = None, + writer_args: Any = None, ): with tempfile.TemporaryDirectory() as tmpdir: writer = writer or dataset.get_writer() storage = LOCAL_STORAGE reader, artifacts = writer.write( - dataset, storage, posixpath.join(tmpdir, "data") + dataset, storage, posixpath.join(tmpdir, "data"), writer_args ) if reader_type is not None: assert isinstance(reader, reader_type) - new = reader.read(artifacts) + if batch: + new = reader.read_batch(artifacts, batch) + else: + new = reader.read(artifacts) assert dataset == new if custom_assert is not None: @@ -328,6 +334,29 @@ def dataset_write_read_check( assert new.data == dataset.data +def dataset_write_read_batch_unsupported( + dataset: DatasetType, + batch: int, + writer: DatasetWriter = None, + reader_type: Type[DatasetReader] = None, +): + with tempfile.TemporaryDirectory() as tmpdir: + writer = writer or dataset.get_writer() + + storage = LOCAL_STORAGE + reader, artifacts = writer.write( + dataset, storage, posixpath.join(tmpdir, "data") + ) + if reader_type is not None: + assert isinstance(reader, reader_type) + + with pytest.raises( + UnsupportedDatasetBatchLoadingType, + match="Batch-loading Dataset of type .*", + ): + reader.read_batch(artifacts, batch) + + def check_model_type_common_interface( model_type: ModelType, data_type: DatasetType, diff --git a/tests/contrib/test_pandas.py b/tests/contrib/test_pandas.py index 22e4102c..199d3b18 100644 --- a/tests/contrib/test_pandas.py +++ b/tests/contrib/test_pandas.py @@ -13,12 +13,12 @@ from mlem.api.commands import import_object from mlem.contrib.pandas import ( - PANDAS_FORMATS, DataFrameType, PandasConfig, PandasReader, PandasWriter, SeriesType, + get_pandas_formats, pd_type_from_string, python_type_from_pd_string_repr, python_type_from_pd_type, @@ -29,7 +29,11 @@ from mlem.core.meta_io import MLEM_EXT from mlem.core.metadata import load, save from mlem.core.objects import DatasetMeta -from tests.conftest import dataset_write_read_check, long +from tests.conftest import ( + dataset_write_read_batch_unsupported, + dataset_write_read_check, + long, +) PD_DATA_FRAME = pd.DataFrame( [ @@ -105,7 +109,7 @@ def series_df_type(series_data): def for_all_formats(exclude: Union[List[str], Callable] = None): ex = exclude if isinstance(exclude, list) else [] - formats = [name for name in PANDAS_FORMATS if name not in ex] + formats = [name for name in get_pandas_formats() if name not in ex] mark = pytest.mark.parametrize("format", formats) if isinstance(exclude, list): return mark @@ -120,6 +124,53 @@ def test_simple_df(data, format): ) +@for_all_formats( + exclude=[ # Following file formats do not support chunksize parameter + "html", + "excel", + "parquet", + "feather", + "pickle", + "stata", # TODO: Add support for stata + ] +) +def test_simple_batch_df(data, format): + writer = PandasWriter(format=format) + # Batch-reading JSON files require line-delimited data + writer_args = None + if format == "json": + writer_args = { + "write_args": { + "date_format": "iso", + "date_unit": "ns", + "orient": "records", + "lines": True, + } + } + dataset_write_read_check( + DatasetType.create(data), + writer, + PandasReader, + pd.DataFrame.equals, + batch=2, + writer_args=writer_args, + ) + + +@for_all_formats( + exclude=[ # Following file formats do not support chunksize parameter + "csv", + "json", + "stata", # TODO: Add support for stata + ] +) +def test_unsupported_batch_df(data, format): + writer = PandasWriter(format=format) + dataset_write_read_batch_unsupported( + DatasetType.create(data), 2, writer, PandasReader + ) + + @for_all_formats def test_with_index(data, format): writer = PandasWriter(format=format) @@ -146,7 +197,7 @@ def test_with_multiindex(data, format): exclude=[ "excel", # Excel does not support datetimes with timezones "parquet", # Casting from timestamp[ns] to timestamp[ms] would lose data - "strata", # Data type datetime64[ns, UTC] not supported. + "stata", # Data type datetime64[ns, UTC] not supported. ] ) @pytest.mark.parametrize( diff --git a/tests/core/test_dataset_io.py b/tests/core/test_dataset_io.py index 0c69d30e..d5f4eb8e 100644 --- a/tests/core/test_dataset_io.py +++ b/tests/core/test_dataset_io.py @@ -3,7 +3,7 @@ import pytest from mlem.contrib.numpy import NumpyArrayReader, NumpyArrayWriter -from mlem.contrib.pandas import PANDAS_FORMATS, PandasReader, PandasWriter +from mlem.contrib.pandas import PandasReader, PandasWriter, get_pandas_formats from mlem.core.artifacts import FSSpecStorage from mlem.core.dataset_type import DatasetType @@ -27,7 +27,7 @@ def test_numpy_read_write(): assert np.array_equal(dataset2.data, data) -@pytest.mark.parametrize("format", list(PANDAS_FORMATS.keys())) +@pytest.mark.parametrize("format", list(get_pandas_formats().keys())) def test_pandas_read_write(format): data = pd.DataFrame([{"a": 1, "b": 2}]) dataset = DatasetType.create(data) From 5c3ed282eeb5c11c0cea82f4027029934f197dff Mon Sep 17 00:00:00 2001 From: Terence Lim Date: Sun, 24 Apr 2022 13:25:34 +0800 Subject: [PATCH 3/7] Add support for Stata Dataset batch reading --- mlem/contrib/pandas.py | 15 +++++++++++++++ tests/contrib/test_pandas.py | 7 +++---- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/mlem/contrib/pandas.py b/mlem/contrib/pandas.py index 4217e66a..16118d6d 100644 --- a/mlem/contrib/pandas.py +++ b/mlem/contrib/pandas.py @@ -469,6 +469,18 @@ def read_batch_json_reset_index(*args, **kwargs): return df +def read_batch_stata_reset_index(*args, **kwargs): + df = None + df_iterator = pd.read_stata(*args, **kwargs) + for i, df_chunk in enumerate(df_iterator): + # Instantiate Pandas DataFrame if it is the first chunk + if i == 0: + df = pd.DataFrame(columns=df_chunk.columns) + df = pd.concat([df, df_chunk], ignore_index=True) + df = df.reset_index(drop=True) + return df + + def read_html(*args, **kwargs): # read_html returns list of dataframes return pd.read_html(*args, **kwargs)[0] @@ -639,6 +651,9 @@ def update_batch_args( "lines": True, "orient": "records", } + elif type_ == "stata": + fmt.read_func = read_batch_stata_reset_index + fmt.read_args = {"chunksize": batch} else: raise UnsupportedDatasetBatchLoadingType(type_) diff --git a/tests/contrib/test_pandas.py b/tests/contrib/test_pandas.py index 199d3b18..ea1f030c 100644 --- a/tests/contrib/test_pandas.py +++ b/tests/contrib/test_pandas.py @@ -125,13 +125,12 @@ def test_simple_df(data, format): @for_all_formats( - exclude=[ # Following file formats do not support chunksize parameter + exclude=[ # Following file formats do not support Pandas chunksize parameter "html", "excel", "parquet", "feather", "pickle", - "stata", # TODO: Add support for stata ] ) def test_simple_batch_df(data, format): @@ -158,10 +157,10 @@ def test_simple_batch_df(data, format): @for_all_formats( - exclude=[ # Following file formats do not support chunksize parameter + exclude=[ # Following file formats support Pandas chunksize parameter "csv", "json", - "stata", # TODO: Add support for stata + "stata", ] ) def test_unsupported_batch_df(data, format): From a1e3a06c69b770ef9099d4be15cbebeb680cebaf Mon Sep 17 00:00:00 2001 From: Terence Lim Date: Sun, 24 Apr 2022 14:53:22 +0800 Subject: [PATCH 4/7] Remove code duplication using partial function --- mlem/contrib/pandas.py | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/mlem/contrib/pandas.py b/mlem/contrib/pandas.py index 16118d6d..27af09d8 100644 --- a/mlem/contrib/pandas.py +++ b/mlem/contrib/pandas.py @@ -3,6 +3,7 @@ import re from abc import ABC from dataclasses import dataclass +from functools import partial from typing import ( IO, Any, @@ -457,21 +458,9 @@ def read_json_reset_index(*args, **kwargs): return pd.read_json(*args, **kwargs).reset_index(drop=True) -def read_batch_json_reset_index(*args, **kwargs): - df = None - df_iterator = pd.read_json(*args, **kwargs) - for i, df_chunk in enumerate(df_iterator): - # Instantiate Pandas DataFrame if it is the first chunk - if i == 0: - df = pd.DataFrame(columns=df_chunk.columns) - df = pd.concat([df, df_chunk], ignore_index=True) - df = df.reset_index(drop=True) - return df - - -def read_batch_stata_reset_index(*args, **kwargs): - df = None - df_iterator = pd.read_stata(*args, **kwargs) +def read_batch_reset_index(read_func: Callable, *args, **kwargs): + df = pd.DataFrame() + df_iterator = read_func(*args, **kwargs) for i, df_chunk in enumerate(df_iterator): # Instantiate Pandas DataFrame if it is the first chunk if i == 0: @@ -644,7 +633,7 @@ def update_batch_args( fmt.read_func = read_batch_csv_with_unnamed fmt.read_args = {"chunksize": batch} elif type_ == "json": - fmt.read_func = read_batch_json_reset_index + fmt.read_func = partial(read_batch_reset_index, pd.read_json) # JSON batch-reading requires line-delimited data, and orient to be records fmt.read_args = { "chunksize": batch, @@ -652,7 +641,7 @@ def update_batch_args( "orient": "records", } elif type_ == "stata": - fmt.read_func = read_batch_stata_reset_index + fmt.read_func = partial(read_batch_reset_index, pd.read_stata) fmt.read_args = {"chunksize": batch} else: raise UnsupportedDatasetBatchLoadingType(type_) From 9e012d1ff717c3ce0d9e75d95391c2125c2660f2 Mon Sep 17 00:00:00 2001 From: Terence Lim Date: Sun, 24 Apr 2022 15:01:11 +0800 Subject: [PATCH 5/7] Add clearer comments --- mlem/contrib/pandas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlem/contrib/pandas.py b/mlem/contrib/pandas.py index 27af09d8..e5fb1ed6 100644 --- a/mlem/contrib/pandas.py +++ b/mlem/contrib/pandas.py @@ -442,7 +442,7 @@ def read_batch_csv_with_unnamed(*args, **kwargs): df_iterator = pd.read_csv(*args, **kwargs) unnamed = {} for i, df_chunk in enumerate(df_iterator): - # Instantiate Pandas DataFrame if it is the first chunk + # Instantiate Pandas DataFrame with columns if it is the first chunk if i == 0: df = pd.DataFrame(columns=df_chunk.columns) for col in df_chunk.columns: @@ -462,7 +462,7 @@ def read_batch_reset_index(read_func: Callable, *args, **kwargs): df = pd.DataFrame() df_iterator = read_func(*args, **kwargs) for i, df_chunk in enumerate(df_iterator): - # Instantiate Pandas DataFrame if it is the first chunk + # Instantiate Pandas DataFrame with columns if it is the first chunk if i == 0: df = pd.DataFrame(columns=df_chunk.columns) df = pd.concat([df, df_chunk], ignore_index=True) From 365ca098ee8f29e1bfbd51ba162e9ade8e55d446 Mon Sep 17 00:00:00 2001 From: Terence Lim Date: Tue, 26 Apr 2022 01:16:44 +0800 Subject: [PATCH 6/7] Address PR comments --- mlem/contrib/numpy.py | 8 +- mlem/contrib/pandas.py | 147 ++++++++++++++++------------------ mlem/core/dataset_type.py | 6 +- mlem/core/errors.py | 4 - tests/conftest.py | 41 ++++++++-- tests/contrib/test_pandas.py | 23 ++---- tests/core/test_dataset_io.py | 4 +- 7 files changed, 115 insertions(+), 118 deletions(-) diff --git a/mlem/contrib/numpy.py b/mlem/contrib/numpy.py index 7c6524a4..69342e6e 100644 --- a/mlem/contrib/numpy.py +++ b/mlem/contrib/numpy.py @@ -1,5 +1,5 @@ from types import ModuleType -from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union +from typing import Any, ClassVar, List, Optional, Tuple, Type, Union import numpy as np from pydantic import BaseModel, conlist, create_model @@ -172,11 +172,7 @@ class NumpyArrayWriter(DatasetWriter): type: ClassVar[str] = "numpy" def write( - self, - dataset: DatasetType, - storage: Storage, - path: str, - writer_fmt_args: Optional[Dict[str, Any]] = None, + self, dataset: DatasetType, storage: Storage, path: str ) -> Tuple[DatasetReader, Artifacts]: with storage.open(path) as (f, art): np.savez_compressed(f, **{DATA_KEY: dataset.data}) diff --git a/mlem/contrib/pandas.py b/mlem/contrib/pandas.py index e5fb1ed6..c92dcc74 100644 --- a/mlem/contrib/pandas.py +++ b/mlem/contrib/pandas.py @@ -46,7 +46,6 @@ DatasetWriter, ) from mlem.core.errors import ( - DatasetBatchLoadingJSONError, DeserializationError, SerializationError, UnsupportedDatasetBatchLoadingType, @@ -326,7 +325,7 @@ def get_writer(self, **kwargs): if filename is not None: _, ext = os.path.splitext(filename) ext = ext.lstrip(".") - if ext in get_pandas_formats(): + if ext in PANDAS_FORMATS: fmt = ext return PandasWriter(format=fmt) @@ -475,46 +474,78 @@ def read_html(*args, **kwargs): return pd.read_html(*args, **kwargs)[0] -def get_pandas_formats(): +PANDAS_FORMATS = { + "csv": PandasFormat( + read_csv_with_unnamed, + pd.DataFrame.to_csv, + file_name="data.csv", + write_args={"index": False}, + ), + "json": PandasFormat( + read_json_reset_index, + pd.DataFrame.to_json, + file_name="data.json", + write_args={"date_format": "iso", "date_unit": "ns"}, + ), + "html": PandasFormat( + read_html, + pd.DataFrame.to_html, + file_name="data.html", + write_args={"index": False}, + string_buffer=True, + ), + "excel": PandasFormat( + pd.read_excel, + pd.DataFrame.to_excel, + file_name="data.xlsx", + write_args={"index": False}, + ), + "parquet": PandasFormat( + pd.read_parquet, pd.DataFrame.to_parquet, file_name="data.parquet" + ), + "feather": PandasFormat( + pd.read_feather, pd.DataFrame.to_feather, file_name="data.feather" + ), + "pickle": PandasFormat( # TODO buffer closed error for some reason + pd.read_pickle, pd.DataFrame.to_pickle, file_name="data.pkl" + ), + "stata": PandasFormat( # TODO int32 converts to int64 for some reason + pd.read_stata, + pd.DataFrame.to_stata, + write_args={"write_index": False}, + ), +} + + +def get_pandas_batch_formats(batch: int): PANDAS_FORMATS = { "csv": PandasFormat( - read_csv_with_unnamed, + read_batch_csv_with_unnamed, pd.DataFrame.to_csv, file_name="data.csv", write_args={"index": False}, + read_args={"chunksize": batch}, ), "json": PandasFormat( - read_json_reset_index, + partial(read_batch_reset_index, pd.read_json), pd.DataFrame.to_json, file_name="data.json", - write_args={"date_format": "iso", "date_unit": "ns"}, - ), - "html": PandasFormat( - read_html, - pd.DataFrame.to_html, - file_name="data.html", - write_args={"index": False}, - string_buffer=True, - ), - "excel": PandasFormat( - pd.read_excel, - pd.DataFrame.to_excel, - file_name="data.xlsx", - write_args={"index": False}, - ), - "parquet": PandasFormat( - pd.read_parquet, pd.DataFrame.to_parquet, file_name="data.parquet" - ), - "feather": PandasFormat( - pd.read_feather, pd.DataFrame.to_feather, file_name="data.feather" - ), - "pickle": PandasFormat( # TODO buffer closed error for some reason - pd.read_pickle, pd.DataFrame.to_pickle, file_name="data.pkl" + write_args={ + "date_format": "iso", + "date_unit": "ns", + "orient": "records", + "lines": True, + }, + # Pandas supports batch-reading for JSON only if the JSON file is line-delimited + # and orient to be records + # https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html#line-delimited-json + read_args={"chunksize": batch, "orient": "records", "lines": True}, ), "stata": PandasFormat( # TODO int32 converts to int64 for some reason - pd.read_stata, + partial(read_batch_reset_index, pd.read_stata), pd.DataFrame.to_stata, write_args={"write_index": False}, + read_args={"chunksize": batch}, ), } @@ -528,13 +559,13 @@ class _PandasIO(BaseModel): def is_valid_format( # pylint: disable=no-self-argument cls, value # noqa: B902 ): - if value not in get_pandas_formats(): + if value not in PANDAS_FORMATS: raise ValueError(f"format {value} is not supported") return value @property def fmt(self): - return get_pandas_formats()[self.format] + return PANDAS_FORMATS[self.format] class PandasReader(_PandasIO, DatasetReader): @@ -549,16 +580,11 @@ def read(self, artifacts: Artifacts) -> DatasetType: ) def read_batch(self, artifacts: Artifacts, batch: int) -> DatasetType: - fmt = update_batch_args(self.format, self.fmt, batch) - - # Pandas supports batch-reading for JSON only if the JSON file is line-delimited - # https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html#line-delimited-json - if self.format == "json": - dataset_lines = sum(1 for line in open(artifacts["data"].uri)) - if dataset_lines <= 1: - raise DatasetBatchLoadingJSONError( - "Batch-loading Dataset of type JSON requires provided JSON file to be line-delimited." - ) + batch_formats = get_pandas_batch_formats(batch) + if self.format not in batch_formats: + raise UnsupportedDatasetBatchLoadingType(self.format) + fmt = batch_formats[self.format] + return self.dataset_type.copy().bind( self.dataset_type.align(fmt.read(artifacts)) ) @@ -570,16 +596,9 @@ class PandasWriter(DatasetWriter, _PandasIO): type: ClassVar[str] = "pandas" def write( - self, - dataset: DatasetType, - storage: Storage, - path: str, - writer_fmt_args: Optional[Dict[str, Any]] = None, + self, dataset: DatasetType, storage: Storage, path: str ) -> Tuple[DatasetReader, Artifacts]: fmt = self.fmt - if writer_fmt_args: - for k, v in writer_fmt_args.items(): - setattr(fmt, k, v) art = fmt.write(dataset.data, storage, path) if not isinstance(dataset, DataFrameType): raise ValueError("Cannot write non-pandas Dataset") @@ -589,7 +608,7 @@ def write( class PandasImport(ExtImportHook): - EXTS: ClassVar = tuple(f".{k}" for k in get_pandas_formats()) + EXTS: ClassVar = tuple(f".{k}" for k in PANDAS_FORMATS) type: ClassVar = "pandas" @classmethod @@ -606,9 +625,9 @@ def process( **kwargs, ) -> MlemMeta: ext = modifier or posixpath.splitext(obj.path)[1][1:] - fmt = get_pandas_formats()[ext] + fmt = PANDAS_FORMATS[ext] if batch: - fmt = update_batch_args(ext, fmt, batch) + fmt = get_pandas_batch_formats(batch)[ext] read_args = fmt.read_args or {} read_args.update(kwargs) with obj.open("rb") as f: @@ -623,27 +642,3 @@ def process( ) } return meta - - -def update_batch_args( - type_: str, fmt: PandasFormat, batch: int -) -> PandasFormat: - # Check if batch reading is supported for specified _type format - if type_ == "csv": - fmt.read_func = read_batch_csv_with_unnamed - fmt.read_args = {"chunksize": batch} - elif type_ == "json": - fmt.read_func = partial(read_batch_reset_index, pd.read_json) - # JSON batch-reading requires line-delimited data, and orient to be records - fmt.read_args = { - "chunksize": batch, - "lines": True, - "orient": "records", - } - elif type_ == "stata": - fmt.read_func = partial(read_batch_reset_index, pd.read_stata) - fmt.read_args = {"chunksize": batch} - else: - raise UnsupportedDatasetBatchLoadingType(type_) - - return fmt diff --git a/mlem/core/dataset_type.py b/mlem/core/dataset_type.py index 9f49b367..23e093c2 100644 --- a/mlem/core/dataset_type.py +++ b/mlem/core/dataset_type.py @@ -420,10 +420,6 @@ class Config: @abstractmethod def write( - self, - dataset: DatasetType, - storage: Storage, - path: str, - writer_fmt_args: Optional[Dict[str, Any]] = None, + self, dataset: DatasetType, storage: Storage, path: str ) -> Tuple[DatasetReader, Artifacts]: raise NotImplementedError diff --git a/mlem/core/errors.py b/mlem/core/errors.py index a7426f05..343926cc 100644 --- a/mlem/core/errors.py +++ b/mlem/core/errors.py @@ -85,10 +85,6 @@ def __init__( super().__init__(self.message) -class DatasetBatchLoadingJSONError(ValueError, MlemError): - """Thrown if batch loading of JSON dataset is not line-delimited""" - - class WrongMethodError(ValueError, MlemError): """Thrown if wrong method name for model is provided""" diff --git a/tests/conftest.py b/tests/conftest.py index d42fbc33..ad0f418b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import posixpath import tempfile from pathlib import Path -from typing import Any, Callable, Optional, Type +from typing import Any, Callable, Type import git import pandas as pd @@ -18,6 +18,7 @@ from mlem.api import init, save from mlem.constants import PREDICT_ARG_NAME, PREDICT_METHOD_NAME from mlem.contrib.fastapi import FastAPIServer +from mlem.contrib.pandas import PandasReader, get_pandas_batch_formats from mlem.contrib.sklearn import SklearnModel from mlem.core.artifacts import LOCAL_STORAGE, FSSpecStorage, LocalArtifact from mlem.core.dataset_type import DatasetReader, DatasetType, DatasetWriter @@ -306,23 +307,49 @@ def dataset_write_read_check( reader_type: Type[DatasetReader] = None, custom_eq: Callable[[Any, Any], bool] = None, custom_assert: Callable[[Any, Any], Any] = None, - batch: Optional[int] = None, - writer_args: Any = None, ): with tempfile.TemporaryDirectory() as tmpdir: writer = writer or dataset.get_writer() storage = LOCAL_STORAGE reader, artifacts = writer.write( - dataset, storage, posixpath.join(tmpdir, "data"), writer_args + dataset, storage, posixpath.join(tmpdir, "data") ) if reader_type is not None: assert isinstance(reader, reader_type) - if batch: - new = reader.read_batch(artifacts, batch) + new = reader.read(artifacts) + + assert dataset == new + if custom_assert is not None: + custom_assert(new.data, dataset.data) else: - new = reader.read(artifacts) + if custom_eq is not None: + assert custom_eq(new.data, dataset.data) + else: + assert new.data == dataset.data + + +def dataset_write_read_batch_check( + dataset: DatasetType, + format: str, + reader_type: Type[DatasetReader] = None, + custom_eq: Callable[[Any, Any], bool] = None, + custom_assert: Callable[[Any, Any], Any] = None, +): + with tempfile.TemporaryDirectory() as tmpdir: + # writer = writer or dataset.get_writer() + BATCH_SIZE = 2 + storage = LOCAL_STORAGE + + fmt = get_pandas_batch_formats(BATCH_SIZE)[format] + art = fmt.write(dataset.data, storage, posixpath.join(tmpdir, "data")) + reader = PandasReader(dataset_type=dataset, format=format) + artifacts = {"data": art} + if reader_type is not None: + assert isinstance(reader, reader_type) + + new = reader.read_batch(artifacts, BATCH_SIZE) assert dataset == new if custom_assert is not None: diff --git a/tests/contrib/test_pandas.py b/tests/contrib/test_pandas.py index ea1f030c..3c4e4af6 100644 --- a/tests/contrib/test_pandas.py +++ b/tests/contrib/test_pandas.py @@ -13,12 +13,12 @@ from mlem.api.commands import import_object from mlem.contrib.pandas import ( + PANDAS_FORMATS, DataFrameType, PandasConfig, PandasReader, PandasWriter, SeriesType, - get_pandas_formats, pd_type_from_string, python_type_from_pd_string_repr, python_type_from_pd_type, @@ -30,6 +30,7 @@ from mlem.core.metadata import load, save from mlem.core.objects import DatasetMeta from tests.conftest import ( + dataset_write_read_batch_check, dataset_write_read_batch_unsupported, dataset_write_read_check, long, @@ -109,7 +110,7 @@ def series_df_type(series_data): def for_all_formats(exclude: Union[List[str], Callable] = None): ex = exclude if isinstance(exclude, list) else [] - formats = [name for name in get_pandas_formats() if name not in ex] + formats = [name for name in PANDAS_FORMATS if name not in ex] mark = pytest.mark.parametrize("format", formats) if isinstance(exclude, list): return mark @@ -134,25 +135,11 @@ def test_simple_df(data, format): ] ) def test_simple_batch_df(data, format): - writer = PandasWriter(format=format) - # Batch-reading JSON files require line-delimited data - writer_args = None - if format == "json": - writer_args = { - "write_args": { - "date_format": "iso", - "date_unit": "ns", - "orient": "records", - "lines": True, - } - } - dataset_write_read_check( + dataset_write_read_batch_check( DatasetType.create(data), - writer, + format, PandasReader, pd.DataFrame.equals, - batch=2, - writer_args=writer_args, ) diff --git a/tests/core/test_dataset_io.py b/tests/core/test_dataset_io.py index d5f4eb8e..0c69d30e 100644 --- a/tests/core/test_dataset_io.py +++ b/tests/core/test_dataset_io.py @@ -3,7 +3,7 @@ import pytest from mlem.contrib.numpy import NumpyArrayReader, NumpyArrayWriter -from mlem.contrib.pandas import PandasReader, PandasWriter, get_pandas_formats +from mlem.contrib.pandas import PANDAS_FORMATS, PandasReader, PandasWriter from mlem.core.artifacts import FSSpecStorage from mlem.core.dataset_type import DatasetType @@ -27,7 +27,7 @@ def test_numpy_read_write(): assert np.array_equal(dataset2.data, data) -@pytest.mark.parametrize("format", list(get_pandas_formats().keys())) +@pytest.mark.parametrize("format", list(PANDAS_FORMATS.keys())) def test_pandas_read_write(format): data = pd.DataFrame([{"a": 1, "b": 2}]) dataset = DatasetType.create(data) From 8d089f59977fe733dee66526cdc0b81b79f18034 Mon Sep 17 00:00:00 2001 From: Terence Lim Date: Wed, 27 Apr 2022 00:21:36 +0800 Subject: [PATCH 7/7] Implement batch dataset read iterator --- mlem/contrib/numpy.py | 4 +-- mlem/contrib/pandas.py | 58 +++++++++++++----------------------- mlem/core/dataset_type.py | 14 +++++++-- tests/conftest.py | 56 ++++++++++++---------------------- tests/contrib/test_pandas.py | 22 ++++---------- 5 files changed, 60 insertions(+), 94 deletions(-) diff --git a/mlem/contrib/numpy.py b/mlem/contrib/numpy.py index 69342e6e..aea37701 100644 --- a/mlem/contrib/numpy.py +++ b/mlem/contrib/numpy.py @@ -1,5 +1,5 @@ from types import ModuleType -from typing import Any, ClassVar, List, Optional, Tuple, Type, Union +from typing import Any, ClassVar, Iterator, List, Optional, Tuple, Type, Union import numpy as np from pydantic import BaseModel, conlist, create_model @@ -193,5 +193,5 @@ def read(self, artifacts: Artifacts) -> DatasetType: data = np.load(f)[DATA_KEY] return self.dataset_type.copy().bind(data) - def read_batch(self, artifacts: Artifacts, batch: int) -> DatasetType: + def read_batch(self, artifacts: Artifacts, batch: int) -> Iterator: raise NotImplementedError diff --git a/mlem/contrib/pandas.py b/mlem/contrib/pandas.py index c92dcc74..9a7245ce 100644 --- a/mlem/contrib/pandas.py +++ b/mlem/contrib/pandas.py @@ -3,13 +3,13 @@ import re from abc import ABC from dataclasses import dataclass -from functools import partial from typing import ( IO, Any, Callable, ClassVar, Dict, + Iterator, List, Optional, Tuple, @@ -436,39 +436,10 @@ def read_csv_with_unnamed(*args, **kwargs): return df.rename(unnamed, axis=1) # pylint: disable=no-member -def read_batch_csv_with_unnamed(*args, **kwargs): - df = None - df_iterator = pd.read_csv(*args, **kwargs) - unnamed = {} - for i, df_chunk in enumerate(df_iterator): - # Instantiate Pandas DataFrame with columns if it is the first chunk - if i == 0: - df = pd.DataFrame(columns=df_chunk.columns) - for col in df_chunk.columns: - if col.startswith("Unnamed: "): - unnamed[col] = "" - df = pd.concat([df, df_chunk], ignore_index=True) - if not unnamed: - return df - return df.rename(unnamed, axis=1) # pylint: disable=no-member - - def read_json_reset_index(*args, **kwargs): return pd.read_json(*args, **kwargs).reset_index(drop=True) -def read_batch_reset_index(read_func: Callable, *args, **kwargs): - df = pd.DataFrame() - df_iterator = read_func(*args, **kwargs) - for i, df_chunk in enumerate(df_iterator): - # Instantiate Pandas DataFrame with columns if it is the first chunk - if i == 0: - df = pd.DataFrame(columns=df_chunk.columns) - df = pd.concat([df, df_chunk], ignore_index=True) - df = df.reset_index(drop=True) - return df - - def read_html(*args, **kwargs): # read_html returns list of dataframes return pd.read_html(*args, **kwargs)[0] @@ -520,14 +491,14 @@ def read_html(*args, **kwargs): def get_pandas_batch_formats(batch: int): PANDAS_FORMATS = { "csv": PandasFormat( - read_batch_csv_with_unnamed, + pd.read_csv, pd.DataFrame.to_csv, file_name="data.csv", write_args={"index": False}, read_args={"chunksize": batch}, ), "json": PandasFormat( - partial(read_batch_reset_index, pd.read_json), + pd.read_json, pd.DataFrame.to_json, file_name="data.json", write_args={ @@ -542,7 +513,7 @@ def get_pandas_batch_formats(batch: int): read_args={"chunksize": batch, "orient": "records", "lines": True}, ), "stata": PandasFormat( # TODO int32 converts to int64 for some reason - partial(read_batch_reset_index, pd.read_stata), + pd.read_stata, pd.DataFrame.to_stata, write_args={"write_index": False}, read_args={"chunksize": batch}, @@ -579,15 +550,28 @@ def read(self, artifacts: Artifacts) -> DatasetType: self.dataset_type.align(self.fmt.read(artifacts)) ) - def read_batch(self, artifacts: Artifacts, batch: int) -> DatasetType: + def read_batch(self, artifacts: Artifacts, batch: int) -> Iterator: batch_formats = get_pandas_batch_formats(batch) if self.format not in batch_formats: raise UnsupportedDatasetBatchLoadingType(self.format) fmt = batch_formats[self.format] - return self.dataset_type.copy().bind( - self.dataset_type.align(fmt.read(artifacts)) - ) + read_kwargs = {} + if fmt.read_args: + read_kwargs.update(fmt.read_args) + with artifacts[DatasetWriter.art_name].open() as f: + iter_df = fmt.read_func(f, **read_kwargs) + for df in iter_df: + unnamed = {} + for col in df.columns: + if col.startswith("Unnamed: "): + unnamed[col] = "" + if unnamed: + df = df.rename(unnamed, axis=1) + + yield self.dataset_type.copy().bind( + self.dataset_type.align(df) + ) class PandasWriter(DatasetWriter, _PandasIO): diff --git a/mlem/core/dataset_type.py b/mlem/core/dataset_type.py index 23e093c2..0625d962 100644 --- a/mlem/core/dataset_type.py +++ b/mlem/core/dataset_type.py @@ -3,7 +3,17 @@ """ import builtins from abc import ABC, abstractmethod -from typing import Any, ClassVar, Dict, List, Optional, Sized, Tuple, Type +from typing import ( + Any, + ClassVar, + Dict, + Iterator, + List, + Optional, + Sized, + Tuple, + Type, +) from pydantic import BaseModel from pydantic.main import create_model @@ -405,7 +415,7 @@ def read(self, artifacts: Artifacts) -> DatasetType: raise NotImplementedError @abstractmethod - def read_batch(self, artifacts: Artifacts, batch: int) -> DatasetType: + def read_batch(self, artifacts: Artifacts, batch: int) -> Iterator: raise NotImplementedError diff --git a/tests/conftest.py b/tests/conftest.py index ad0f418b..b05d1673 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import posixpath import tempfile from pathlib import Path -from typing import Any, Callable, Type +from typing import Any, Callable, Iterator, Type import git import pandas as pd @@ -22,7 +22,6 @@ from mlem.contrib.sklearn import SklearnModel from mlem.core.artifacts import LOCAL_STORAGE, FSSpecStorage, LocalArtifact from mlem.core.dataset_type import DatasetReader, DatasetType, DatasetWriter -from mlem.core.errors import UnsupportedDatasetBatchLoadingType from mlem.core.meta_io import MLEM_EXT, get_fs from mlem.core.metadata import load_meta from mlem.core.model import Argument, ModelType, Signature @@ -335,10 +334,8 @@ def dataset_write_read_batch_check( format: str, reader_type: Type[DatasetReader] = None, custom_eq: Callable[[Any, Any], bool] = None, - custom_assert: Callable[[Any, Any], Any] = None, ): with tempfile.TemporaryDirectory() as tmpdir: - # writer = writer or dataset.get_writer() BATCH_SIZE = 2 storage = LOCAL_STORAGE @@ -349,39 +346,24 @@ def dataset_write_read_batch_check( if reader_type is not None: assert isinstance(reader, reader_type) - new = reader.read_batch(artifacts, BATCH_SIZE) - - assert dataset == new - if custom_assert is not None: - custom_assert(new.data, dataset.data) - else: - if custom_eq is not None: - assert custom_eq(new.data, dataset.data) - else: - assert new.data == dataset.data - - -def dataset_write_read_batch_unsupported( - dataset: DatasetType, - batch: int, - writer: DatasetWriter = None, - reader_type: Type[DatasetReader] = None, -): - with tempfile.TemporaryDirectory() as tmpdir: - writer = writer or dataset.get_writer() - - storage = LOCAL_STORAGE - reader, artifacts = writer.write( - dataset, storage, posixpath.join(tmpdir, "data") - ) - if reader_type is not None: - assert isinstance(reader, reader_type) - - with pytest.raises( - UnsupportedDatasetBatchLoadingType, - match="Batch-loading Dataset of type .*", - ): - reader.read_batch(artifacts, batch) + df_iterable: Iterator = reader.read_batch(artifacts, BATCH_SIZE) + df = None + col_types = None + while True: + try: + chunk = next(df_iterable) + if df is None: + df = pd.DataFrame(columns=chunk.columns, dtype=col_types) + col_types = { + chunk.columns[idx]: chunk.dtypes[idx] + for idx in range(len(chunk.columns)) + } + df = df.astype(dtype=col_types) + df = pd.concat([df, chunk.data], ignore_index=True) + except StopIteration: + break + + assert custom_eq(df, dataset.data) def check_model_type_common_interface( diff --git a/tests/contrib/test_pandas.py b/tests/contrib/test_pandas.py index 3c4e4af6..ea94c6e8 100644 --- a/tests/contrib/test_pandas.py +++ b/tests/contrib/test_pandas.py @@ -31,7 +31,6 @@ from mlem.core.objects import DatasetMeta from tests.conftest import ( dataset_write_read_batch_check, - dataset_write_read_batch_unsupported, dataset_write_read_check, long, ) @@ -80,7 +79,12 @@ def pandas_assert(actual: pd.DataFrame, expected: pd.DataFrame): @pytest.fixture def data(): - return pd.DataFrame([{"a": 1, "b": 3, "c": 5}, {"a": 2, "b": 4, "c": 6}]) + return pd.DataFrame( + [ + {"a": 1, "b": 3, "c": 5}, + {"a": 2, "b": 4, "c": 6}, + ] + ) @pytest.fixture @@ -143,20 +147,6 @@ def test_simple_batch_df(data, format): ) -@for_all_formats( - exclude=[ # Following file formats support Pandas chunksize parameter - "csv", - "json", - "stata", - ] -) -def test_unsupported_batch_df(data, format): - writer = PandasWriter(format=format) - dataset_write_read_batch_unsupported( - DatasetType.create(data), 2, writer, PandasReader - ) - - @for_all_formats def test_with_index(data, format): writer = PandasWriter(format=format)