diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index df19dfc6..b1072438 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -421,7 +421,7 @@ def fill_null(self, value: Any | Expr | None = None) -> Expr: _to_pyarrow_types = { float: pa.float64(), int: pa.int64(), - str: pa.string_view(), + str: pa.string(), bool: pa.bool_(), } diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 570a6ce5..9552876f 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -295,7 +295,7 @@ def decode(input: Expr, encoding: Expr) -> Expr: def array_to_string(expr: Expr, delimiter: Expr) -> Expr: """Converts each element to its text representation.""" - return Expr(f.array_to_string(expr.expr, delimiter.expr)) + return Expr(f.array_to_string(expr.expr, delimiter.expr.cast(pa.string()))) def array_join(expr: Expr, delimiter: Expr) -> Expr: @@ -924,7 +924,7 @@ def to_timestamp(arg: Expr, *formatters: Expr) -> Expr: return f.to_timestamp(arg.expr) formatters = [f.expr for f in formatters] - return Expr(f.to_timestamp(arg.expr, *formatters)) + return Expr(f.to_timestamp(arg.expr.cast(pa.string()), *formatters)) def to_timestamp_millis(arg: Expr, *formatters: Expr) -> Expr: @@ -1065,7 +1065,10 @@ def struct(*args: Expr) -> Expr: def named_struct(name_pairs: list[tuple[str, Expr]]) -> Expr: """Returns a struct with the given names and arguments pairs.""" - name_pair_exprs = [[Expr.literal(pair[0]), pair[1]] for pair in name_pairs] + name_pair_exprs = [ + [Expr.literal(pa.scalar(pair[0], type=pa.string())), pair[1]] + for pair in name_pairs + ] # flatten name_pairs = [x.expr for xs in name_pair_exprs for x in xs] @@ -1422,7 +1425,9 @@ def array_sort(array: Expr, descending: bool = False, null_first: bool = False) nulls_first = "NULLS FIRST" if null_first else "NULLS LAST" return Expr( f.array_sort( - array.expr, Expr.literal(desc).expr, Expr.literal(nulls_first).expr + array.expr, + Expr.literal(pa.scalar(desc, type=pa.string())).expr, + Expr.literal(pa.scalar(nulls_first, type=pa.string())).expr, ) ) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index d81e04c8..77f88aa4 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -130,7 +130,10 @@ def test_relational_expr(test_ctx): ctx = SessionContext() batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2, 3]), pa.array(["alpha", "beta", "gamma"])], + [ + pa.array([1, 2, 3]), + pa.array(["alpha", "beta", "gamma"], type=pa.string_view()), + ], names=["a", "b"], ) df = ctx.create_dataframe([[batch]], name="batch_array") @@ -145,7 +148,8 @@ def test_relational_expr(test_ctx): assert df.filter(col("b") == "beta").count() == 1 assert df.filter(col("b") != "beta").count() == 2 - assert df.filter(col("a") == "beta").count() == 0 + with pytest.raises(Exception): + df.filter(col("a") == "beta").count() def test_expr_to_variant(): diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index e6fd41d8..1ba4cfd8 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -34,9 +34,9 @@ def df(): # create a RecordBatch and a new DataFrame from it batch = pa.RecordBatch.from_arrays( [ - pa.array(["Hello", "World", "!"]), + pa.array(["Hello", "World", "!"], type=pa.string_view()), pa.array([4, 5, 6]), - pa.array(["hello ", " world ", " !"]), + pa.array(["hello ", " world ", " !"], type=pa.string_view()), pa.array( [ datetime(2022, 12, 31), @@ -88,8 +88,8 @@ def test_literal(df): assert len(result) == 1 result = result[0] assert result.column(0) == pa.array([1] * 3) - assert result.column(1) == pa.array(["1"] * 3) - assert result.column(2) == pa.array(["OK"] * 3) + assert result.column(1) == pa.array(["1"] * 3, type=pa.string_view()) + assert result.column(2) == pa.array(["OK"] * 3, type=pa.string_view()) assert result.column(3) == pa.array([3.14] * 3) assert result.column(4) == pa.array([True] * 3) assert result.column(5) == pa.array([b"hello world"] * 3) @@ -97,7 +97,9 @@ def test_literal(df): def test_lit_arith(df): """Test literals with arithmetic operations""" - df = df.select(literal(1) + column("b"), f.concat(column("a"), literal("!"))) + df = df.select( + literal(1) + column("b"), f.concat(column("a").cast(pa.string()), literal("!")) + ) result = df.collect() assert len(result) == 1 result = result[0] @@ -578,21 +580,33 @@ def test_array_function_obj_tests(stmt, py_expr): f.ascii(column("a")), pa.array([72, 87, 33], type=pa.int32()), ), # H = 72; W = 87; ! = 33 - (f.bit_length(column("a")), pa.array([40, 40, 8], type=pa.int32())), - (f.btrim(literal(" World ")), pa.array(["World", "World", "World"])), + ( + f.bit_length(column("a").cast(pa.string())), + pa.array([40, 40, 8], type=pa.int32()), + ), + ( + f.btrim(literal(" World ")), + pa.array(["World", "World", "World"], type=pa.string_view()), + ), (f.character_length(column("a")), pa.array([5, 5, 1], type=pa.int32())), (f.chr(literal(68)), pa.array(["D", "D", "D"])), ( f.concat_ws("-", column("a"), literal("test")), pa.array(["Hello-test", "World-test", "!-test"]), ), - (f.concat(column("a"), literal("?")), pa.array(["Hello?", "World?", "!?"])), + ( + f.concat(column("a").cast(pa.string()), literal("?")), + pa.array(["Hello?", "World?", "!?"]), + ), (f.initcap(column("c")), pa.array(["Hello ", " World ", " !"])), (f.left(column("a"), literal(3)), pa.array(["Hel", "Wor", "!"])), (f.length(column("c")), pa.array([6, 7, 2], type=pa.int32())), (f.lower(column("a")), pa.array(["hello", "world", "!"])), (f.lpad(column("a"), literal(7)), pa.array([" Hello", " World", " !"])), - (f.ltrim(column("c")), pa.array(["hello ", "world ", "!"])), + ( + f.ltrim(column("c")), + pa.array(["hello ", "world ", "!"], type=pa.string_view()), + ), ( f.md5(column("a")), pa.array( @@ -618,19 +632,25 @@ def test_array_function_obj_tests(stmt, py_expr): f.rpad(column("a"), literal(8)), pa.array(["Hello ", "World ", "! "]), ), - (f.rtrim(column("c")), pa.array(["hello", " world", " !"])), + ( + f.rtrim(column("c")), + pa.array(["hello", " world", " !"], type=pa.string_view()), + ), ( f.split_part(column("a"), literal("l"), literal(1)), pa.array(["He", "Wor", "!"]), ), (f.starts_with(column("a"), literal("Wor")), pa.array([False, True, False])), (f.strpos(column("a"), literal("o")), pa.array([5, 2, 0], type=pa.int32())), - (f.substr(column("a"), literal(3)), pa.array(["llo", "rld", ""])), + ( + f.substr(column("a"), literal(3)), + pa.array(["llo", "rld", ""], type=pa.string_view()), + ), ( f.translate(column("a"), literal("or"), literal("ld")), pa.array(["Helll", "Wldld", "!"]), ), - (f.trim(column("c")), pa.array(["hello", "world", "!"])), + (f.trim(column("c")), pa.array(["hello", "world", "!"], type=pa.string_view())), (f.upper(column("c")), pa.array(["HELLO ", " WORLD ", " !"])), (f.ends_with(column("a"), literal("llo")), pa.array([True, False, False])), ( @@ -772,9 +792,9 @@ def test_temporal_functions(df): f.date_trunc(literal("month"), column("d")), f.datetrunc(literal("day"), column("d")), f.date_bin( - literal("15 minutes"), + literal("15 minutes").cast(pa.string()), column("d"), - literal("2001-01-01 00:02:30"), + literal("2001-01-01 00:02:30").cast(pa.string()), ), f.from_unixtime(literal(1673383974)), f.to_timestamp(literal("2023-09-07 05:06:14.523952")), @@ -836,8 +856,8 @@ def test_case(df): result = df.collect() result = result[0] assert result.column(0) == pa.array([10, 8, 8]) - assert result.column(1) == pa.array(["Hola", "Mundo", "!!"]) - assert result.column(2) == pa.array(["Hola", "Mundo", None]) + assert result.column(1) == pa.array(["Hola", "Mundo", "!!"], type=pa.string_view()) + assert result.column(2) == pa.array(["Hola", "Mundo", None], type=pa.string_view()) def test_when_with_no_base(df): @@ -855,8 +875,10 @@ def test_when_with_no_base(df): result = df.collect() result = result[0] assert result.column(0) == pa.array([4, 5, 6]) - assert result.column(1) == pa.array(["too small", "just right", "too big"]) - assert result.column(2) == pa.array(["Hello", None, None]) + assert result.column(1) == pa.array( + ["too small", "just right", "too big"], type=pa.string_view() + ) + assert result.column(2) == pa.array(["Hello", None, None], type=pa.string_view()) def test_regr_funcs_sql(df): @@ -999,8 +1021,13 @@ def test_regr_funcs_df(func, expected): def test_binary_string_functions(df): df = df.select( - f.encode(column("a"), literal("base64")), - f.decode(f.encode(column("a"), literal("base64")), literal("base64")), + f.encode(column("a").cast(pa.string()), literal("base64").cast(pa.string())), + f.decode( + f.encode( + column("a").cast(pa.string()), literal("base64").cast(pa.string()) + ), + literal("base64").cast(pa.string()), + ), ) result = df.collect() assert len(result) == 1