Skip to content

Commit

Permalink
Do not convert io_dtypes to maps, closes #971
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Aug 27, 2024
1 parent e5d7ef9 commit 943feeb
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 9 deletions.
7 changes: 4 additions & 3 deletions lib/explorer/backend/data_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ defmodule Explorer.Backend.DataFrame do
@type column_name :: String.t()
@type dtype :: Explorer.Series.dtype()
@type dtypes :: %{column_name() => dtype()}
@type io_dtypes :: [{column_name(), dtype()}]

@type basic_types :: float() | integer() | String.t() | Date.t() | DateTime.t()
@type mutate_value ::
Expand All @@ -46,7 +47,7 @@ defmodule Explorer.Backend.DataFrame do
# IO: CSV
@callback from_csv(
entry :: fs_entry(),
dtypes,
io_dtypes,
delimiter :: String.t(),
nil_values :: list(String.t()),
skip_rows :: integer(),
Expand All @@ -71,7 +72,7 @@ defmodule Explorer.Backend.DataFrame do

@callback load_csv(
contents :: String.t(),
dtypes,
io_dtypes,
delimiter :: String.t(),
nil_values :: list(String.t()),
skip_rows :: integer(),
Expand Down Expand Up @@ -153,7 +154,7 @@ defmodule Explorer.Backend.DataFrame do
@callback lazy() :: module()
@callback lazy(df) :: df
@callback compute(df) :: df
@callback from_tabular(Table.Reader.t(), dtypes) :: df
@callback from_tabular(Table.Reader.t(), io_dtypes) :: df
@callback from_series([{binary(), Series.t()}]) :: df
@callback to_rows(df, atom_keys? :: boolean()) :: [map()]
@callback to_rows_stream(df, atom_keys? :: boolean(), chunk_size :: integer()) :: Enumerable.t()
Expand Down
6 changes: 4 additions & 2 deletions lib/explorer/data_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ defmodule Explorer.DataFrame do
end

defp check_dtypes!(dtypes) do
Map.new(dtypes, fn
Enum.map(dtypes, fn
{key, value} when is_atom(key) ->
{Atom.to_string(key), check_dtype!(key, value)}

Expand Down Expand Up @@ -524,7 +524,8 @@ defmodule Explorer.DataFrame do
* `:delimiter` - A single character used to separate fields within a record. (default: `","`)
* `:dtypes` - A list/map of `{"column_name", dtype}` tuples. Any non-specified column has its type
* `:dtypes` - A list of shape `[column_name: dtype]`. The column names must match the ones in the
CSV header or the ones given in the `:columns` option. Any non-specified column has its type
imputed from the first 1000 rows. (default: `[]`)
* `:header` - Does the file have a header of column names as the first row or not? (default: `true`)
Expand Down Expand Up @@ -1827,6 +1828,7 @@ defmodule Explorer.DataFrame do
backend.from_series(pairs)

{:tensor, data} ->
dtypes = Map.new(dtypes)
s_backend = df_backend_to_s_backend(backend)

pairs =
Expand Down
7 changes: 4 additions & 3 deletions lib/explorer/polars_backend/data_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ defmodule Explorer.PolarsBackend.DataFrame do
delimiter,
true,
columns,
Map.to_list(dtypes),
dtypes,
encoding,
nil_values,
parse_dates,
Expand Down Expand Up @@ -207,7 +207,7 @@ defmodule Explorer.PolarsBackend.DataFrame do
delimiter,
true,
columns,
Map.to_list(dtypes),
dtypes,
encoding,
nil_values,
parse_dates,
Expand Down Expand Up @@ -514,7 +514,8 @@ defmodule Explorer.PolarsBackend.DataFrame do
def compute(df), do: df

@impl true
def from_tabular(tabular, dtypes) do
def from_tabular(tabular, io_dtypes) do
dtypes = Map.new(io_dtypes)
{_, %{columns: keys}, _} = reader = init_reader!(tabular)
columns = Table.to_columns(reader)

Expand Down
2 changes: 1 addition & 1 deletion lib/explorer/polars_backend/lazy_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ defmodule Explorer.PolarsBackend.LazyFrame do
skip_rows_after_header,
delimiter,
true,
Map.to_list(dtypes),
dtypes,
encoding,
nil_values,
parse_dates,
Expand Down
22 changes: 22 additions & 0 deletions test/explorer/data_frame/csv_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,28 @@ defmodule Explorer.DataFrame.CSVTest do
assert city[13] == "Aberdeen, Aberdeen City, UK"
end

test "load_csv/2 dtypes - mismatched names" do
text = """
first_name , last_name , dob
Alice , Ant , 01/02/1970
Billy , Bat , 03/04/1990
"""

types = [
{"first_name", :string},
{"last_name", :string},
{"dob", :string}
]

assert text
|> DF.load_csv!(dtypes: types)
|> DF.to_columns(atom_keys: true) == %{
dob: [" 01/02/1970", " 03/04/1990"],
first_name: ["Alice ", "Billy "],
last_name: [" Ant ", " Bat "]
}
end

test "load_csv/2 dtypes - all as strings" do
csv =
"""
Expand Down

0 comments on commit 943feeb

Please sign in to comment.