diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index 46178e480..bb005f72e 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -10,13 +10,18 @@ defmodule Explorer.Series do * `:date` - Date type that unwraps to `Elixir.Date` * `{:datetime, precision}` - DateTime type with millisecond/microsecond/nanosecond precision that unwraps to `Elixir.NaiveDateTime` * `{:duration, precision}` - Duration type with millisecond/microsecond/nanosecond precision that unwraps to `Explorer.Duration` - * `{:f, size}` - a 64-bit or 32-bit floating point number. The atom `:float` can be used as an alias for `{:f, 64}`. + * `{:f, size}` - a 64-bit or 32-bit floating point number * `:integer` - 64-bit signed integer * `:string` - UTF-8 encoded binary * `:time` - Time type that unwraps to `Elixir.Time` * `{:list, dtype}` - A recursive dtype that can store lists. Examples: `{:list, :integer}` or a nested list dtype like `{:list, {:list, :integer}}`. + The following data type aliases are also supported: + + * The atom `:float` as an alias for `{:f, 64}` to mirror Elixir's floats + * The atoms `:f32` and `:f64` as aliases to `{:f, 32}` and `{:f, 64}` for Nx compabitility + A series must consist of a single data type only. Series may have `nil` values in them. The series `dtype` can be retrieved via the `dtype/1` function or directly accessed as `series.dtype`. A `series.name` field is also available, but it is always `nil` unless @@ -264,6 +269,18 @@ defmodule Explorer.Series do string ["1", nil] > + iex> Explorer.Series.from_list([1, 2], dtype: :f32) + #Explorer.Series< + Polars[2] + f32 [1.0, 2.0] + > + + iex> Explorer.Series.from_list([1, nil, 2], dtype: :float) + #Explorer.Series< + Polars[3] + f64 [1.0, nil, 2.0] + > + The `dtype` option is particulary important if a `:binary` series is desired, because by default binary series will have the dtype of `:string`: diff --git a/lib/explorer/shared.ex b/lib/explorer/shared.ex index be3ceac0d..8d0308179 100644 --- a/lib/explorer/shared.ex +++ b/lib/explorer/shared.ex @@ -38,7 +38,8 @@ defmodule Explorer.Shared do end def normalise_dtype(dtype) when dtype in @non_list_types, do: dtype - def normalise_dtype(:float), do: {:f, 64} + def normalise_dtype(dtype) when dtype in [:float, :f64], do: {:f, 64} + def normalise_dtype(:f32), do: {:f, 32} def normalise_dtype(_dtype), do: nil @doc """ diff --git a/test/explorer/series_test.exs b/test/explorer/series_test.exs index 680b1141d..9b1ebd982 100644 --- a/test/explorer/series_test.exs +++ b/test/explorer/series_test.exs @@ -148,10 +148,23 @@ defmodule Explorer.SeriesTest do end test "integers as {:f, 64}" do - s = Series.from_list([1, 2, 3, 4], dtype: :float) - assert s[0] == 1.0 - assert Series.to_list(s) === [1.0, 2.0, 3.0, 4.0] - assert Series.dtype(s) == {:f, 64} + # Shortcuts and the "real" dtype + for dtype <- [:float, :f64, {:f, 64}] do + s = Series.from_list([1, 2, 3, 4], dtype: dtype) + assert s[0] === 1.0 + assert Series.to_list(s) === [1.0, 2.0, 3.0, 4.0] + assert Series.dtype(s) === {:f, 64} + end + end + + test "integers as {:f, 32}" do + # Shortcut and the "real" dtype + for dtype <- [:f32, {:f, 32}] do + s = Series.from_list([1, 2, 3, 4], dtype: dtype) + assert s[0] === 1.0 + assert Series.to_list(s) === [1.0, 2.0, 3.0, 4.0] + assert Series.dtype(s) === {:f, 32} + end end test "mixing integers with an invalid atom" do