diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py index 55849edf0..671c4c209 100644 --- a/datafusion/tests/test_context.py +++ b/datafusion/tests/test_context.py @@ -40,6 +40,8 @@ def test_create_context_with_all_valid_args(): repartition_windows=False, parquet_pruning=False, config_options=None, + memory_pool_size=1073741824, + spill_path="." ) # verify that at least some of the arguments worked diff --git a/examples/dataframe-parquet.py b/examples/dataframe-parquet.py index 31a8aa645..48b4ad45c 100644 --- a/examples/dataframe-parquet.py +++ b/examples/dataframe-parquet.py @@ -18,7 +18,7 @@ from datafusion import SessionContext from datafusion import functions as f -ctx = SessionContext() +ctx = SessionContext(memory_pool_size=1073741824, spill_path="/tmp") df = ctx.read_parquet( "/mnt/bigdata/nyctaxi/yellow/2021/yellow_tripdata_2021-01.parquet" ).aggregate([f.col("passenger_count")], [f.count_star()]) diff --git a/examples/sql-parquet.py b/examples/sql-parquet.py index 7b2db6f2b..5f9a12edf 100644 --- a/examples/sql-parquet.py +++ b/examples/sql-parquet.py @@ -17,7 +17,7 @@ from datafusion import SessionContext -ctx = SessionContext() +ctx = SessionContext(memory_pool_size=1073741824, spill_path="/tmp") ctx.register_parquet( "taxi", "/mnt/bigdata/nyctaxi/yellow/2021/yellow_tripdata_2021-01.parquet" ) diff --git a/examples/sql-to-pandas.py b/examples/sql-to-pandas.py index 3569e6d8c..cfa96b2b2 100644 --- a/examples/sql-to-pandas.py +++ b/examples/sql-to-pandas.py @@ -19,7 +19,7 @@ # Create a DataFusion context -ctx = SessionContext() +ctx = SessionContext(memory_pool_size=1073741824, spill_path="/tmp") # Register table with context ctx.register_parquet("taxi", "yellow_tripdata_2021-01.parquet") diff --git a/src/context.rs b/src/context.rs index 3990a4c37..34c091841 100644 --- a/src/context.rs +++ b/src/context.rs @@ -39,6 +39,9 @@ use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::datasource::TableProvider; use datafusion::datasource::MemTable; use datafusion::execution::context::{SessionConfig, SessionContext}; +use datafusion::execution::disk_manager::DiskManagerConfig; +use datafusion::execution::memory_pool::GreedyMemoryPool; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::prelude::{ AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions, }; @@ -66,7 +69,9 @@ impl PySessionContext { repartition_windows = "true", parquet_pruning = "true", target_partitions = "None", - config_options = "None" + config_options = "None", + memory_pool_size = "None", + spill_path = "None" )] #[new] fn new( @@ -80,6 +85,8 @@ impl PySessionContext { parquet_pruning: bool, target_partitions: Option, config_options: Option>, + memory_pool_size: Option, + spill_path: Option<&str>, ) -> PyResult { let mut cfg = SessionConfig::new() .with_information_schema(information_schema) @@ -103,9 +110,20 @@ impl PySessionContext { Some(x) => cfg.with_target_partitions(x), }; - Ok(PySessionContext { - ctx: SessionContext::with_config(cfg_full), - }) + let mut runtime_config = datafusion::execution::runtime_env::RuntimeConfig::new(); + + if let Some(size) = memory_pool_size { + runtime_config = runtime_config.with_memory_pool(Arc::new(GreedyMemoryPool::new(size))); + } + if let Some(path) = spill_path { + runtime_config = runtime_config + .with_disk_manager(DiskManagerConfig::new_specified(vec![path.into()])); + } + + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + let ctx = SessionContext::with_config_rt(cfg_full, runtime); + + Ok(PySessionContext { ctx }) } /// Register a an object store with the given name