/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/string_agg.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! [`StringAgg`] accumulator for the `string_agg` function |
19 | | |
20 | | use arrow::array::ArrayRef; |
21 | | use arrow_schema::DataType; |
22 | | use datafusion_common::cast::as_generic_string_array; |
23 | | use datafusion_common::Result; |
24 | | use datafusion_common::{not_impl_err, ScalarValue}; |
25 | | use datafusion_expr::function::AccumulatorArgs; |
26 | | use datafusion_expr::{ |
27 | | Accumulator, AggregateUDFImpl, Signature, TypeSignature, Volatility, |
28 | | }; |
29 | | use datafusion_physical_expr::expressions::Literal; |
30 | | use std::any::Any; |
31 | | |
32 | | make_udaf_expr_and_func!( |
33 | | StringAgg, |
34 | | string_agg, |
35 | | expr delimiter, |
36 | | "Concatenates the values of string expressions and places separator values between them", |
37 | | string_agg_udaf |
38 | | ); |
39 | | |
40 | | /// STRING_AGG aggregate expression |
41 | | #[derive(Debug)] |
42 | | pub struct StringAgg { |
43 | | signature: Signature, |
44 | | } |
45 | | |
46 | | impl StringAgg { |
47 | | /// Create a new StringAgg aggregate function |
48 | 0 | pub fn new() -> Self { |
49 | 0 | Self { |
50 | 0 | signature: Signature::one_of( |
51 | 0 | vec![ |
52 | 0 | TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), |
53 | 0 | TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), |
54 | 0 | TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]), |
55 | 0 | ], |
56 | 0 | Volatility::Immutable, |
57 | 0 | ), |
58 | 0 | } |
59 | 0 | } |
60 | | } |
61 | | |
62 | | impl Default for StringAgg { |
63 | 0 | fn default() -> Self { |
64 | 0 | Self::new() |
65 | 0 | } |
66 | | } |
67 | | |
68 | | impl AggregateUDFImpl for StringAgg { |
69 | 0 | fn as_any(&self) -> &dyn Any { |
70 | 0 | self |
71 | 0 | } |
72 | | |
73 | 0 | fn name(&self) -> &str { |
74 | 0 | "string_agg" |
75 | 0 | } |
76 | | |
77 | 0 | fn signature(&self) -> &Signature { |
78 | 0 | &self.signature |
79 | 0 | } |
80 | | |
81 | 0 | fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { |
82 | 0 | Ok(DataType::LargeUtf8) |
83 | 0 | } |
84 | | |
85 | 0 | fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
86 | 0 | if let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::<Literal>() { |
87 | 0 | return match lit.value() { |
88 | 0 | ScalarValue::Utf8(Some(delimiter)) |
89 | 0 | | ScalarValue::LargeUtf8(Some(delimiter)) => { |
90 | 0 | Ok(Box::new(StringAggAccumulator::new(delimiter.as_str()))) |
91 | | } |
92 | | ScalarValue::Utf8(None) |
93 | | | ScalarValue::LargeUtf8(None) |
94 | 0 | | ScalarValue::Null => Ok(Box::new(StringAggAccumulator::new(""))), |
95 | 0 | e => not_impl_err!("StringAgg not supported for delimiter {}", e), |
96 | | }; |
97 | 0 | } |
98 | 0 |
|
99 | 0 | not_impl_err!("expect literal") |
100 | 0 | } |
101 | | } |
102 | | |
103 | | #[derive(Debug)] |
104 | | pub(crate) struct StringAggAccumulator { |
105 | | values: Option<String>, |
106 | | delimiter: String, |
107 | | } |
108 | | |
109 | | impl StringAggAccumulator { |
110 | 0 | pub fn new(delimiter: &str) -> Self { |
111 | 0 | Self { |
112 | 0 | values: None, |
113 | 0 | delimiter: delimiter.to_string(), |
114 | 0 | } |
115 | 0 | } |
116 | | } |
117 | | |
118 | | impl Accumulator for StringAggAccumulator { |
119 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
120 | 0 | let string_array: Vec<_> = as_generic_string_array::<i64>(&values[0])? |
121 | 0 | .iter() |
122 | 0 | .filter_map(|v| v.as_ref().map(ToString::to_string)) |
123 | 0 | .collect(); |
124 | 0 | if !string_array.is_empty() { |
125 | 0 | let s = string_array.join(self.delimiter.as_str()); |
126 | 0 | let v = self.values.get_or_insert("".to_string()); |
127 | 0 | if !v.is_empty() { |
128 | 0 | v.push_str(self.delimiter.as_str()); |
129 | 0 | } |
130 | 0 | v.push_str(s.as_str()); |
131 | 0 | } |
132 | 0 | Ok(()) |
133 | 0 | } |
134 | | |
135 | 0 | fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
136 | 0 | self.update_batch(values)?; |
137 | 0 | Ok(()) |
138 | 0 | } |
139 | | |
140 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
141 | 0 | Ok(vec![self.evaluate()?]) |
142 | 0 | } |
143 | | |
144 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
145 | 0 | Ok(ScalarValue::LargeUtf8(self.values.clone())) |
146 | 0 | } |
147 | | |
148 | 0 | fn size(&self) -> usize { |
149 | 0 | std::mem::size_of_val(self) |
150 | 0 | + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0) |
151 | 0 | + self.delimiter.capacity() |
152 | 0 | } |
153 | | } |