Skip to content

Commit

Permalink
Fix issue 976 (#980)
Browse files Browse the repository at this point in the history
* Allow optional dtype

* Add test

* Make dtype required, not optional

Also remove the pattern match on `[]`.
It doesn't actually optimize anything AFAICT.

* Also require dtype for structs

* Try to not panic

* Use `map_err` to simplify logic

* Fix linting errors (I think)
  • Loading branch information
billylanchantin authored Sep 9, 2024
1 parent c31b573 commit ffcb2e6
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 33 deletions.
4 changes: 2 additions & 2 deletions lib/explorer/polars_backend/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,8 @@ defmodule Explorer.PolarsBackend.Native do
def s_from_list_str(_name, _val), do: err()
def s_from_list_binary(_name, _val), do: err()
def s_from_list_categories(_name, _val), do: err()
def s_from_list_of_series(_name, _val), do: err()
def s_from_list_of_series_as_structs(_name, _val), do: err()
def s_from_list_of_series(_name, _val, _dtype), do: err()
def s_from_list_of_series_as_structs(_name, _val, _dtype), do: err()
def s_from_binary_f32(_name, _val), do: err()
def s_from_binary_f64(_name, _val), do: err()
def s_from_binary_s8(_name, _val), do: err()
Expand Down
25 changes: 7 additions & 18 deletions lib/explorer/polars_backend/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -127,28 +127,17 @@ defmodule Explorer.PolarsBackend.Shared do

def from_list(list, dtype), do: from_list(list, dtype, "")

def from_list([], {:list, _} = dtype, name) do
polars_series = Native.s_from_list_of_series(name, [])
{:ok, casted} = Native.s_cast(polars_series, dtype)
casted
end

def from_list(list, {:list, inner_dtype} = _dtype, name) when is_list(list) do
def from_list(list, {:list, inner_dtype} = dtype, name) do
series =
Enum.map(list, fn maybe_inner_list ->
if is_list(maybe_inner_list), do: from_list(maybe_inner_list, inner_dtype, name)
Enum.map(list, fn
inner_list when is_list(inner_list) -> from_list(inner_list, inner_dtype, name)
_ -> nil
end)

Native.s_from_list_of_series(name, series)
end

def from_list([], {:struct, _} = dtype, name) do
polars_series = Native.s_from_list_of_series_as_structs(name, [])
{:ok, casted} = Native.s_cast(polars_series, dtype)
casted
Native.s_from_list_of_series(name, series, dtype)
end

def from_list(list, {:struct, fields}, name) when is_list(list) do
def from_list(list, {:struct, fields} = dtype, name) when is_list(list) do
columns = Map.new(fields, fn {k, _v} -> {k, []} end)

columns =
Expand All @@ -172,7 +161,7 @@ defmodule Explorer.PolarsBackend.Shared do
|> from_list(inner_dtype, field)
end

Native.s_from_list_of_series_as_structs(name, series)
Native.s_from_list_of_series_as_structs(name, series, dtype)
end

def from_list(list, dtype, name) when is_list(list) do
Expand Down
46 changes: 33 additions & 13 deletions native/explorer/src/series/from_list.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use crate::atoms;
use crate::datatypes::{ExDate, ExDateTime, ExDuration, ExNaiveDateTime, ExTime, ExTimeUnit};
use crate::datatypes::{
ExDate, ExDateTime, ExDuration, ExNaiveDateTime, ExSeriesDtype, ExTime, ExTimeUnit,
};
use crate::{ExSeries, ExplorerError};

use polars::datatypes::DataType;
use polars::prelude::*;
use rustler::{Binary, Error, ListIterator, NifResult, Term, TermType};
use std::slice;
Expand Down Expand Up @@ -307,10 +310,16 @@ pub fn s_from_list_categories(name: &str, val: Term) -> NifResult<ExSeries> {
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_from_list_of_series(name: &str, series_term: Term) -> NifResult<ExSeries> {
pub fn s_from_list_of_series(
name: &str,
series_term: Term,
ex_dtype: ExSeriesDtype,
) -> NifResult<ExSeries> {
let dtype = DataType::try_from(&ex_dtype).unwrap();

series_term
.decode::<Vec<Option<ExSeries>>>()
.map(|series_vec| {
.and_then(|series_vec| {
let lists: Vec<Option<Series>> = series_vec
.iter()
.map(|maybe_series| {
Expand All @@ -320,27 +329,38 @@ pub fn s_from_list_of_series(name: &str, series_term: Term) -> NifResult<ExSerie
})
.collect();

ExSeries::new(Series::new(name, lists))
Series::new(name, lists).cast(&dtype).map_err(|err| {
let message = format!("from_list/2 cannot create series of lists: {err:?}");
Error::RaiseTerm(Box::new(message))
})
})
.map(ExSeries::new)
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_from_list_of_series_as_structs(name: &str, series_term: Term) -> NifResult<ExSeries> {
pub fn s_from_list_of_series_as_structs(
name: &str,
series_term: Term,
ex_dtype: ExSeriesDtype,
) -> NifResult<ExSeries> {
let dtype = DataType::try_from(&ex_dtype).unwrap();
let series_vec = series_term.decode::<Vec<ExSeries>>()?;
match StructChunked::from_series(

StructChunked::from_series(
name,
series_vec
.into_iter()
.map(|s| s.clone_inner())
.collect::<Vec<_>>()
.as_slice(),
) {
Ok(struct_chunked) => Ok(ExSeries::new(struct_chunked.into_series())),
Err(err) => {
let message = format!("from_list/2 cannot create series of structs: {err:?}");
Err(Error::RaiseTerm(Box::new(message)))
}
}
)
.map(|struct_chunked| struct_chunked.into_series())
.and_then(|series| series.cast(&dtype))
.map_err(|err| {
let message = format!("from_list/2 cannot create series of structs: {err:?}");
Error::RaiseTerm(Box::new(message))
})
.map(ExSeries::new)
}

macro_rules! from_binary {
Expand Down
13 changes: 13 additions & 0 deletions test/explorer/data_frame_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,19 @@ defmodule Explorer.DataFrameTest do

assert DF.dtypes(df) == %{"dates" => :date}
end

test "lists of structs of lists with nils (issue #976)" do
df =
Explorer.DataFrame.new(
[
%{a: [%{b: [""]}]},
%{a: [%{b: nil}]}
],
dtypes: [{"a", {:list, {:struct, [{"b", {:list, :string}}]}}}]
)

assert DF.to_columns(df) == %{"a" => [[%{"b" => [""]}], [%{"b" => nil}]]}
end
end

describe "mask/2" do
Expand Down

0 comments on commit ffcb2e6

Please sign in to comment.