Skip to content

Commit

Permalink
Stream CSV file during schema inference (#4661)
Browse files Browse the repository at this point in the history
* Stream CSV file during infer schema

* Change error type

* Refactor

* Add test for infer csv stream

* Make csv infer schema test more robust
  • Loading branch information
Jefffrey authored Dec 23, 2022
1 parent 6a4e0df commit 720bdb0
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 21 deletions.
180 changes: 159 additions & 21 deletions datafusion/core/src/datasource/file_format/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@
use std::any::Any;

use std::collections::HashSet;
use std::sync::Arc;

use arrow::datatypes::Schema;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::{self, datatypes::SchemaRef};
use async_trait::async_trait;
use bytes::Buf;

use datafusion_common::DataFusionError;

use futures::TryFutureExt;
use futures::{pin_mut, StreamExt, TryStreamExt};
use object_store::{ObjectMeta, ObjectStore};

use super::FileFormat;
Expand All @@ -37,7 +38,9 @@ use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD;
use crate::error::Result;
use crate::execution::context::SessionState;
use crate::logical_expr::Expr;
use crate::physical_plan::file_format::{CsvExec, FileScanConfig};
use crate::physical_plan::file_format::{
newline_delimited_stream, CsvExec, FileScanConfig,
};
use crate::physical_plan::ExecutionPlan;
use crate::physical_plan::Statistics;

Expand Down Expand Up @@ -122,27 +125,75 @@ impl FileFormat for CsvFormat {

let mut records_to_read = self.schema_infer_max_rec.unwrap_or(usize::MAX);

for object in objects {
let data = store
'iterating_objects: for object in objects {
// stream to only read as many rows as needed into memory
let stream = store
.get(&object.location)
.and_then(|r| r.bytes())
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;

let decoder = self.file_compression_type.convert_read(data.reader())?;
let (schema, records_read) = arrow::csv::reader::infer_reader_schema(
decoder,
self.delimiter,
Some(records_to_read),
self.has_header,
)?;
schemas.push(schema.clone());
if records_read == 0 {
continue;
.await?
.into_stream()
.map_err(|e| DataFusionError::External(Box::new(e)));
let stream = newline_delimited_stream(stream);
pin_mut!(stream);

let mut column_names = vec![];
let mut column_type_possibilities = vec![];
let mut first_chunk = true;

'reading_object: while let Some(data) = stream.next().await.transpose()? {
let (Schema { fields, .. }, records_read) =
arrow::csv::reader::infer_reader_schema(
self.file_compression_type.convert_read(data.reader())?,
self.delimiter,
Some(records_to_read),
// only consider header for first chunk
self.has_header && first_chunk,
)?;
records_to_read -= records_read;

if first_chunk {
// set up initial structures for recording inferred schema across chunks
(column_names, column_type_possibilities) = fields
.into_iter()
.map(|field| {
let mut possibilities = HashSet::new();
if records_read > 0 {
// at least 1 data row read, record the inferred datatype
possibilities.insert(field.data_type().clone());
}
(field.name().clone(), possibilities)
})
.unzip();
first_chunk = false;
} else {
if fields.len() != column_type_possibilities.len() {
return Err(DataFusionError::Execution(
format!(
"Encountered unequal lengths between records on CSV file whilst inferring schema. \
Expected {} records, found {} records",
column_type_possibilities.len(),
fields.len()
)
));
}

column_type_possibilities.iter_mut().zip(fields).for_each(
|(possibilities, field)| {
possibilities.insert(field.data_type().clone());
},
);
}

if records_to_read == 0 {
break 'reading_object;
}
}
records_to_read -= records_read;

schemas.push(build_schema_helper(
column_names,
&column_type_possibilities,
));
if records_to_read == 0 {
break;
break 'iterating_objects;
}
}

Expand Down Expand Up @@ -176,14 +227,50 @@ impl FileFormat for CsvFormat {
}
}

fn build_schema_helper(names: Vec<String>, types: &[HashSet<DataType>]) -> Schema {
let fields = names
.into_iter()
.zip(types)
.map(|(field_name, data_type_possibilities)| {
// ripped from arrow::csv::reader::infer_reader_schema_with_csv_options
// determine data type based on possible types
// if there are incompatible types, use DataType::Utf8
match data_type_possibilities.len() {
1 => Field::new(
field_name,
data_type_possibilities.iter().next().unwrap().clone(),
true,
),
2 => {
if data_type_possibilities.contains(&DataType::Int64)
&& data_type_possibilities.contains(&DataType::Float64)
{
// we have an integer and double, fall down to double
Field::new(field_name, DataType::Float64, true)
} else {
// default to Utf8 for conflicting datatypes (e.g bool and int)
Field::new(field_name, DataType::Utf8, true)
}
}
_ => Field::new(field_name, DataType::Utf8, true),
}
})
.collect();
Schema::new(fields)
}

#[cfg(test)]
mod tests {
use super::super::test_util::scan_format;
use super::*;
use crate::datasource::file_format::test_util::VariableStream;
use crate::physical_plan::collect;
use crate::prelude::{SessionConfig, SessionContext};
use bytes::Bytes;
use chrono::DateTime;
use datafusion_common::cast::as_string_array;
use futures::StreamExt;
use object_store::path::Path;

#[tokio::test]
async fn read_small_batches() -> Result<()> {
Expand Down Expand Up @@ -291,6 +378,57 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_infer_schema_stream() -> Result<()> {
let session_ctx = SessionContext::new();
let state = session_ctx.state();
let variable_object_store =
Arc::new(VariableStream::new(Bytes::from("1,2,3,4,5\n"), 200));
let object_meta = ObjectMeta {
location: Path::parse("/")?,
last_modified: DateTime::default(),
size: usize::MAX,
};

let num_rows_to_read = 100;
let csv_format = CsvFormat {
has_header: false,
schema_infer_max_rec: Some(num_rows_to_read),
..Default::default()
};
let inferred_schema = csv_format
.infer_schema(
&state,
&(variable_object_store.clone() as Arc<dyn ObjectStore>),
&[object_meta],
)
.await?;

let actual_fields: Vec<_> = inferred_schema
.fields()
.iter()
.map(|f| format!("{}: {:?}", f.name(), f.data_type()))
.collect();
assert_eq!(
vec![
"column_1: Int64",
"column_2: Int64",
"column_3: Int64",
"column_4: Int64",
"column_5: Int64"
],
actual_fields
);
// ensuring on csv infer that it won't try to read entire file
// should only read as many rows as was configured in the CsvFormat
assert_eq!(
num_rows_to_read,
variable_object_store.get_iterations_detected()
);

Ok(())
}

async fn get_exec(
state: &SessionState,
file_name: &str,
Expand Down
126 changes: 126 additions & 0 deletions datafusion/core/src/datasource/file_format/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,20 @@ pub trait FileFormat: Send + Sync + fmt::Debug {

#[cfg(test)]
pub(crate) mod test_util {
use std::ops::Range;
use std::sync::Mutex;

use super::*;
use crate::datasource::listing::PartitionedFile;
use crate::datasource::object_store::ObjectStoreUrl;
use crate::test::object_store::local_unpartitioned_file;
use bytes::Bytes;
use futures::stream::BoxStream;
use futures::StreamExt;
use object_store::local::LocalFileSystem;
use object_store::path::Path;
use object_store::{GetResult, ListResult, MultipartId};
use tokio::io::AsyncWrite;

pub async fn scan_format(
state: &SessionState,
Expand Down Expand Up @@ -135,4 +144,121 @@ pub(crate) mod test_util {
.await?;
Ok(exec)
}

/// Mock ObjectStore to provide an variable stream of bytes on get
/// Able to keep track of how many iterations of the provided bytes were repeated
#[derive(Debug)]
pub struct VariableStream {
bytes_to_repeat: Bytes,
max_iterations: usize,
iterations_detected: Arc<Mutex<usize>>,
}

impl std::fmt::Display for VariableStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "VariableStream")
}
}

#[async_trait]
impl ObjectStore for VariableStream {
async fn put(&self, _location: &Path, _bytes: Bytes) -> object_store::Result<()> {
unimplemented!()
}

async fn put_multipart(
&self,
_location: &Path,
) -> object_store::Result<(MultipartId, Box<dyn AsyncWrite + Unpin + Send>)>
{
unimplemented!()
}

async fn abort_multipart(
&self,
_location: &Path,
_multipart_id: &MultipartId,
) -> object_store::Result<()> {
unimplemented!()
}

async fn get(&self, _location: &Path) -> object_store::Result<GetResult> {
let bytes = self.bytes_to_repeat.clone();
let arc = self.iterations_detected.clone();
Ok(GetResult::Stream(
futures::stream::repeat_with(move || {
let arc_inner = arc.clone();
*arc_inner.lock().unwrap() += 1;
Ok(bytes.clone())
})
.take(self.max_iterations)
.boxed(),
))
}

async fn get_range(
&self,
_location: &Path,
_range: Range<usize>,
) -> object_store::Result<Bytes> {
unimplemented!()
}

async fn get_ranges(
&self,
_location: &Path,
_ranges: &[Range<usize>],
) -> object_store::Result<Vec<Bytes>> {
unimplemented!()
}

async fn head(&self, _location: &Path) -> object_store::Result<ObjectMeta> {
unimplemented!()
}

async fn delete(&self, _location: &Path) -> object_store::Result<()> {
unimplemented!()
}

async fn list(
&self,
_prefix: Option<&Path>,
) -> object_store::Result<BoxStream<'_, object_store::Result<ObjectMeta>>>
{
unimplemented!()
}

async fn list_with_delimiter(
&self,
_prefix: Option<&Path>,
) -> object_store::Result<ListResult> {
unimplemented!()
}

async fn copy(&self, _from: &Path, _to: &Path) -> object_store::Result<()> {
unimplemented!()
}

async fn copy_if_not_exists(
&self,
_from: &Path,
_to: &Path,
) -> object_store::Result<()> {
unimplemented!()
}
}

impl VariableStream {
pub fn new(bytes_to_repeat: Bytes, max_iterations: usize) -> Self {
Self {
bytes_to_repeat,
max_iterations,
iterations_detected: Arc::new(Mutex::new(0)),
}
}

pub fn get_iterations_detected(&self) -> usize {
*self.iterations_detected.lock().unwrap()
}
}
}
1 change: 1 addition & 0 deletions datafusion/core/src/physical_plan/file_format/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ mod parquet;

pub(crate) use self::csv::plan_to_csv;
pub use self::csv::CsvExec;
pub(crate) use self::delimited_stream::newline_delimited_stream;
pub(crate) use self::parquet::plan_to_parquet;
pub use self::parquet::{ParquetExec, ParquetFileMetrics, ParquetFileReaderFactory};
use arrow::{
Expand Down

0 comments on commit 720bdb0

Please sign in to comment.