Skip to content

Commit

Permalink
Fix step operations
Browse files Browse the repository at this point in the history
  • Loading branch information
AjayP13 committed Jan 31, 2024
1 parent c75f621 commit 1f1bbdb
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ warn_unused_ignores = true
mypy_path = "src/_stubs"

[[tool.mypy.overrides]]
module = "click,wandb,wandb.*,click.testing,flaky,tensorflow,torch_xla,jax,datasets.features.features,datasets.iterable_dataset,datasets.fingerprint,datasets.builder,datasets.splits,datasets.utils,datasets.utils.version,pyarrow.lib,huggingface_hub,huggingface_hub.utils._headers,huggingface_hub.utils._errors,dill,dill.source,transformers,bitsandbytes,sqlitedict,optimum.bettertransformer,optimum.bettertransformer.models,optimum.utils,transformers.utils.quantization_config,sortedcontainers,peft,psutil,ring,ctransformers,petals,petals.client.inference_session,hivemind.p2p.p2p_daemon_bindings.utils,huggingface_hub.utils,tqdm,ctransformers.transformers,vllm,litellm,litellm.llms.palm,litellm.exceptions,sentence_transformers,faiss,huggingface_hub.utils._validators,evaluate,transformers.trainer_callback,transformers.training_args,trl,guidance,sentence_transformers.models.Transformer,trl.trainer.utils,transformers.trainer_utils,setfit,joblib,setfit.modeling,transformers.utils.notebook,mistralai.models.chat_completion,accelerate.utils,accelerate.utils.constants,accelerate,transformers.trainer,sentence_transformers.util,Pyro5,Pyro5.server,Pyro5.api,Pyro5,datadreamer"
module = "click,wandb,wandb.*,click.testing,flaky,tensorflow,torch_xla,jax,datasets.features.features,datasets.iterable_dataset,datasets.fingerprint,datasets.builder,datasets.arrow_writer,datasets.splits,datasets.utils,datasets.utils.version,pyarrow.lib,huggingface_hub,huggingface_hub.utils._headers,huggingface_hub.utils._errors,dill,dill.source,transformers,bitsandbytes,sqlitedict,optimum.bettertransformer,optimum.bettertransformer.models,optimum.utils,transformers.utils.quantization_config,sortedcontainers,peft,psutil,ring,ctransformers,petals,petals.client.inference_session,hivemind.p2p.p2p_daemon_bindings.utils,huggingface_hub.utils,tqdm,ctransformers.transformers,vllm,litellm,litellm.llms.palm,litellm.exceptions,sentence_transformers,faiss,huggingface_hub.utils._validators,evaluate,transformers.trainer_callback,transformers.training_args,trl,guidance,sentence_transformers.models.Transformer,trl.trainer.utils,transformers.trainer_utils,setfit,joblib,setfit.modeling,transformers.utils.notebook,mistralai.models.chat_completion,accelerate.utils,accelerate.utils.constants,accelerate,transformers.trainer,sentence_transformers.util,Pyro5,Pyro5.server,Pyro5.api,Pyro5,datadreamer"
ignore_missing_imports = true

[tool.coverage.run]
Expand Down
30 changes: 19 additions & 11 deletions src/steps/step_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import TYPE_CHECKING, Any, Callable, Sequence, cast

from datasets import Dataset, DatasetDict, IterableDataset, concatenate_datasets
from datasets.arrow_writer import SchemaInferenceError
from datasets.builder import DatasetGenerationError
from datasets.fingerprint import Hasher
from pyarrow.lib import ArrowInvalid, ArrowTypeError
Expand All @@ -30,7 +31,7 @@
##################################


def _iterable_dataset_to_dataset(
def _iterable_dataset_to_dataset( # noqa: C901
self,
step: "Step",
iterable_dataset: IterableDataset,
Expand Down Expand Up @@ -61,16 +62,23 @@ def dataset_generator(iterable_dataset):
cache_path = os.path.join(
step_to_use._output_folder_path, ".datadreamer_save_cache"
)
dataset = Dataset.from_generator(
partial(dataset_generator, iterable_dataset),
# Commenting out features=..., because it fails when remove_columns is
# used with something like map()
# features=step.output._features,
features=None,
cache_dir=cache_path,
writer_batch_size=writer_batch_size,
num_proc=default_to(save_num_proc, None),
)
for features in [None, step.output._features]:
try:
dataset = Dataset.from_generator(
partial(dataset_generator, iterable_dataset),
features=features,
cache_dir=cache_path,
writer_batch_size=writer_batch_size,
num_proc=default_to(save_num_proc, None),
)
break
except DatasetGenerationError as e:
if features is None and isinstance(
e.__cause__, SchemaInferenceError
):
continue
else: # pragma: no cover
raise e
else:
dataset = Dataset.from_list(list(dataset_generator(iterable_dataset)))
except DatasetGenerationError as e:
Expand Down
4 changes: 2 additions & 2 deletions src/steps/step_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
iterable_dataset,
)
from datasets.builder import DatasetGenerationError
from datasets.features.features import Features
from datasets.features.features import Features, Value
from datasets.iterable_dataset import (
_apply_feature_types_on_batch,
_apply_feature_types_on_example,
Expand Down Expand Up @@ -527,7 +527,7 @@ def to_dict_generator_wrapper(_value, output_names, _value_is_batched):
)

# Create an IterableDataset if given a generator function of dicts
features = Features([(n, None) for n in output_names])
features = Features([(n, Value("null")) for n in output_names])
if is_lazy and _is_lazy_type(_value):
# Make sure the generator returns a dict and the keys are correct
try:
Expand Down
16 changes: 16 additions & 0 deletions src/tests/steps/test_step_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,22 @@ def test_map_lazy_pickle_error(
step.map(lambda row: {"out1": row["out1"]}, total_num_rows=3).output
)

def test_map_empty_generator(
self, create_datadreamer, create_test_step: Callable[..., Step]
):
with create_datadreamer():
step = create_test_step(name="my-step", inputs=None, output_names=["out1"])

def empty_generator():
return iter(())

step._set_output(LazyRows(empty_generator, total_num_rows=0))
map_step = step.map(lambda row: {"out1": row["out1"]}, lazy=False)
assert isinstance(map_step, MapStep)
assert isinstance(map_step.output, OutputDataset)
assert map_step.output.num_rows == 0
assert list(map_step.output["out1"]) == []


class TestFilter:
def test_filter_on_dataset(
Expand Down

0 comments on commit 1f1bbdb

Please sign in to comment.