Skip to content

Commit

Permalink
feat: Add map_extract module and function (#11969)
Browse files Browse the repository at this point in the history
* feat: Add map_extract module and function

* chore: Fix fmt

* chore: Add tests

* chore: Simplify

* chore: Simplify

* chore: Fix clippy

* doc: Add user doc

* feat: use Signature::user_defined

* chore: Update tests

* chore: Fix fmt

* chore: Fix clippy

* chore

* chore: typo

* chore: Check args len in return_type

* doc: Update doc

* chore: Simplify logic

* chore: check args earlier

* feat: Support UTF8VIEW

* chore: Update doc

* chore: Fic clippy

* refacotr: Use MutableArrayData

* chore

* refactor: Avoid type conversion

* chore: Fix clippy

* chore: Follow DuckDB

* Update datafusion/functions-nested/src/map_extract.rs

Co-authored-by: Jay Zhan <[email protected]>

* chore: Fix fmt

---------

Co-authored-by: Jay Zhan <[email protected]>
  • Loading branch information
Weijun-H and jayzhan211 authored Aug 17, 2024
1 parent b06e8b0 commit 72b6a49
Show file tree
Hide file tree
Showing 5 changed files with 302 additions and 1 deletion.
17 changes: 16 additions & 1 deletion datafusion/common/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use arrow_array::{
Array, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait,
RecordBatchOptions,
};
use arrow_schema::DataType;
use arrow_schema::{DataType, Fields};
use sqlparser::ast::Ident;
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
Expand Down Expand Up @@ -753,6 +753,21 @@ pub fn combine_limit(
(combined_skip, combined_fetch)
}

pub fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> {
match data_type {
DataType::Map(field, _) => {
let field_data_type = field.data_type();
match field_data_type {
DataType::Struct(fields) => Ok(fields),
_ => {
_internal_err!("Expected a Struct type, got {:?}", field_data_type)
}
}
}
_ => _internal_err!("Expected a Map type, got {:?}", data_type),
}
}

#[cfg(test)]
mod tests {
use crate::ScalarValue::Null;
Expand Down
3 changes: 3 additions & 0 deletions datafusion/functions-nested/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub mod flatten;
pub mod length;
pub mod make_array;
pub mod map;
pub mod map_extract;
pub mod planner;
pub mod position;
pub mod range;
Expand Down Expand Up @@ -81,6 +82,7 @@ pub mod expr_fn {
pub use super::flatten::flatten;
pub use super::length::array_length;
pub use super::make_array::make_array;
pub use super::map_extract::map_extract;
pub use super::position::array_position;
pub use super::position::array_positions;
pub use super::range::gen_series;
Expand Down Expand Up @@ -143,6 +145,7 @@ pub fn all_default_nested_functions() -> Vec<Arc<ScalarUDF>> {
replace::array_replace_all_udf(),
replace::array_replace_udf(),
map::map_udf(),
map_extract::map_extract_udf(),
]
}

Expand Down
173 changes: 173 additions & 0 deletions datafusion/functions-nested/src/map_extract.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! [`ScalarUDFImpl`] definitions for map_extract functions.
use arrow::array::{ArrayRef, Capacities, MutableArrayData};
use arrow_array::{make_array, ListArray};

use arrow::datatypes::DataType;
use arrow_array::{Array, MapArray};
use arrow_buffer::OffsetBuffer;
use arrow_schema::Field;
use datafusion_common::utils::get_map_entry_field;

use datafusion_common::{cast::as_map_array, exec_err, Result};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::Arc;
use std::vec;

use crate::utils::make_scalar_function;

// Create static instances of ScalarUDFs for each function
make_udf_expr_and_func!(
MapExtract,
map_extract,
map key,
"Return a list containing the value for a given key or an empty list if the key is not contained in the map.",
map_extract_udf
);

#[derive(Debug)]
pub(super) struct MapExtract {
signature: Signature,
aliases: Vec<String>,
}

impl MapExtract {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
aliases: vec![String::from("element_at")],
}
}
}

impl ScalarUDFImpl for MapExtract {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"map_extract"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types.len() != 2 {
return exec_err!("map_extract expects two arguments");
}
let map_type = &arg_types[0];
let map_fields = get_map_entry_field(map_type)?;
Ok(DataType::List(Arc::new(Field::new(
"item",
map_fields.last().unwrap().data_type().clone(),
true,
))))
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
make_scalar_function(map_extract_inner)(args)
}

fn aliases(&self) -> &[String] {
&self.aliases
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 2 {
return exec_err!("map_extract expects two arguments");
}

let field = get_map_entry_field(&arg_types[0])?;
Ok(vec![
arg_types[0].clone(),
field.first().unwrap().data_type().clone(),
])
}
}

fn general_map_extract_inner(
map_array: &MapArray,
query_keys_array: &dyn Array,
) -> Result<ArrayRef> {
let keys = map_array.keys();
let mut offsets = vec![0_i32];

let values = map_array.values();
let original_data = values.to_data();
let capacity = Capacities::Array(original_data.len());

let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], true, capacity);

for (row_index, offset_window) in map_array.value_offsets().windows(2).enumerate() {
let start = offset_window[0] as usize;
let end = offset_window[1] as usize;
let len = end - start;

let query_key = query_keys_array.slice(row_index, 1);

let value_index =
(0..len).find(|&i| keys.slice(start + i, 1).as_ref() == query_key.as_ref());

match value_index {
Some(index) => {
mutable.extend(0, start + index, start + index + 1);
}
None => {
mutable.extend_nulls(1);
}
}
offsets.push(offsets[row_index] + 1);
}

let data = mutable.freeze();

