Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Jan 14, 2025
1 parent d48bf53 commit 7386a84
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,7 +1220,7 @@ def _iter_arrow(self, max_chunksize: Optional[int] = None) -> Iterator[Tuple[Key
if not isinstance(output_table, pa.Table):
raise TypeError(
f"Provided `function` which is applied to {formatter.table_type} returns a variable of type "
f"{type(output_table)}. Make sure provided `function` returns a {formatter.table_type} to update the dataset."
f"{type(output)}. Make sure provided `function` returns a {formatter.table_type} to update the dataset."
)
# we don't need to merge results for consistency with Dataset.map which merges iif both input and output are dicts
# then remove the unwanted columns
Expand Down Expand Up @@ -1419,7 +1419,7 @@ def _iter(self):
yield key, example

def _iter_arrow(self, max_chunksize: Optional[int] = None):
formatter = get_formatter(self.formatting) if self.formatting else ArrowFormatter()
formatter = get_formatter(self.formatting.format_type) if self.formatting else ArrowFormatter()
if self.ex_iterable.iter_arrow:
iterator = self.ex_iterable.iter_arrow()
else:
Expand Down Expand Up @@ -1456,10 +1456,10 @@ def _iter_arrow(self, max_chunksize: Optional[int] = None):
# then apply the transform
output = self.function(*function_args, **self.fn_kwargs)
mask = _table_output_to_arrow(output)
if not isinstance(mask, (pa.Array, pa.BooleanScalar)):
if not isinstance(mask, (bool, pa.Array, pa.BooleanScalar)):
raise TypeError(
f"Provided `function` which is applied to {formatter.table_type} returns a variable of type "
f"{type(output_table)}. Make sure provided `function` returns a {formatter.column_type} to update the dataset."
f"{type(output)}. Make sure provided `function` returns a {formatter.column_type} to update the dataset."
)
# return output
if self.batched:
Expand Down

0 comments on commit 7386a84

Please sign in to comment.