From 4b15fa56067af50c593049c6cf43b79cbfa6d183 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Sun, 17 Oct 2021 22:42:51 +0800 Subject: [PATCH] fix lit function to allow multiple types (#1130) --- python/src/functions.rs | 22 ++++++++++---- python/tests/generic.py | 2 +- python/tests/test_df_sql.py | 1 - python/tests/test_functions.py | 54 ++++++++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 7 deletions(-) create mode 100644 python/tests/test_functions.py diff --git a/python/src/functions.rs b/python/src/functions.rs index 6633f0afefaa..8611ca54b566 100644 --- a/python/src/functions.rs +++ b/python/src/functions.rs @@ -21,7 +21,9 @@ use crate::{expression, types::PyDataType}; use datafusion::arrow::datatypes::DataType; use datafusion::logical_plan; use datafusion::physical_plan::functions::Volatility; -use pyo3::{prelude::*, types::PyTuple, wrap_pyfunction}; +use pyo3::{ + exceptions::PyTypeError, prelude::*, types::PyTuple, wrap_pyfunction, Python, +}; use std::sync::Arc; /// Expression representing a column on the existing plan. @@ -36,10 +38,20 @@ fn col(name: &str) -> expression::Expression { /// Expression representing a constant value #[pyfunction] #[pyo3(text_signature = "(value)")] -fn lit(value: i32) -> expression::Expression { - expression::Expression { - expr: logical_plan::lit(value), - } +fn lit(value: &PyAny) -> PyResult { + let expr = if let Ok(v) = value.extract::() { + logical_plan::lit(v) + } else if let Ok(v) = value.extract::() { + logical_plan::lit(v) + } else if let Ok(v) = value.extract::() { + logical_plan::lit(v) + } else { + return Err(PyTypeError::new_err(format!( + "Unsupported value {}, expected one of i64, f64, or String type", + value + ))); + }; + Ok(expression::Expression { expr }) } #[pyfunction] diff --git a/python/tests/generic.py b/python/tests/generic.py index 8d5adaaaf956..1f984a40adaa 100644 --- a/python/tests/generic.py +++ b/python/tests/generic.py @@ -20,9 +20,9 @@ import numpy as np import pyarrow as pa import pyarrow.csv -import pyarrow.parquet as pq # used to write parquet files +import pyarrow.parquet as pq def data(): diff --git a/python/tests/test_df_sql.py b/python/tests/test_df_sql.py index ebc38b16427e..c6eac6bb2ffc 100644 --- a/python/tests/test_df_sql.py +++ b/python/tests/test_df_sql.py @@ -26,7 +26,6 @@ def ctx(): def test_register_record_batches(ctx): - # create a RecordBatch and register it as memtable batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py new file mode 100644 index 000000000000..c6c1cf6905f8 --- /dev/null +++ b/python/tests/test_functions.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pyarrow as pa +import pytest +from datafusion import ExecutionContext +from datafusion import functions as f + + +@pytest.fixture +def df(): + ctx = ExecutionContext() + # create a RecordBatch and a new DataFrame from it + batch = pa.RecordBatch.from_arrays( + [pa.array(["Hello", "World", "!"]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + return ctx.create_dataframe([[batch]]) + + +def test_lit(df): + """test lit function""" + df = df.select(f.lit(1), f.lit("1"), f.lit("OK"), f.lit(3.14)) + result = df.collect() + assert len(result) == 1 + result = result[0] + assert result.column(0) == pa.array([1] * 3) + assert result.column(1) == pa.array(["1"] * 3) + assert result.column(2) == pa.array(["OK"] * 3) + assert result.column(3) == pa.array([3.14] * 3) + + +def test_lit_arith(df): + """test lit function within arithmatics""" + df = df.select(f.lit(1) + f.col("b"), f.concat(f.col("a"), f.lit("!"))) + result = df.collect() + assert len(result) == 1 + result = result[0] + assert result.column(0) == pa.array([5, 6, 7]) + assert result.column(1) == pa.array(["Hello!", "World!", "!!"])