Ok(Arc::new(ListArray::new(
Arc::new(Field::new("item", map_array.value_type().clone(), true)),
OffsetBuffer::<i32>::new(offsets.into()),
Arc::new(make_array(data)),
None,
)))
}

fn map_extract_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("map_extract expects two arguments");
}

let map_array = match args[0].data_type() {
DataType::Map(_, _) => as_map_array(&args[0])?,
_ => return exec_err!("The first argument in map_extract must be a map"),
};

let key_type = map_array.key_type();

if key_type != args[1].data_type() {
return exec_err!(
"The key type {} does not match the map key type {}",
args[1].data_type(),
key_type
);
}

general_map_extract_inner(map_array, &args[1])
}
81 changes: 81 additions & 0 deletions datafusion/sqllogictest/test_files/map.slt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,22 @@
# specific language governing permissions and limitations
# under the License.

statement ok
CREATE TABLE map_array_table_1
AS VALUES
(MAP {1: [1, NULL, 3], 2: [4, NULL, 6], 3: [7, 8, 9]}, 1, 1.0, '1'),
(MAP {4: [1, NULL, 3], 5: [4, NULL, 6], 6: [7, 8, 9]}, 5, 5.0, '5'),
(MAP {7: [1, NULL, 3], 8: [9, NULL, 6], 9: [7, 8, 9]}, 4, 4.0, '4')
;

statement ok
CREATE TABLE map_array_table_2
AS VALUES
(MAP {'1': [1, NULL, 3], '2': [4, NULL, 6], '3': [7, 8, 9]}, 1, 1.0, '1'),
(MAP {'4': [1, NULL, 3], '5': [4, NULL, 6], '6': [7, 8, 9]}, 5, 5.0, '5'),
(MAP {'7': [1, NULL, 3], '8': [9, NULL, 6], '9': [7, 8, 9]}, 4, 4.0, '4')
;

statement ok
CREATE EXTERNAL TABLE data
STORED AS PARQUET
Expand Down Expand Up @@ -493,3 +509,68 @@ select cardinality(map([1, 2, 3], ['a', 'b', 'c'])), cardinality(MAP {'a': 1, 'b
cardinality(MAP {'a': MAP {1:'a', 2:'b', 3:'c'}, 'b': MAP {2:'c', 4:'d'} });
----
3 2 0 2

# map_extract
# key is string
query ????
select map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'a'), map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'b'),
map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'c'), map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'd');
----
[1] [] [3] []

# key is integer
query ????
select map_extract(MAP {1: 1, 2: NULL, 3:3}, 1), map_extract(MAP {1: 1, 2: NULL, 3:3}, 2),
map_extract(MAP {1: 1, 2: NULL, 3:3}, 3), map_extract(MAP {1: 1, 2: NULL, 3:3}, 4);
----
[1] [] [3] []

# value is list
query ????
select map_extract(MAP {1: [1, 2], 2: NULL, 3:[3]}, 1), map_extract(MAP {1: [1, 2], 2: NULL, 3:[3]}, 2),
map_extract(MAP {1: [1, 2], 2: NULL, 3:[3]}, 3), map_extract(MAP {1: [1, 2], 2: NULL, 3:[3]}, 4);
----
[[1, 2]] [] [[3]] []

# key in map and query key are different types
query ?????
select map_extract(MAP {1: 1, 2: 2, 3:3}, '1'), map_extract(MAP {1: 1, 2: 2, 3:3}, 1.0),
map_extract(MAP {1.0: 1, 2: 2, 3:3}, '1'), map_extract(MAP {'1': 1, '2': 2, '3':3}, 1.0),
map_extract(MAP {arrow_cast('1', 'Utf8View'): 1, arrow_cast('2', 'Utf8View'): 2, arrow_cast('3', 'Utf8View'):3}, '1');
----
[1] [1] [1] [] [1]

# map_extract with columns
query ???
select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7) from map_array_table_1;
----
[[1, , 3]] [] []
[] [[4, , 6]] []
[] [] [[1, , 3]]

query ???
select map_extract(column1, column2), map_extract(column1, column3), map_extract(column1, column4) from map_array_table_1;
----
[[1, , 3]] [[1, , 3]] [[1, , 3]]
[[4, , 6]] [[4, , 6]] [[4, , 6]]
[] [] []

query ???
select map_extract(column1, column2), map_extract(column1, column3), map_extract(column1, column4) from map_array_table_2;
----
[[1, , 3]] [] [[1, , 3]]
[[4, , 6]] [] [[4, , 6]]
[] [] []

query ???
select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7) from map_array_table_2;
----
[[1, , 3]] [] []
[] [[4, , 6]] []
[] [] [[1, , 3]]

statement ok
drop table map_array_table_1;

statement ok
drop table map_array_table_2;
29 changes: 29 additions & 0 deletions docs/source/user-guide/sql/scalar_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -3640,6 +3640,7 @@ Unwraps struct fields into columns.

- [map](#map)
- [make_map](#make_map)
- [map_extract](#map_extract)

### `map`

Expand Down Expand Up @@ -3700,6 +3701,34 @@ SELECT MAKE_MAP('POST', 41, 'HEAD', 33, 'PATCH', null);
{POST: 41, HEAD: 33, PATCH: }
```

### `map_extract`

Return a list containing the value for a given key or an empty list if the key is not contained in the map.

```
map_extract(map, key)
```

#### Arguments

- `map`: Map expression.
Can be a constant, column, or function, and any combination of map operators.
- `key`: Key to extract from the map.
Can be a constant, column, or function, any combination of arithmetic or
string operators, or a named expression of previous listed.

#### Example

```
SELECT map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'a');
----
[1]
```

#### Aliases

- element_at

## Hashing Functions

- [digest](#digest)
Expand Down

0 comments on commit 72b6a49

Please sign in to comment.