Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/string_agg.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! [`StringAgg`] accumulator for the `string_agg` function
19
20
use arrow::array::ArrayRef;
21
use arrow_schema::DataType;
22
use datafusion_common::cast::as_generic_string_array;
23
use datafusion_common::Result;
24
use datafusion_common::{not_impl_err, ScalarValue};
25
use datafusion_expr::function::AccumulatorArgs;
26
use datafusion_expr::{
27
    Accumulator, AggregateUDFImpl, Signature, TypeSignature, Volatility,
28
};
29
use datafusion_physical_expr::expressions::Literal;
30
use std::any::Any;
31
32
make_udaf_expr_and_func!(
33
    StringAgg,
34
    string_agg,
35
    expr delimiter,
36
    "Concatenates the values of string expressions and places separator values between them",
37
    string_agg_udaf
38
);
39
40
/// STRING_AGG aggregate expression
41
#[derive(Debug)]
42
pub struct StringAgg {
43
    signature: Signature,
44
}
45
46
impl StringAgg {
47
    /// Create a new StringAgg aggregate function
48
0
    pub fn new() -> Self {
49
0
        Self {
50
0
            signature: Signature::one_of(
51
0
                vec![
52
0
                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
53
0
                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
54
0
                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]),
55
0
                ],
56
0
                Volatility::Immutable,
57
0
            ),
58
0
        }
59
0
    }
60
}
61
62
impl Default for StringAgg {
63
0
    fn default() -> Self {
64
0
        Self::new()
65
0
    }
66
}
67
68
impl AggregateUDFImpl for StringAgg {
69
0
    fn as_any(&self) -> &dyn Any {
70
0
        self
71
0
    }
72
73
0
    fn name(&self) -> &str {
74
0
        "string_agg"
75
0
    }
76
77
0
    fn signature(&self) -> &Signature {
78
0
        &self.signature
79
0
    }
80
81
0
    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
82
0
        Ok(DataType::LargeUtf8)
83
0
    }
84
85
0
    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
86
0
        if let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::<Literal>() {
87
0
            return match lit.value() {
88
0
                ScalarValue::Utf8(Some(delimiter))
89
0
                | ScalarValue::LargeUtf8(Some(delimiter)) => {
90
0
                    Ok(Box::new(StringAggAccumulator::new(delimiter.as_str())))
91
                }
92
                ScalarValue::Utf8(None)
93
                | ScalarValue::LargeUtf8(None)
94
0
                | ScalarValue::Null => Ok(Box::new(StringAggAccumulator::new(""))),
95
0
                e => not_impl_err!("StringAgg not supported for delimiter {}", e),
96
            };
97
0
        }
98
0
99
0
        not_impl_err!("expect literal")
100
0
    }
101
}
102
103
#[derive(Debug)]
104
pub(crate) struct StringAggAccumulator {
105
    values: Option<String>,
106
    delimiter: String,
107
}
108
109
impl StringAggAccumulator {
110
0
    pub fn new(delimiter: &str) -> Self {
111
0
        Self {
112
0
            values: None,
113
0
            delimiter: delimiter.to_string(),
114
0
        }
115
0
    }
116
}
117
118
impl Accumulator for StringAggAccumulator {
119
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
120
0
        let string_array: Vec<_> = as_generic_string_array::<i64>(&values[0])?
121
0
            .iter()
122
0
            .filter_map(|v| v.as_ref().map(ToString::to_string))
123
0
            .collect();
124
0
        if !string_array.is_empty() {
125
0
            let s = string_array.join(self.delimiter.as_str());
126
0
            let v = self.values.get_or_insert("".to_string());
127
0
            if !v.is_empty() {
128
0
                v.push_str(self.delimiter.as_str());
129
0
            }
130
0
            v.push_str(s.as_str());
131
0
        }
132
0
        Ok(())
133
0
    }
134
135
0
    fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
136
0
        self.update_batch(values)?;
137
0
        Ok(())
138
0
    }
139
140
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
141
0
        Ok(vec![self.evaluate()?])
142
0
    }
143
144
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
145
0
        Ok(ScalarValue::LargeUtf8(self.values.clone()))
146
0
    }
147
148
0
    fn size(&self) -> usize {
149
0
        std::mem::size_of_val(self)
150
0
            + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0)
151
0
            + self.delimiter.capacity()
152
0
    }
153
}