Skip to content

Commit

Permalink
fix lit function to allow multiple types (#1130)
Browse files Browse the repository at this point in the history
  • Loading branch information
jimexist authored Oct 17, 2021
1 parent 4fa7e64 commit 4b15fa5
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 7 deletions.
22 changes: 17 additions & 5 deletions python/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ use crate::{expression, types::PyDataType};
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_plan;
use datafusion::physical_plan::functions::Volatility;
use pyo3::{prelude::*, types::PyTuple, wrap_pyfunction};
use pyo3::{
exceptions::PyTypeError, prelude::*, types::PyTuple, wrap_pyfunction, Python,
};
use std::sync::Arc;

/// Expression representing a column on the existing plan.
Expand All @@ -36,10 +38,20 @@ fn col(name: &str) -> expression::Expression {
/// Expression representing a constant value
#[pyfunction]
#[pyo3(text_signature = "(value)")]
fn lit(value: i32) -> expression::Expression {
expression::Expression {
expr: logical_plan::lit(value),
}
fn lit(value: &PyAny) -> PyResult<expression::Expression> {
let expr = if let Ok(v) = value.extract::<i64>() {
logical_plan::lit(v)
} else if let Ok(v) = value.extract::<f64>() {
logical_plan::lit(v)
} else if let Ok(v) = value.extract::<String>() {
logical_plan::lit(v)
} else {
return Err(PyTypeError::new_err(format!(
"Unsupported value {}, expected one of i64, f64, or String type",
value
)));
};
Ok(expression::Expression { expr })
}

#[pyfunction]
Expand Down
2 changes: 1 addition & 1 deletion python/tests/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import numpy as np
import pyarrow as pa
import pyarrow.csv
import pyarrow.parquet as pq

# used to write parquet files
import pyarrow.parquet as pq


def data():
Expand Down
1 change: 0 additions & 1 deletion python/tests/test_df_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def ctx():


def test_register_record_batches(ctx):

# create a RecordBatch and register it as memtable
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
Expand Down
54 changes: 54 additions & 0 deletions python/tests/test_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import pyarrow as pa
import pytest
from datafusion import ExecutionContext
from datafusion import functions as f


@pytest.fixture
def df():
ctx = ExecutionContext()
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[pa.array(["Hello", "World", "!"]), pa.array([4, 5, 6])],
names=["a", "b"],
)
return ctx.create_dataframe([[batch]])


def test_lit(df):
"""test lit function"""
df = df.select(f.lit(1), f.lit("1"), f.lit("OK"), f.lit(3.14))
result = df.collect()
assert len(result) == 1
result = result[0]
assert result.column(0) == pa.array([1] * 3)
assert result.column(1) == pa.array(["1"] * 3)
assert result.column(2) == pa.array(["OK"] * 3)
assert result.column(3) == pa.array([3.14] * 3)


def test_lit_arith(df):
"""test lit function within arithmatics"""
df = df.select(f.lit(1) + f.col("b"), f.concat(f.col("a"), f.lit("!")))
result = df.collect()
assert len(result) == 1
result = result[0]
assert result.column(0) == pa.array([5, 6, 7])
assert result.column(1) == pa.array(["Hello!", "World!", "!!"])

0 comments on commit 4b15fa5

Please sign in to comment.