Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

WIP: Support batch-reading for data-types with chunksize parameter #206

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mlem/api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def apply(
output: str = None,
link: bool = None,
external: bool = None,
batch: Optional[int] = None,
) -> Optional[Any]:
"""Apply provided model against provided data

Expand Down Expand Up @@ -85,7 +86,7 @@ def apply(
resolved_method = PREDICT_METHOD_NAME
echo(EMOJI_APPLY + f"Applying `{resolved_method}` method...")
res = [
w.call_method(resolved_method, get_dataset_value(part))
w.call_method(resolved_method, get_dataset_value(part, batch))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you should flatten all the batches if there are any. Here the example of how I see it:
Suppose you have dataframe=[1,2,3,4] (I mean 1 column, 4 rows) saved to csv file. You load its metadata without loading the value, let's say dt = DatasetMeta(dataset_type=DataFrameType(...), ...). And you call apply with data=[dt]. If batch arg is not provided, what will happen is get_dataset_value will load the actual dataframe and in the end res = [w.call_method(..., dataframe([1,2,3,4))]. But if you provided batch=2, dt. read_batch should be called. If you iterate through it, you will get 2 parts of the dataframe, and in the end res=[w.call_method(..., dataframe([1,2])), w.call_method(..., dataframe([3,4]]))]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏻 Makes sense, I implemented something similar yesterday, but I couldn't sync the changes to this PR because the private fork no longer points to this repository.

for part in data
]
if output is None:
Expand Down Expand Up @@ -412,6 +413,7 @@ def import_object(
target_repo: Optional[str] = None,
target_fs: Optional[AbstractFileSystem] = None,
type_: Optional[str] = None,
batch: Optional[int] = None,
copy_data: bool = True,
external: bool = None,
link: bool = None,
Expand All @@ -426,7 +428,7 @@ def import_object(
if type_ not in ImportHook.__type_map__:
raise ValueError(f"Unknown import type {type_}")
meta = ImportHook.__type_map__[type_].process(
loc, copy_data=copy_data, modifier=modifier
loc, copy_data=copy_data, modifier=modifier, batch=batch
)
else:
meta = ImportAnalyzer.analyze(loc, copy_data=copy_data)
Expand Down
7 changes: 5 additions & 2 deletions mlem/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
from mlem.core.objects import DatasetMeta, MlemMeta, ModelMeta


def get_dataset_value(dataset: Any) -> Any:
def get_dataset_value(dataset: Any, batch: Optional[int] = None) -> Any:
if isinstance(dataset, str):
return load(dataset)
if isinstance(dataset, DatasetMeta):
# TODO: https://github.com/iterative/mlem/issues/29
# fix discrepancies between model and data meta objects
if not hasattr(dataset.dataset, "data"):
dataset.load_value()
if batch:
dataset.load_batch_value(batch)
else:
dataset.load_value()
return dataset.data

# TODO: https://github.com/iterative/mlem/issues/29
Expand Down
11 changes: 10 additions & 1 deletion mlem/cli/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def apply(
# TODO: change ImportHook to MlemObject to support ext machinery
help=f"Specify how to read data file for import. Available types: {list_implementations(ImportHook)}",
),
batch: Optional[int] = Option(
None, "-b", "--batch", help="Batch size for reading data in batches."
),
link: bool = option_link,
external: bool = option_external,
json: bool = option_json,
Expand All @@ -83,7 +86,11 @@ def apply(
with set_echo(None if json else ...):
if import_:
dataset = import_object(
data, repo=data_repo, rev=data_rev, type_=import_type
data,
repo=data_repo,
rev=data_rev,
type_=import_type,
batch=batch,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now let's ignore this branch and focus on the other one without import. once we done, we can discuss how to approach this (I am not sure myself)

)
else:
dataset = load_meta(
Expand All @@ -92,6 +99,7 @@ def apply(
data_rev,
load_value=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as per comments in metadata.py and objects.py, this should be load_meta(..., load_value=batch is None)

force_type=DatasetMeta,
batch=batch,
)
meta = load_meta(model, repo, rev, force_type=ModelMeta)

Expand All @@ -102,6 +110,7 @@ def apply(
output=output,
link=link,
external=external,
batch=batch,
)
if output is None and json:
print(
Expand Down
11 changes: 9 additions & 2 deletions mlem/contrib/numpy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from types import ModuleType
from typing import Any, ClassVar, List, Optional, Tuple, Type, Union
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union

import numpy as np
from pydantic import BaseModel, conlist, create_model
Expand Down Expand Up @@ -172,7 +172,11 @@ class NumpyArrayWriter(DatasetWriter):
type: ClassVar[str] = "numpy"

def write(
self, dataset: DatasetType, storage: Storage, path: str
self,
dataset: DatasetType,
storage: Storage,
path: str,
writer_fmt_args: Optional[Dict[str, Any]] = None,
) -> Tuple[DatasetReader, Artifacts]:
with storage.open(path) as (f, art):
np.savez_compressed(f, **{DATA_KEY: dataset.data})
Expand All @@ -192,3 +196,6 @@ def read(self, artifacts: Artifacts) -> DatasetType:
with artifacts[DatasetWriter.art_name].open() as f:
data = np.load(f)[DATA_KEY]
return self.dataset_type.copy().bind(data)

def read_batch(self, artifacts: Artifacts, batch: int) -> DatasetType:
raise NotImplementedError
184 changes: 138 additions & 46 deletions mlem/contrib/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
from abc import ABC
from dataclasses import dataclass
from functools import partial
from typing import (
IO,
Any,
Expand Down Expand Up @@ -44,7 +45,12 @@
DatasetType,
DatasetWriter,
)
from mlem.core.errors import DeserializationError, SerializationError
from mlem.core.errors import (
DatasetBatchLoadingJSONError,
DeserializationError,
SerializationError,
UnsupportedDatasetBatchLoadingType,
)
from mlem.core.import_objects import ExtImportHook
from mlem.core.meta_io import Location
from mlem.core.metadata import get_object_metadata
Expand Down Expand Up @@ -320,7 +326,7 @@ def get_writer(self, **kwargs):
if filename is not None:
_, ext = os.path.splitext(filename)
ext = ext.lstrip(".")
if ext in PANDAS_FORMATS:
if ext in get_pandas_formats():
fmt = ext
return PandasWriter(format=fmt)

Expand Down Expand Up @@ -386,6 +392,9 @@ class PandasFormat:
file_name: str = "data.pd"
string_buffer: bool = False

def update_variable(self, name: str, val: Any):
setattr(self, name, val)

def read(self, artifacts: Artifacts, **kwargs):
"""Read DataFrame"""
read_kwargs = {}
Expand Down Expand Up @@ -428,54 +437,88 @@ def read_csv_with_unnamed(*args, **kwargs):
return df.rename(unnamed, axis=1) # pylint: disable=no-member


def read_batch_csv_with_unnamed(*args, **kwargs):
df = None
df_iterator = pd.read_csv(*args, **kwargs)
unnamed = {}
for i, df_chunk in enumerate(df_iterator):
# Instantiate Pandas DataFrame with columns if it is the first chunk
if i == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may be df is None

df = pd.DataFrame(columns=df_chunk.columns)
for col in df_chunk.columns:
if col.startswith("Unnamed: "):
unnamed[col] = ""
df = pd.concat([df, df_chunk], ignore_index=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You still read the whole file to memory. The idea is to apply model to each part before you read next

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, sorry I misunderstood. - I'll refactor and implement iterator for these batch functions.

if not unnamed:
return df
return df.rename(unnamed, axis=1) # pylint: disable=no-member


def read_json_reset_index(*args, **kwargs):
return pd.read_json(*args, **kwargs).reset_index(drop=True)


def read_batch_reset_index(read_func: Callable, *args, **kwargs):
df = pd.DataFrame()
df_iterator = read_func(*args, **kwargs)
for i, df_chunk in enumerate(df_iterator):
# Instantiate Pandas DataFrame with columns if it is the first chunk
if i == 0:
df = pd.DataFrame(columns=df_chunk.columns)
df = pd.concat([df, df_chunk], ignore_index=True)
df = df.reset_index(drop=True)
return df


def read_html(*args, **kwargs):
# read_html returns list of dataframes
return pd.read_html(*args, **kwargs)[0]


PANDAS_FORMATS = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored to make this a non-global variable to avoid needing hacks around tests for batch-reading.

def test_simple_batch_df(data, format):
    writer = PandasWriter(format=format)
    # Batch-reading JSON files require line-delimited data
    if format == "json":
        writer.fmt.write_args = {"orient": "records", "lines": True}
    dataset_write_read_check(
        DatasetType.create(data), writer, PandasReader, pd.DataFrame.equals, batch=2
    )
    # Need reset if PANDAS_FORMATS is a global variable
    if format == "csv":
        writer.fmt.write_args = {"index": False}
        writer.fmt.read_args = {}
        writer.fmt.read_func = read_csv_with_unnamed
    if format == "json":
        writer.fmt.write_args = {"date_format": "iso", "date_unit": "ns"}
        writer.fmt.read_args = {}
        writer.fmt.read_func = read_json_reset_index

Copy link
Contributor

@mike0sv mike0sv Apr 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, why you need to change something for tests? This means you testing something else. no?
And you don't need function anyway. You can just do writer.fmt = PandasFormat(<whatever you need>) before chech, or maybe use fmt = writer.fmt.copy() and then change what you need.

"csv": PandasFormat(
read_csv_with_unnamed,
pd.DataFrame.to_csv,
file_name="data.csv",
write_args={"index": False},
),
"json": PandasFormat(
read_json_reset_index,
pd.DataFrame.to_json,
file_name="data.json",
write_args={"date_format": "iso", "date_unit": "ns"},
),
"html": PandasFormat(
read_html,
pd.DataFrame.to_html,
file_name="data.html",
write_args={"index": False},
string_buffer=True,
),
"excel": PandasFormat(
pd.read_excel,
pd.DataFrame.to_excel,
file_name="data.xlsx",
write_args={"index": False},
),
"parquet": PandasFormat(
pd.read_parquet, pd.DataFrame.to_parquet, file_name="data.parquet"
),
"feather": PandasFormat(
pd.read_feather, pd.DataFrame.to_feather, file_name="data.feather"
),
"pickle": PandasFormat( # TODO buffer closed error for some reason
pd.read_pickle, pd.DataFrame.to_pickle, file_name="data.pkl"
),
"strata": PandasFormat( # TODO int32 converts to int64 for some reason
mike0sv marked this conversation as resolved.
Show resolved Hide resolved
pd.read_stata, pd.DataFrame.to_stata, write_args={"write_index": False}
),
}
def get_pandas_formats():
PANDAS_FORMATS = {
"csv": PandasFormat(
read_csv_with_unnamed,
pd.DataFrame.to_csv,
file_name="data.csv",
write_args={"index": False},
),
"json": PandasFormat(
read_json_reset_index,
pd.DataFrame.to_json,
file_name="data.json",
write_args={"date_format": "iso", "date_unit": "ns"},
),
"html": PandasFormat(
read_html,
pd.DataFrame.to_html,
file_name="data.html",
write_args={"index": False},
string_buffer=True,
),
"excel": PandasFormat(
pd.read_excel,
pd.DataFrame.to_excel,
file_name="data.xlsx",
write_args={"index": False},
),
"parquet": PandasFormat(
pd.read_parquet, pd.DataFrame.to_parquet, file_name="data.parquet"
),
"feather": PandasFormat(
pd.read_feather, pd.DataFrame.to_feather, file_name="data.feather"
),
"pickle": PandasFormat( # TODO buffer closed error for some reason
pd.read_pickle, pd.DataFrame.to_pickle, file_name="data.pkl"
),
"stata": PandasFormat( # TODO int32 converts to int64 for some reason
pd.read_stata,
pd.DataFrame.to_stata,
write_args={"write_index": False},
),
}

return PANDAS_FORMATS


class _PandasIO(BaseModel):
Expand All @@ -485,13 +528,13 @@ class _PandasIO(BaseModel):
def is_valid_format( # pylint: disable=no-self-argument
cls, value # noqa: B902
):
if value not in PANDAS_FORMATS:
if value not in get_pandas_formats():
raise ValueError(f"format {value} is not supported")
return value

@property
def fmt(self):
return PANDAS_FORMATS[self.format]
return get_pandas_formats()[self.format]


class PandasReader(_PandasIO, DatasetReader):
Expand All @@ -505,16 +548,38 @@ def read(self, artifacts: Artifacts) -> DatasetType:
self.dataset_type.align(self.fmt.read(artifacts))
)

def read_batch(self, artifacts: Artifacts, batch: int) -> DatasetType:
fmt = update_batch_args(self.format, self.fmt, batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I get now this thing above. I'd say we just need separate set of args for batch and non-batch reading, so you don't have to change the state every time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored this bit - so there's no longer a need to update args based on batch/non-batch reads.


# Pandas supports batch-reading for JSON only if the JSON file is line-delimited
# https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html#line-delimited-json
if self.format == "json":
dataset_lines = sum(1 for line in open(artifacts["data"].uri))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. you are reading the whole file here, but you actually need 2 lines to know there are enough.
  2. if it is one-line json, you still read the whole dataset. and then you actually read it again if everything is ok
  3. it will fail you your dataset has just 1 row

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's assume dataset file is in the right format for now

if dataset_lines <= 1:
raise DatasetBatchLoadingJSONError(
"Batch-loading Dataset of type JSON requires provided JSON file to be line-delimited."
)
return self.dataset_type.copy().bind(
self.dataset_type.align(fmt.read(artifacts))
)


class PandasWriter(DatasetWriter, _PandasIO):
"""DatasetWriter for pandas dataframes"""

type: ClassVar[str] = "pandas"

def write(
self, dataset: DatasetType, storage: Storage, path: str
self,
dataset: DatasetType,
storage: Storage,
path: str,
writer_fmt_args: Optional[Dict[str, Any]] = None,
) -> Tuple[DatasetReader, Artifacts]:
fmt = self.fmt
if writer_fmt_args:
for k, v in writer_fmt_args.items():
setattr(fmt, k, v)
art = fmt.write(dataset.data, storage, path)
if not isinstance(dataset, DataFrameType):
raise ValueError("Cannot write non-pandas Dataset")
Expand All @@ -524,7 +589,7 @@ def write(


class PandasImport(ExtImportHook):
EXTS: ClassVar = tuple(f".{k}" for k in PANDAS_FORMATS)
EXTS: ClassVar = tuple(f".{k}" for k in get_pandas_formats())
type: ClassVar = "pandas"

@classmethod
Expand All @@ -537,10 +602,13 @@ def process(
obj: Location,
copy_data: bool = True,
modifier: Optional[str] = None,
batch: Optional[int] = None,
**kwargs,
) -> MlemMeta:
ext = modifier or posixpath.splitext(obj.path)[1][1:]
fmt = PANDAS_FORMATS[ext]
fmt = get_pandas_formats()[ext]
if batch:
fmt = update_batch_args(ext, fmt, batch)
read_args = fmt.read_args or {}
read_args.update(kwargs)
with obj.open("rb") as f:
Expand All @@ -555,3 +623,27 @@ def process(
)
}
return meta


def update_batch_args(
type_: str, fmt: PandasFormat, batch: int
) -> PandasFormat:
# Check if batch reading is supported for specified _type format
if type_ == "csv":
fmt.read_func = read_batch_csv_with_unnamed
fmt.read_args = {"chunksize": batch}
elif type_ == "json":
fmt.read_func = partial(read_batch_reset_index, pd.read_json)
# JSON batch-reading requires line-delimited data, and orient to be records
fmt.read_args = {
"chunksize": batch,
"lines": True,
"orient": "records",
}
elif type_ == "stata":
fmt.read_func = partial(read_batch_reset_index, pd.read_stata)
fmt.read_args = {"chunksize": batch}
else:
raise UnsupportedDatasetBatchLoadingType(type_)

return fmt
Loading