Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push SessionState into FileFormat (#4349) #4699

Merged
merged 3 commits into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ impl TableProvider for DataFrame {

async fn scan(
&self,
_ctx: &SessionState,
_state: &SessionState,
projection: Option<&Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/datasource/datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub trait TableProvider: Sync + Send {
/// parallelized or distributed.
async fn scan(
&self,
ctx: &SessionState,
state: &SessionState,
projection: Option<&Vec<usize>>,
filters: &[Expr],
// limit can be used to reduce the amount scanned
Expand Down Expand Up @@ -94,7 +94,7 @@ pub trait TableProviderFactory: Sync + Send {
/// Create a TableProvider with the given url
async fn create(
&self,
ctx: &SessionState,
state: &SessionState,
cmd: &CreateExternalTable,
) -> Result<Arc<dyn TableProvider>>;
}
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/empty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl TableProvider for EmptyTable {

async fn scan(
&self,
_ctx: &SessionState,
_state: &SessionState,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
Expand Down
59 changes: 38 additions & 21 deletions datafusion/core/src/datasource/file_format/avro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use object_store::{GetResult, ObjectMeta, ObjectStore};
use super::FileFormat;
use crate::avro_to_arrow::read_avro_schema_from_reader;
use crate::error::Result;
use crate::execution::context::SessionState;
use crate::logical_expr::Expr;
use crate::physical_plan::file_format::{AvroExec, FileScanConfig};
use crate::physical_plan::ExecutionPlan;
Expand All @@ -47,6 +48,7 @@ impl FileFormat for AvroFormat {

async fn infer_schema(
&self,
_state: &SessionState,
store: &Arc<dyn ObjectStore>,
objects: &[ObjectMeta],
) -> Result<SchemaRef> {
Expand All @@ -68,6 +70,7 @@ impl FileFormat for AvroFormat {

async fn infer_stats(
&self,
_state: &SessionState,
_store: &Arc<dyn ObjectStore>,
_table_schema: SchemaRef,
_object: &ObjectMeta,
Expand All @@ -77,6 +80,7 @@ impl FileFormat for AvroFormat {

async fn create_physical_plan(
&self,
_state: &SessionState,
conf: FileScanConfig,
_filters: &[Expr],
) -> Result<Arc<dyn ExecutionPlan>> {
Expand All @@ -101,10 +105,11 @@ mod tests {
#[tokio::test]
async fn read_small_batches() -> Result<()> {
let config = SessionConfig::new().with_batch_size(2);
let ctx = SessionContext::with_config(config);
let task_ctx = ctx.task_ctx();
let session_ctx = SessionContext::with_config(config);
let state = session_ctx.state();
let task_ctx = state.task_ctx();
let projection = None;
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&state, "alltypes_plain.avro", projection, None).await?;
let stream = exec.execute(0, task_ctx)?;

let tt_batches = stream
Expand All @@ -124,9 +129,10 @@ mod tests {
#[tokio::test]
async fn read_limit() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let state = session_ctx.state();
let task_ctx = state.task_ctx();
let projection = None;
let exec = get_exec("alltypes_plain.avro", projection, Some(1)).await?;
let exec = get_exec(&state, "alltypes_plain.avro", projection, Some(1)).await?;
let batches = collect(exec, task_ctx).await?;
assert_eq!(1, batches.len());
assert_eq!(11, batches[0].num_columns());
Expand All @@ -138,9 +144,10 @@ mod tests {
#[tokio::test]
async fn read_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let state = session_ctx.state();
let task_ctx = state.task_ctx();
let projection = None;
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&state, "alltypes_plain.avro", projection, None).await?;

let x: Vec<String> = exec
.schema()
Expand Down Expand Up @@ -190,9 +197,10 @@ mod tests {
#[tokio::test]
async fn read_bool_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let state = session_ctx.state();
let task_ctx = state.task_ctx();
let projection = Some(vec![1]);
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&state, "alltypes_plain.avro", projection, None).await?;

let batches = collect(exec, task_ctx).await?;
assert_eq!(batches.len(), 1);
Expand All @@ -216,9 +224,10 @@ mod tests {
#[tokio::test]
async fn read_i32_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let state = session_ctx.state();
let task_ctx = state.task_ctx();
let projection = Some(vec![0]);
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&state, "alltypes_plain.avro", projection, None).await?;

let batches = collect(exec, task_ctx).await?;
assert_eq!(batches.len(), 1);
Expand All @@ -239,9 +248,10 @@ mod tests {
#[tokio::test]
async fn read_i96_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let state = session_ctx.state();
let task_ctx = state.task_ctx();
let projection = Some(vec![10]);
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&state, "alltypes_plain.avro", projection, None).await?;

let batches = collect(exec, task_ctx).await?;
assert_eq!(batches.len(), 1);
Expand All @@ -262,9 +272,10 @@ mod tests {
#[tokio::test]
async fn read_f32_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let state = session_ctx.state();
let task_ctx = state.task_ctx();
let projection = Some(vec![6]);
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&state, "alltypes_plain.avro", projection, None).await?;

let batches = collect(exec, task_ctx).await?;
assert_eq!(batches.len(), 1);
Expand All @@ -288,9 +299,10 @@ mod tests {
#[tokio::test]
async fn read_f64_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let state = session_ctx.state();
let task_ctx = state.task_ctx();
let projection = Some(vec![7]);
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&state, "alltypes_plain.avro", projection, None).await?;

let batches = collect(exec, task_ctx).await?;
assert_eq!(batches.len(), 1);
Expand All @@ -314,9 +326,10 @@ mod tests {
#[tokio::test]
async fn read_binary_alltypes_plain_avro() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let state = session_ctx.state();
let task_ctx = state.task_ctx();
let projection = Some(vec![9]);
let exec = get_exec("alltypes_plain.avro", projection, None).await?;
let exec = get_exec(&state, "alltypes_plain.avro", projection, None).await?;

let batches = collect(exec, task_ctx).await?;
assert_eq!(batches.len(), 1);
Expand All @@ -338,14 +351,15 @@ mod tests {
}

async fn get_exec(
state: &SessionState,
file_name: &str,
projection: Option<Vec<usize>>,
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let testdata = crate::test_util::arrow_test_data();
let store_root = format!("{}/avro", testdata);
let format = AvroFormat {};
scan_format(&format, &store_root, file_name, projection, limit).await
scan_format(state, &format, &store_root, file_name, projection, limit).await
}
}

Expand All @@ -356,13 +370,16 @@ mod tests {

use super::super::test_util::scan_format;
use crate::error::DataFusionError;
use crate::prelude::SessionContext;

#[tokio::test]
async fn test() -> Result<()> {
let session_ctx = SessionContext::new();
let state = session_ctx.state();
let format = AvroFormat {};
let testdata = crate::test_util::arrow_test_data();
let filename = "avro/alltypes_plain.avro";
let result = scan_format(&format, &testdata, filename, None, None).await;
let result = scan_format(&state, &format, &testdata, filename, None, None).await;
assert!(matches!(
result,
Err(DataFusionError::NotImplemented(msg))
Expand Down
26 changes: 19 additions & 7 deletions datafusion/core/src/datasource/file_format/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use super::FileFormat;
use crate::datasource::file_format::file_type::FileCompressionType;
use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD;
use crate::error::Result;
use crate::execution::context::SessionState;
use crate::logical_expr::Expr;
use crate::physical_plan::file_format::{CsvExec, FileScanConfig};
use crate::physical_plan::ExecutionPlan;
Expand Down Expand Up @@ -113,6 +114,7 @@ impl FileFormat for CsvFormat {

async fn infer_schema(
&self,
_state: &SessionState,
store: &Arc<dyn ObjectStore>,
objects: &[ObjectMeta],
) -> Result<SchemaRef> {
Expand Down Expand Up @@ -150,6 +152,7 @@ impl FileFormat for CsvFormat {

async fn infer_stats(
&self,
_state: &SessionState,
_store: &Arc<dyn ObjectStore>,
_table_schema: SchemaRef,
_object: &ObjectMeta,
Expand All @@ -159,6 +162,7 @@ impl FileFormat for CsvFormat {

async fn create_physical_plan(
&self,
_state: &SessionState,
conf: FileScanConfig,
_filters: &[Expr],
) -> Result<Arc<dyn ExecutionPlan>> {
Expand All @@ -184,11 +188,12 @@ mod tests {
#[tokio::test]
async fn read_small_batches() -> Result<()> {
let config = SessionConfig::new().with_batch_size(2);
let ctx = SessionContext::with_config(config);
let session_ctx = SessionContext::with_config(config);
let state = session_ctx.state();
let task_ctx = state.task_ctx();
// skip column 9 that overflows the automaticly discovered column type of i64 (u64 would work)
let projection = Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12]);
let exec = get_exec("aggregate_test_100.csv", projection, None).await?;
let task_ctx = ctx.task_ctx();
let exec = get_exec(&state, "aggregate_test_100.csv", projection, None).await?;
let stream = exec.execute(0, task_ctx)?;

let tt_batches: i32 = stream
Expand All @@ -212,9 +217,11 @@ mod tests {
#[tokio::test]
async fn read_limit() -> Result<()> {
let session_ctx = SessionContext::new();
let state = session_ctx.state();
let task_ctx = session_ctx.task_ctx();
let projection = Some(vec![0, 1, 2, 3]);
let exec = get_exec("aggregate_test_100.csv", projection, Some(1)).await?;
let exec =
get_exec(&state, "aggregate_test_100.csv", projection, Some(1)).await?;
let batches = collect(exec, task_ctx).await?;
assert_eq!(1, batches.len());
assert_eq!(4, batches[0].num_columns());
Expand All @@ -225,8 +232,11 @@ mod tests {

#[tokio::test]
async fn infer_schema() -> Result<()> {
let session_ctx = SessionContext::new();
let state = session_ctx.state();

let projection = None;
let exec = get_exec("aggregate_test_100.csv", projection, None).await?;
let exec = get_exec(&state, "aggregate_test_100.csv", projection, None).await?;

let x: Vec<String> = exec
.schema()
Expand Down Expand Up @@ -259,9 +269,10 @@ mod tests {
#[tokio::test]
async fn read_char_column() -> Result<()> {
let session_ctx = SessionContext::new();
let state = session_ctx.state();
let task_ctx = session_ctx.task_ctx();
let projection = Some(vec![0]);
let exec = get_exec("aggregate_test_100.csv", projection, None).await?;
let exec = get_exec(&state, "aggregate_test_100.csv", projection, None).await?;

let batches = collect(exec, task_ctx).await.expect("Collect batches");

Expand All @@ -281,12 +292,13 @@ mod tests {
}

async fn get_exec(
state: &SessionState,
file_name: &str,
projection: Option<Vec<usize>>,
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let root = format!("{}/csv", crate::test_util::arrow_test_data());
let format = CsvFormat::default();
scan_format(&format, &root, file_name, projection, limit).await
scan_format(state, &format, &root, file_name, projection, limit).await
}
}
Loading