Skip to content

Commit

Permalink
fix interval array generator
Browse files Browse the repository at this point in the history
Signed-off-by: BubbleCal <[email protected]>
  • Loading branch information
BubbleCal committed Jul 12, 2024
1 parent 8f879f1 commit fe2c459
Showing 1 changed file with 60 additions and 6 deletions.
66 changes: 60 additions & 6 deletions rust/lance-datagen/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

use std::{iter, marker::PhantomData, sync::Arc};

use arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano};
use arrow::{
array::{ArrayData, AsArray},
buffer::{BooleanBuffer, Buffer, OffsetBuffer, ScalarBuffer},
Expand All @@ -14,7 +15,7 @@ use arrow_array::{
Array, FixedSizeBinaryArray, FixedSizeListArray, ListArray, PrimitiveArray, RecordBatch,
RecordBatchOptions, RecordBatchReader, StringArray, StructArray,
};
use arrow_schema::{ArrowError, DataType, Field, Fields, Schema, SchemaRef};
use arrow_schema::{ArrowError, DataType, Field, Fields, IntervalUnit, Schema, SchemaRef};
use futures::{stream::BoxStream, StreamExt};
use rand::{distributions::Uniform, Rng, RngCore, SeedableRng};

Expand Down Expand Up @@ -596,6 +597,59 @@ impl ArrayGenerator for RandomFixedSizeBinaryGenerator {
}
}

pub struct RandomIntervalGenerator {
unit: IntervalUnit,
data_type: DataType,
}

impl RandomIntervalGenerator {
pub fn new(unit: IntervalUnit) -> Self {
Self {
unit,
data_type: DataType::Interval(unit),
}
}
}

impl ArrayGenerator for RandomIntervalGenerator {
fn generate(
&mut self,
length: RowCount,
rng: &mut rand_xoshiro::Xoshiro256PlusPlus,
) -> Result<Arc<dyn arrow_array::Array>, ArrowError> {
match self.unit {
IntervalUnit::YearMonth => {
let months = (0..length.0).map(|_| rng.gen::<i32>()).collect::<Vec<_>>();
Ok(Arc::new(arrow_array::IntervalYearMonthArray::from(months)))
}
IntervalUnit::MonthDayNano => {
let day_time_array = (0..length.0)
.map(|_| IntervalMonthDayNano::new(rng.gen(), rng.gen(), rng.gen()))
.collect::<Vec<_>>();
Ok(Arc::new(arrow_array::IntervalMonthDayNanoArray::from(
day_time_array,
)))
}
IntervalUnit::DayTime => {
let day_time_array = (0..length.0)
.map(|_| IntervalDayTime::new(rng.gen(), rng.gen()))
.collect::<Vec<_>>();
Ok(Arc::new(arrow_array::IntervalDayTimeArray::from(
day_time_array,
)))
}
}
}

fn data_type(&self) -> &DataType {
&self.data_type
}

fn element_size_bytes(&self) -> Option<ByteCount> {
Some(ByteCount::from(12))
}
}

pub struct RandomBinaryGenerator {
bytes_per_element: ByteCount,
scale_to_utf8: bool,
Expand Down Expand Up @@ -1461,6 +1515,10 @@ pub mod array {
Box::new(RandomFixedSizeBinaryGenerator::new(size))
}

pub fn rand_interval(unit: IntervalUnit) -> Box<dyn ArrayGenerator> {
Box::new(RandomIntervalGenerator::new(unit))
}

/// Create a generator of randomly sampled date32 values
///
/// Instead of sampling the entire range, all values will be drawn from the last year as this
Expand Down Expand Up @@ -1663,11 +1721,7 @@ pub mod array {
TimeUnit::Microsecond => rand::<DurationMicrosecondType>(),
TimeUnit::Nanosecond => rand::<DurationNanosecondType>(),
},
DataType::Interval(unit) => match unit {
IntervalUnit::DayTime => rand::<IntervalDayTimeType>(),
IntervalUnit::MonthDayNano => rand::<IntervalMonthDayNanoType>(),
IntervalUnit::YearMonth => rand::<IntervalYearMonthType>(),
},
DataType::Interval(unit) => rand_interval(*unit),
DataType::Date32 => rand_date32(),
DataType::Date64 => rand_date64(),
DataType::Time32(resolution) => rand_time32(resolution),
Expand Down

0 comments on commit fe2c459

Please sign in to comment.