Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
Fix batch dataset reading
Browse files Browse the repository at this point in the history
  • Loading branch information
terryyylim committed May 22, 2022
1 parent 575241b commit 8f9dbbf
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
6 changes: 4 additions & 2 deletions mlem/api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from mlem.config import CONFIG_FILE_NAME, repo_config
from mlem.constants import PREDICT_METHOD_NAME
from mlem.core.dataset_type import DatasetAnalyzer
from mlem.core.errors import (
InvalidArgumentError,
MlemObjectNotFound,
Expand Down Expand Up @@ -85,12 +86,13 @@ def apply(
resolved_method = PREDICT_METHOD_NAME
echo(EMOJI_APPLY + f"Applying `{resolved_method}` method...")
if batch_size:
res: Any = []
res: Any = None
for part in data:
batch_dataset = get_dataset_value(part, batch_size)
for chunk in batch_dataset:
preds = w.call_method(resolved_method, chunk.data)
res += [*preds] # TODO: merge results
dt = DatasetAnalyzer.analyze(preds)
res = dt.combine(res, preds)
else:
res = [
w.call_method(resolved_method, get_dataset_value(part))
Expand Down
6 changes: 6 additions & 0 deletions mlem/contrib/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ class NumpyNdarrayType(
def _abstract_shape(shape):
return (None,) + shape[1:]

@staticmethod
def combine(original: np.ndarray, new: np.ndarray):
if original is None:
return new
return np.concatenate((original, new))

@classmethod
def process(cls, obj, **kwargs) -> DatasetType:
return NumpyNdarrayType(
Expand Down
4 changes: 4 additions & 0 deletions mlem/core/dataset_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def check_type(obj, exp_type, exc_type):
f"given dataset is of type: {type(obj)}, expected: {exp_type}"
)

@staticmethod
def combine(original: Any, new: Any):
raise NotImplementedError

@abstractmethod
def get_requirements(self) -> Requirements:
return get_object_requirements(self)
Expand Down

0 comments on commit 8f9dbbf

Please sign in to comment.