-
Notifications
You must be signed in to change notification settings - Fork 44
WIP: Support batch-reading for data-types with chunksize parameter #206
Changes from 5 commits
90a4ee4
83a05f6
5c3ed28
a1e3a06
9e012d1
365ca09
8d089f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,6 +60,9 @@ def apply( | |
# TODO: change ImportHook to MlemObject to support ext machinery | ||
help=f"Specify how to read data file for import. Available types: {list_implementations(ImportHook)}", | ||
), | ||
batch: Optional[int] = Option( | ||
None, "-b", "--batch", help="Batch size for reading data in batches." | ||
), | ||
link: bool = option_link, | ||
external: bool = option_external, | ||
json: bool = option_json, | ||
|
@@ -83,7 +86,11 @@ def apply( | |
with set_echo(None if json else ...): | ||
if import_: | ||
dataset = import_object( | ||
data, repo=data_repo, rev=data_rev, type_=import_type | ||
data, | ||
repo=data_repo, | ||
rev=data_rev, | ||
type_=import_type, | ||
batch=batch, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for now let's ignore this branch and focus on the other one without import. once we done, we can discuss how to approach this (I am not sure myself) |
||
) | ||
else: | ||
dataset = load_meta( | ||
|
@@ -92,6 +99,7 @@ def apply( | |
data_rev, | ||
load_value=True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as per comments in |
||
force_type=DatasetMeta, | ||
batch=batch, | ||
) | ||
meta = load_meta(model, repo, rev, force_type=ModelMeta) | ||
|
||
|
@@ -102,6 +110,7 @@ def apply( | |
output=output, | ||
link=link, | ||
external=external, | ||
batch=batch, | ||
) | ||
if output is None and json: | ||
print( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
import re | ||
from abc import ABC | ||
from dataclasses import dataclass | ||
from functools import partial | ||
from typing import ( | ||
IO, | ||
Any, | ||
|
@@ -44,7 +45,12 @@ | |
DatasetType, | ||
DatasetWriter, | ||
) | ||
from mlem.core.errors import DeserializationError, SerializationError | ||
from mlem.core.errors import ( | ||
DatasetBatchLoadingJSONError, | ||
DeserializationError, | ||
SerializationError, | ||
UnsupportedDatasetBatchLoadingType, | ||
) | ||
from mlem.core.import_objects import ExtImportHook | ||
from mlem.core.meta_io import Location | ||
from mlem.core.metadata import get_object_metadata | ||
|
@@ -320,7 +326,7 @@ def get_writer(self, **kwargs): | |
if filename is not None: | ||
_, ext = os.path.splitext(filename) | ||
ext = ext.lstrip(".") | ||
if ext in PANDAS_FORMATS: | ||
if ext in get_pandas_formats(): | ||
fmt = ext | ||
return PandasWriter(format=fmt) | ||
|
||
|
@@ -386,6 +392,9 @@ class PandasFormat: | |
file_name: str = "data.pd" | ||
string_buffer: bool = False | ||
|
||
def update_variable(self, name: str, val: Any): | ||
setattr(self, name, val) | ||
|
||
def read(self, artifacts: Artifacts, **kwargs): | ||
"""Read DataFrame""" | ||
read_kwargs = {} | ||
|
@@ -428,54 +437,88 @@ def read_csv_with_unnamed(*args, **kwargs): | |
return df.rename(unnamed, axis=1) # pylint: disable=no-member | ||
|
||
|
||
def read_batch_csv_with_unnamed(*args, **kwargs): | ||
df = None | ||
df_iterator = pd.read_csv(*args, **kwargs) | ||
unnamed = {} | ||
for i, df_chunk in enumerate(df_iterator): | ||
# Instantiate Pandas DataFrame with columns if it is the first chunk | ||
if i == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this may be |
||
df = pd.DataFrame(columns=df_chunk.columns) | ||
for col in df_chunk.columns: | ||
if col.startswith("Unnamed: "): | ||
unnamed[col] = "" | ||
df = pd.concat([df, df_chunk], ignore_index=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You still read the whole file to memory. The idea is to apply model to each part before you read next There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, sorry I misunderstood. - I'll refactor and implement iterator for these batch functions. |
||
if not unnamed: | ||
return df | ||
return df.rename(unnamed, axis=1) # pylint: disable=no-member | ||
|
||
|
||
def read_json_reset_index(*args, **kwargs): | ||
return pd.read_json(*args, **kwargs).reset_index(drop=True) | ||
|
||
|
||
def read_batch_reset_index(read_func: Callable, *args, **kwargs): | ||
df = pd.DataFrame() | ||
df_iterator = read_func(*args, **kwargs) | ||
for i, df_chunk in enumerate(df_iterator): | ||
# Instantiate Pandas DataFrame with columns if it is the first chunk | ||
if i == 0: | ||
df = pd.DataFrame(columns=df_chunk.columns) | ||
df = pd.concat([df, df_chunk], ignore_index=True) | ||
df = df.reset_index(drop=True) | ||
return df | ||
|
||
|
||
def read_html(*args, **kwargs): | ||
# read_html returns list of dataframes | ||
return pd.read_html(*args, **kwargs)[0] | ||
|
||
|
||
PANDAS_FORMATS = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactored to make this a non-global variable to avoid needing hacks around tests for batch-reading.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure, why you need to change something for tests? This means you testing something else. no? |
||
"csv": PandasFormat( | ||
read_csv_with_unnamed, | ||
pd.DataFrame.to_csv, | ||
file_name="data.csv", | ||
write_args={"index": False}, | ||
), | ||
"json": PandasFormat( | ||
read_json_reset_index, | ||
pd.DataFrame.to_json, | ||
file_name="data.json", | ||
write_args={"date_format": "iso", "date_unit": "ns"}, | ||
), | ||
"html": PandasFormat( | ||
read_html, | ||
pd.DataFrame.to_html, | ||
file_name="data.html", | ||
write_args={"index": False}, | ||
string_buffer=True, | ||
), | ||
"excel": PandasFormat( | ||
pd.read_excel, | ||
pd.DataFrame.to_excel, | ||
file_name="data.xlsx", | ||
write_args={"index": False}, | ||
), | ||
"parquet": PandasFormat( | ||
pd.read_parquet, pd.DataFrame.to_parquet, file_name="data.parquet" | ||
), | ||
"feather": PandasFormat( | ||
pd.read_feather, pd.DataFrame.to_feather, file_name="data.feather" | ||
), | ||
"pickle": PandasFormat( # TODO buffer closed error for some reason | ||
pd.read_pickle, pd.DataFrame.to_pickle, file_name="data.pkl" | ||
), | ||
"strata": PandasFormat( # TODO int32 converts to int64 for some reason | ||
mike0sv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pd.read_stata, pd.DataFrame.to_stata, write_args={"write_index": False} | ||
), | ||
} | ||
def get_pandas_formats(): | ||
PANDAS_FORMATS = { | ||
"csv": PandasFormat( | ||
read_csv_with_unnamed, | ||
pd.DataFrame.to_csv, | ||
file_name="data.csv", | ||
write_args={"index": False}, | ||
), | ||
"json": PandasFormat( | ||
read_json_reset_index, | ||
pd.DataFrame.to_json, | ||
file_name="data.json", | ||
write_args={"date_format": "iso", "date_unit": "ns"}, | ||
), | ||
"html": PandasFormat( | ||
read_html, | ||
pd.DataFrame.to_html, | ||
file_name="data.html", | ||
write_args={"index": False}, | ||
string_buffer=True, | ||
), | ||
"excel": PandasFormat( | ||
pd.read_excel, | ||
pd.DataFrame.to_excel, | ||
file_name="data.xlsx", | ||
write_args={"index": False}, | ||
), | ||
"parquet": PandasFormat( | ||
pd.read_parquet, pd.DataFrame.to_parquet, file_name="data.parquet" | ||
), | ||
"feather": PandasFormat( | ||
pd.read_feather, pd.DataFrame.to_feather, file_name="data.feather" | ||
), | ||
"pickle": PandasFormat( # TODO buffer closed error for some reason | ||
pd.read_pickle, pd.DataFrame.to_pickle, file_name="data.pkl" | ||
), | ||
"stata": PandasFormat( # TODO int32 converts to int64 for some reason | ||
pd.read_stata, | ||
pd.DataFrame.to_stata, | ||
write_args={"write_index": False}, | ||
), | ||
} | ||
|
||
return PANDAS_FORMATS | ||
|
||
|
||
class _PandasIO(BaseModel): | ||
|
@@ -485,13 +528,13 @@ class _PandasIO(BaseModel): | |
def is_valid_format( # pylint: disable=no-self-argument | ||
cls, value # noqa: B902 | ||
): | ||
if value not in PANDAS_FORMATS: | ||
if value not in get_pandas_formats(): | ||
raise ValueError(f"format {value} is not supported") | ||
return value | ||
|
||
@property | ||
def fmt(self): | ||
return PANDAS_FORMATS[self.format] | ||
return get_pandas_formats()[self.format] | ||
|
||
|
||
class PandasReader(_PandasIO, DatasetReader): | ||
|
@@ -505,16 +548,38 @@ def read(self, artifacts: Artifacts) -> DatasetType: | |
self.dataset_type.align(self.fmt.read(artifacts)) | ||
) | ||
|
||
def read_batch(self, artifacts: Artifacts, batch: int) -> DatasetType: | ||
fmt = update_batch_args(self.format, self.fmt, batch) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I get now this thing above. I'd say we just need separate set of args for batch and non-batch reading, so you don't have to change the state every time There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactored this bit - so there's no longer a need to update args based on batch/non-batch reads. |
||
|
||
# Pandas supports batch-reading for JSON only if the JSON file is line-delimited | ||
# https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html#line-delimited-json | ||
if self.format == "json": | ||
dataset_lines = sum(1 for line in open(artifacts["data"].uri)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's assume dataset file is in the right format for now |
||
if dataset_lines <= 1: | ||
raise DatasetBatchLoadingJSONError( | ||
"Batch-loading Dataset of type JSON requires provided JSON file to be line-delimited." | ||
) | ||
return self.dataset_type.copy().bind( | ||
self.dataset_type.align(fmt.read(artifacts)) | ||
) | ||
|
||
|
||
class PandasWriter(DatasetWriter, _PandasIO): | ||
"""DatasetWriter for pandas dataframes""" | ||
|
||
type: ClassVar[str] = "pandas" | ||
|
||
def write( | ||
self, dataset: DatasetType, storage: Storage, path: str | ||
self, | ||
dataset: DatasetType, | ||
storage: Storage, | ||
path: str, | ||
writer_fmt_args: Optional[Dict[str, Any]] = None, | ||
) -> Tuple[DatasetReader, Artifacts]: | ||
fmt = self.fmt | ||
if writer_fmt_args: | ||
for k, v in writer_fmt_args.items(): | ||
setattr(fmt, k, v) | ||
art = fmt.write(dataset.data, storage, path) | ||
if not isinstance(dataset, DataFrameType): | ||
raise ValueError("Cannot write non-pandas Dataset") | ||
|
@@ -524,7 +589,7 @@ def write( | |
|
||
|
||
class PandasImport(ExtImportHook): | ||
EXTS: ClassVar = tuple(f".{k}" for k in PANDAS_FORMATS) | ||
EXTS: ClassVar = tuple(f".{k}" for k in get_pandas_formats()) | ||
type: ClassVar = "pandas" | ||
|
||
@classmethod | ||
|
@@ -537,10 +602,13 @@ def process( | |
obj: Location, | ||
copy_data: bool = True, | ||
modifier: Optional[str] = None, | ||
batch: Optional[int] = None, | ||
**kwargs, | ||
) -> MlemMeta: | ||
ext = modifier or posixpath.splitext(obj.path)[1][1:] | ||
fmt = PANDAS_FORMATS[ext] | ||
fmt = get_pandas_formats()[ext] | ||
if batch: | ||
fmt = update_batch_args(ext, fmt, batch) | ||
read_args = fmt.read_args or {} | ||
read_args.update(kwargs) | ||
with obj.open("rb") as f: | ||
|
@@ -555,3 +623,27 @@ def process( | |
) | ||
} | ||
return meta | ||
|
||
|
||
def update_batch_args( | ||
type_: str, fmt: PandasFormat, batch: int | ||
) -> PandasFormat: | ||
# Check if batch reading is supported for specified _type format | ||
if type_ == "csv": | ||
fmt.read_func = read_batch_csv_with_unnamed | ||
fmt.read_args = {"chunksize": batch} | ||
elif type_ == "json": | ||
fmt.read_func = partial(read_batch_reset_index, pd.read_json) | ||
# JSON batch-reading requires line-delimited data, and orient to be records | ||
fmt.read_args = { | ||
"chunksize": batch, | ||
"lines": True, | ||
"orient": "records", | ||
} | ||
elif type_ == "stata": | ||
fmt.read_func = partial(read_batch_reset_index, pd.read_stata) | ||
fmt.read_args = {"chunksize": batch} | ||
else: | ||
raise UnsupportedDatasetBatchLoadingType(type_) | ||
|
||
return fmt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here you should flatten all the batches if there are any. Here the example of how I see it:
Suppose you have
dataframe=[1,2,3,4]
(I mean 1 column, 4 rows) saved to csv file. You load its metadata without loading the value, let's saydt = DatasetMeta(dataset_type=DataFrameType(...), ...)
. And you callapply
with data=[dt]. If batch arg is not provided, what will happen isget_dataset_value
will load the actual dataframe and in the endres = [w.call_method(..., dataframe([1,2,3,4))]
. But if you providedbatch=2
,dt. read_batch
should be called. If you iterate through it, you will get 2 parts of the dataframe, and in the endres=[w.call_method(..., dataframe([1,2])), w.call_method(..., dataframe([3,4]]))]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍🏻 Makes sense, I implemented something similar yesterday, but I couldn't sync the changes to this PR because the private fork no longer points to this repository.