From f2344d25caa897988c519fdd19362888742d37b4 Mon Sep 17 00:00:00 2001 From: Leslie Su <3530611790@qq.com> Date: Wed, 6 Nov 2024 02:22:20 +0800 Subject: [PATCH] feat: Add `Time`/`Interval`/`Decimal`/`Utf8View` in aggregate fuzz testing (#13226) * support Time/Interval/Decimal types in data generator. * introduce RandomNativeData trait. * fix bug. * support utf8view type in data generator. * fix clippy. * fix bug. --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 47 ++++- .../aggregation_fuzzer/data_generator.rs | 173 +++++++++++++++--- .../src/min_max/min_max_bytes.rs | 4 + test-utils/src/array_gen/decimal.rs | 79 ++++++++ test-utils/src/array_gen/mod.rs | 3 + test-utils/src/array_gen/primitive.rs | 90 +++------ test-utils/src/array_gen/random_data.rs | 102 +++++++++++ test-utils/src/array_gen/string.rs | 28 ++- 8 files changed, 433 insertions(+), 93 deletions(-) create mode 100644 test-utils/src/array_gen/decimal.rs create mode 100644 test-utils/src/array_gen/random_data.rs diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 4cb2b1bfbc5c..16f539b75967 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -23,6 +23,10 @@ use arrow::datatypes::DataType; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use arrow_array::types::Int64Type; +use arrow_schema::{ + IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, +}; use datafusion::common::Result; use datafusion::datasource::MemTable; use datafusion::physical_expr::aggregate::AggregateExprBuilder; @@ -45,7 +49,7 @@ use crate::fuzz_cases::aggregation_fuzzer::{ use datafusion_common::HashMap; use datafusion_physical_expr_common::sort_expr::LexOrdering; use rand::rngs::StdRng; -use rand::{Rng, SeedableRng}; +use rand::{thread_rng, Rng, SeedableRng}; use tokio::task::JoinSet; // ======================================================================== @@ -151,6 +155,7 @@ async fn test_count() { /// 1. Floating point numbers /// 1. structured types fn baseline_config() -> DatasetGeneratorConfig { + let mut rng = thread_rng(); let columns = vec![ ColumnDescr::new("i8", DataType::Int8), ColumnDescr::new("i16", DataType::Int16), @@ -162,13 +167,45 @@ fn baseline_config() -> DatasetGeneratorConfig { ColumnDescr::new("u64", DataType::UInt64), ColumnDescr::new("date32", DataType::Date32), ColumnDescr::new("date64", DataType::Date64), - // TODO: date/time columns - // todo decimal columns + ColumnDescr::new("time32_s", DataType::Time32(TimeUnit::Second)), + ColumnDescr::new("time32_ms", DataType::Time32(TimeUnit::Millisecond)), + ColumnDescr::new("time64_us", DataType::Time64(TimeUnit::Microsecond)), + ColumnDescr::new("time64_ns", DataType::Time64(TimeUnit::Nanosecond)), + ColumnDescr::new( + "interval_year_month", + DataType::Interval(IntervalUnit::YearMonth), + ), + ColumnDescr::new( + "interval_day_time", + DataType::Interval(IntervalUnit::DayTime), + ), + ColumnDescr::new( + "interval_month_day_nano", + DataType::Interval(IntervalUnit::MonthDayNano), + ), + // begin decimal columns + ColumnDescr::new("decimal128", { + // Generate valid precision and scale for Decimal128 randomly. + let precision: u8 = rng.gen_range(1..=DECIMAL128_MAX_PRECISION); + // It's safe to cast `precision` to i8 type directly. + let scale: i8 = rng.gen_range( + i8::MIN..=std::cmp::min(precision as i8, DECIMAL128_MAX_SCALE), + ); + DataType::Decimal128(precision, scale) + }), + ColumnDescr::new("decimal256", { + // Generate valid precision and scale for Decimal256 randomly. + let precision: u8 = rng.gen_range(1..=DECIMAL256_MAX_PRECISION); + // It's safe to cast `precision` to i8 type directly. + let scale: i8 = rng.gen_range( + i8::MIN..=std::cmp::min(precision as i8, DECIMAL256_MAX_SCALE), + ); + DataType::Decimal256(precision, scale) + }), // begin string columns ColumnDescr::new("utf8", DataType::Utf8), ColumnDescr::new("largeutf8", DataType::LargeUtf8), - // TODO add support for utf8view in data generator - // ColumnDescr::new("utf8view", DataType::Utf8View), + ColumnDescr::new("utf8view", DataType::Utf8View), // todo binary // low cardinality columns ColumnDescr::new("u8_low", DataType::UInt8).with_max_num_distinct(10), diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs index aafa5ed7f66b..88133a134e4d 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs @@ -18,11 +18,14 @@ use std::sync::Arc; use arrow::datatypes::{ - Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, - Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + ByteArrayType, ByteViewType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, LargeUtf8Type, + StringViewType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, Utf8Type, }; use arrow_array::{ArrayRef, RecordBatch}; -use arrow_schema::{DataType, Field, Schema}; +use arrow_schema::{DataType, Field, IntervalUnit, Schema, TimeUnit}; use datafusion_common::{arrow_datafusion_err, DataFusionError, Result}; use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; @@ -32,7 +35,7 @@ use rand::{ thread_rng, Rng, SeedableRng, }; use test_utils::{ - array_gen::{PrimitiveArrayGenerator, StringArrayGenerator}, + array_gen::{DecimalArrayGenerator, PrimitiveArrayGenerator, StringArrayGenerator}, stagger_batch, }; @@ -219,7 +222,7 @@ struct RecordBatchGenerator { } macro_rules! generate_string_array { - ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $OFFSET_TYPE:ty) => {{ + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE: ident) => {{ let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; let max_len = $BATCH_GEN_RNG.gen_range(1..50); @@ -232,25 +235,47 @@ macro_rules! generate_string_array { rng: $ARRAY_GEN_RNG, }; - generator.gen_data::<$OFFSET_TYPE>() + match $ARROW_TYPE::DATA_TYPE { + DataType::Utf8 => generator.gen_data::(), + DataType::LargeUtf8 => generator.gen_data::(), + DataType::Utf8View => generator.gen_string_view(), + _ => unreachable!(), + } + }}; +} + +macro_rules! generate_decimal_array { + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT: expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $PRECISION: ident, $SCALE: ident, $ARROW_TYPE: ident) => {{ + let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); + let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; + + let mut generator = DecimalArrayGenerator { + precision: $PRECISION, + scale: $SCALE, + num_decimals: $NUM_ROWS, + num_distinct_decimals: $MAX_NUM_DISTINCT, + null_pct, + rng: $ARRAY_GEN_RNG, + }; + + generator.gen_data::<$ARROW_TYPE>() }}; } macro_rules! generate_primitive_array { - ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE:ident) => { - paste::paste! {{ - let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); - let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; - - let mut generator = PrimitiveArrayGenerator { - num_primitives: $NUM_ROWS, - num_distinct_primitives: $MAX_NUM_DISTINCT, - null_pct, - rng: $ARRAY_GEN_RNG, - }; - - generator.gen_data::<$ARROW_TYPE>() - }}} + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE:ident) => {{ + let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); + let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; + + let mut generator = PrimitiveArrayGenerator { + num_primitives: $NUM_ROWS, + num_distinct_primitives: $MAX_NUM_DISTINCT, + null_pct, + rng: $ARRAY_GEN_RNG, + }; + + generator.gen_data::<$ARROW_TYPE>() + }}; } impl RecordBatchGenerator { @@ -432,6 +457,100 @@ impl RecordBatchGenerator { Date64Type ) } + DataType::Time32(TimeUnit::Second) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Time32SecondType + ) + } + DataType::Time32(TimeUnit::Millisecond) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Time32MillisecondType + ) + } + DataType::Time64(TimeUnit::Microsecond) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Time64MicrosecondType + ) + } + DataType::Time64(TimeUnit::Nanosecond) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Time64NanosecondType + ) + } + DataType::Interval(IntervalUnit::YearMonth) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + IntervalYearMonthType + ) + } + DataType::Interval(IntervalUnit::DayTime) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + IntervalDayTimeType + ) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + IntervalMonthDayNanoType + ) + } + DataType::Decimal128(precision, scale) => { + generate_decimal_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + precision, + scale, + Decimal128Type + ) + } + DataType::Decimal256(precision, scale) => { + generate_decimal_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + precision, + scale, + Decimal256Type + ) + } DataType::Utf8 => { generate_string_array!( self, @@ -439,7 +558,7 @@ impl RecordBatchGenerator { max_num_distinct, batch_gen_rng, array_gen_rng, - i32 + Utf8Type ) } DataType::LargeUtf8 => { @@ -449,7 +568,17 @@ impl RecordBatchGenerator { max_num_distinct, batch_gen_rng, array_gen_rng, - i64 + LargeUtf8Type + ) + } + DataType::Utf8View => { + generate_string_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + StringViewType ) } _ => { diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs index 501454edf77c..a09d616ec822 100644 --- a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -338,6 +338,10 @@ impl GroupsAccumulator for MinMaxBytesAccumulator { /// This is a heuristic to avoid allocating too many small buffers fn capacity_to_view_block_size(data_capacity: usize) -> u32 { let max_block_size = 2 * 1024 * 1024; + // Avoid block size equal to zero when calling `with_fixed_block_size()`. + if data_capacity == 0 { + return 1; + } if let Ok(block_size) = u32::try_from(data_capacity) { block_size.min(max_block_size) } else { diff --git a/test-utils/src/array_gen/decimal.rs b/test-utils/src/array_gen/decimal.rs new file mode 100644 index 000000000000..f878a830c4eb --- /dev/null +++ b/test-utils/src/array_gen/decimal.rs @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, PrimitiveArray, PrimitiveBuilder, UInt32Array}; +use arrow::datatypes::DecimalType; +use rand::rngs::StdRng; +use rand::Rng; + +use super::random_data::RandomNativeData; + +/// Randomly generate decimal arrays +pub struct DecimalArrayGenerator { + /// The precision of the decimal type + pub precision: u8, + /// The scale of the decimal type + pub scale: i8, + /// The total number of decimals in the output + pub num_decimals: usize, + /// The number of distinct decimals in the columns + pub num_distinct_decimals: usize, + /// The percentage of nulls in the columns + pub null_pct: f64, + /// Random number generator + pub rng: StdRng, +} + +impl DecimalArrayGenerator { + /// Create a Decimal128Array / Decimal256Array with random values. + pub fn gen_data(&mut self) -> ArrayRef + where + D: DecimalType + RandomNativeData, + { + // table of decimals from which to draw + let distinct_decimals: PrimitiveArray = { + let mut decimal_builder = + PrimitiveBuilder::::with_capacity(self.num_distinct_decimals); + for _ in 0..self.num_distinct_decimals { + decimal_builder + .append_option(Some(D::generate_random_native_data(&mut self.rng))); + } + + decimal_builder + .finish() + .with_precision_and_scale(self.precision, self.scale) + .unwrap() + }; + + // pick num_decimals randomly from the distinct decimal table + let indicies: UInt32Array = (0..self.num_decimals) + .map(|_| { + if self.rng.gen::() < self.null_pct { + None + } else if self.num_distinct_decimals > 1 { + let range = 1..(self.num_distinct_decimals as u32); + Some(self.rng.gen_range(range)) + } else { + Some(0) + } + }) + .collect(); + + let options = None; + arrow::compute::take(&distinct_decimals, &indicies, options).unwrap() + } +} diff --git a/test-utils/src/array_gen/mod.rs b/test-utils/src/array_gen/mod.rs index 4a799ae737d7..8e0e39ddfdce 100644 --- a/test-utils/src/array_gen/mod.rs +++ b/test-utils/src/array_gen/mod.rs @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. +mod decimal; mod primitive; +mod random_data; mod string; +pub use decimal::DecimalArrayGenerator; pub use primitive::PrimitiveArrayGenerator; pub use string::StringArrayGenerator; diff --git a/test-utils/src/array_gen/primitive.rs b/test-utils/src/array_gen/primitive.rs index 0581862d63bd..2469cbf44660 100644 --- a/test-utils/src/array_gen/primitive.rs +++ b/test-utils/src/array_gen/primitive.rs @@ -17,42 +17,10 @@ use arrow::array::{ArrayRef, ArrowPrimitiveType, PrimitiveArray, UInt32Array}; use arrow::datatypes::DataType; -use rand::distributions::Standard; -use rand::prelude::Distribution; use rand::rngs::StdRng; use rand::Rng; -/// Trait for converting type safely from a native type T impl this trait. -pub trait FromNative: std::fmt::Debug + Send + Sync + Copy + Default { - /// Convert native type from i64. - fn from_i64(_: i64) -> Option { - None - } -} - -macro_rules! native_type { - ($t: ty $(, $from:ident)*) => { - impl FromNative for $t { - $( - #[inline] - fn $from(v: $t) -> Option { - Some(v) - } - )* - } - }; -} - -native_type!(i8); -native_type!(i16); -native_type!(i32); -native_type!(i64, from_i64); -native_type!(u8); -native_type!(u16); -native_type!(u32); -native_type!(u64); -native_type!(f32); -native_type!(f64); +use super::random_data::RandomNativeData; /// Randomly generate primitive array pub struct PrimitiveArrayGenerator { @@ -70,41 +38,33 @@ pub struct PrimitiveArrayGenerator { impl PrimitiveArrayGenerator { pub fn gen_data(&mut self) -> ArrayRef where - A: ArrowPrimitiveType, - A::Native: FromNative, - Standard: Distribution<::Native>, + A: ArrowPrimitiveType + RandomNativeData, { // table of primitives from which to draw - let distinct_primitives: PrimitiveArray = (0..self.num_distinct_primitives) - .map(|_| { - Some(match A::DATA_TYPE { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 - | DataType::Date32 => self.rng.gen::(), - - DataType::Date64 => { - // TODO: constrain this range to valid dates if necessary - let date_value = self.rng.gen_range(i64::MIN..=i64::MAX); - let millis_per_day = 86_400_000; - let adjusted_value = date_value - (date_value % millis_per_day); - A::Native::from_i64(adjusted_value).unwrap() - } + let distinct_primitives: PrimitiveArray = match A::DATA_TYPE { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Interval(_) => (0..self.num_distinct_primitives) + .map(|_| Some(A::generate_random_native_data(&mut self.rng))) + .collect(), - _ => { - let arrow_type = A::DATA_TYPE; - panic!("Unsupported arrow data type: {arrow_type}") - } - }) - }) - .collect(); + _ => { + let arrow_type = A::DATA_TYPE; + panic!("Unsupported arrow data type: {arrow_type}") + } + }; // pick num_primitves randomly from the distinct string table let indicies: UInt32Array = (0..self.num_primitives) diff --git a/test-utils/src/array_gen/random_data.rs b/test-utils/src/array_gen/random_data.rs new file mode 100644 index 000000000000..23227100d73f --- /dev/null +++ b/test-utils/src/array_gen/random_data.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrowPrimitiveType; +use arrow::datatypes::{ + i256, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float32Type, + Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTime, + IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, + IntervalYearMonthType, Time32MillisecondType, Time32SecondType, + Time64MicrosecondType, Time64NanosecondType, UInt16Type, UInt32Type, UInt64Type, + UInt8Type, +}; +use rand::distributions::Standard; +use rand::prelude::Distribution; +use rand::rngs::StdRng; +use rand::Rng; + +/// Generate corresponding NativeType value randomly according to +/// ArrowPrimitiveType. +pub trait RandomNativeData: ArrowPrimitiveType { + fn generate_random_native_data(rng: &mut StdRng) -> Self::Native; +} + +macro_rules! basic_random_data { + ($ARROW_TYPE: ty) => { + impl RandomNativeData for $ARROW_TYPE + where + Standard: Distribution, + { + #[inline] + fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { + rng.gen::() + } + } + }; +} + +basic_random_data!(Int8Type); +basic_random_data!(Int16Type); +basic_random_data!(Int32Type); +basic_random_data!(Int64Type); +basic_random_data!(UInt8Type); +basic_random_data!(UInt16Type); +basic_random_data!(UInt32Type); +basic_random_data!(UInt64Type); +basic_random_data!(Float32Type); +basic_random_data!(Float64Type); +basic_random_data!(Date32Type); +basic_random_data!(Time32SecondType); +basic_random_data!(Time32MillisecondType); +basic_random_data!(Time64MicrosecondType); +basic_random_data!(Time64NanosecondType); +basic_random_data!(IntervalYearMonthType); +basic_random_data!(Decimal128Type); + +impl RandomNativeData for Date64Type { + fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { + // TODO: constrain this range to valid dates if necessary + let date_value = rng.gen_range(i64::MIN..=i64::MAX); + let millis_per_day = 86_400_000; + date_value - (date_value % millis_per_day) + } +} + +impl RandomNativeData for IntervalDayTimeType { + fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { + IntervalDayTime { + days: rng.gen::(), + milliseconds: rng.gen::(), + } + } +} + +impl RandomNativeData for IntervalMonthDayNanoType { + fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { + IntervalMonthDayNano { + months: rng.gen::(), + days: rng.gen::(), + nanoseconds: rng.gen::(), + } + } +} + +impl RandomNativeData for Decimal256Type { + fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { + i256::from_parts(rng.gen::(), rng.gen::()) + } +} diff --git a/test-utils/src/array_gen/string.rs b/test-utils/src/array_gen/string.rs index fbfa2bb941e0..b5cef6321bc8 100644 --- a/test-utils/src/array_gen/string.rs +++ b/test-utils/src/array_gen/string.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, UInt32Array}; +use arrow::array::{ + ArrayRef, GenericStringArray, OffsetSizeTrait, StringViewArray, UInt32Array, +}; use rand::rngs::StdRng; use rand::Rng; @@ -59,6 +61,30 @@ impl StringArrayGenerator { let options = None; arrow::compute::take(&distinct_strings, &indicies, options).unwrap() } + + /// Creates a StringViewArray with random strings. + pub fn gen_string_view(&mut self) -> ArrayRef { + let distinct_string_views: StringViewArray = (0..self.num_distinct_strings) + .map(|_| Some(random_string(&mut self.rng, self.max_len))) + .collect(); + + // pick num_strings randomly from the distinct string table + let indicies: UInt32Array = (0..self.num_strings) + .map(|_| { + if self.rng.gen::() < self.null_pct { + None + } else if self.num_distinct_strings > 1 { + let range = 1..(self.num_distinct_strings as u32); + Some(self.rng.gen_range(range)) + } else { + Some(0) + } + }) + .collect(); + + let options = None; + arrow::compute::take(&distinct_string_views, &indicies, options).unwrap() + } } /// Return a string of random characters of length 1..=max_len