diff --git a/rust/lance-datafusion/src/exec.rs b/rust/lance-datafusion/src/exec.rs index cad0a359c0..8872d8a6dc 100644 --- a/rust/lance-datafusion/src/exec.rs +++ b/rust/lance-datafusion/src/exec.rs @@ -22,6 +22,8 @@ use datafusion::{ datasource::streaming::StreamingTable, execution::{ context::{SessionConfig, SessionContext, SessionState}, + disk_manager::DiskManagerConfig, + memory_pool::GreedyMemoryPool, runtime_env::{RuntimeConfig, RuntimeEnv}, TaskContext, }, @@ -147,12 +149,36 @@ impl ExecutionPlan for OneShotExec { } } +#[derive(Debug, Clone)] +pub struct LanceExecutionOptions { + pub use_spilling: bool, + pub mem_pool_size: u64, +} + +impl Default for LanceExecutionOptions { + fn default() -> Self { + Self { + use_spilling: false, + mem_pool_size: 1024 * 1024 * 100, + } + } +} + /// Executes a plan using default session & runtime configuration /// /// Only executes a single partition. Panics if the plan has more than one partition. -pub fn execute_plan(plan: Arc) -> Result { +pub fn execute_plan( + plan: Arc, + options: LanceExecutionOptions, +) -> Result { let session_config = SessionConfig::new(); - let runtime_config = RuntimeConfig::new(); + let mut runtime_config = RuntimeConfig::new(); + if options.use_spilling { + runtime_config.disk_manager = DiskManagerConfig::NewOs; + runtime_config.memory_pool = Some(Arc::new(GreedyMemoryPool::new( + options.mem_pool_size as usize, + ))); + } let runtime_env = Arc::new(RuntimeEnv::new(runtime_config)?); let session_state = SessionState::new_with_config_rt(session_config, runtime_env); // NOTE: we are only executing the first partition here. Therefore, if diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index 1d4cd9acf5..a8af03e278 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -42,7 +42,7 @@ use futures::{ use lance_core::{Error, Result}; use lance_datafusion::{ chunker::chunk_concat_stream, - exec::{execute_plan, OneShotExec}, + exec::{execute_plan, LanceExecutionOptions, OneShotExec}, }; use roaring::RoaringBitmap; use serde::{Serialize, Serializer}; @@ -1124,7 +1124,13 @@ impl BtreeTrainingSource for BTreeUpdater { // them back into a single partition. let all_data = Arc::new(UnionExec::new(vec![old_input, new_input])); let ordered = Arc::new(SortPreservingMergeExec::new(vec![sort_expr], all_data)); - let unchunked = execute_plan(ordered)?; + let unchunked = execute_plan( + ordered, + LanceExecutionOptions { + use_spilling: true, + ..Default::default() + }, + )?; Ok(chunk_concat_stream(unchunked, chunk_size as usize)) } } diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 611f1067d8..986dd40704 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -43,7 +43,7 @@ use futures::stream::{Stream, StreamExt}; use futures::TryStreamExt; use lance_arrow::floats::{coerce_float_vector, FloatType}; use lance_core::{ROW_ID, ROW_ID_FIELD}; -use lance_datafusion::exec::execute_plan; +use lance_datafusion::exec::{execute_plan, LanceExecutionOptions}; use lance_datafusion::expr::parse_substrait; use lance_index::vector::{Query, DIST_COL}; use lance_index::{scalar::expression::ScalarIndexExpr, DatasetIndexExt}; @@ -667,12 +667,18 @@ impl Scanner { #[instrument(skip_all)] pub async fn try_into_stream(&self) -> Result { let plan = self.create_plan().await?; - Ok(DatasetRecordBatchStream::new(execute_plan(plan)?)) + Ok(DatasetRecordBatchStream::new(execute_plan( + plan, + LanceExecutionOptions::default(), + )?)) } - pub(crate) async fn try_into_dfstream(&self) -> Result { + pub(crate) async fn try_into_dfstream( + &self, + options: LanceExecutionOptions, + ) -> Result { let plan = self.create_plan().await?; - execute_plan(plan) + execute_plan(plan, options) } pub async fn try_into_batch(&self) -> Result { @@ -705,7 +711,7 @@ impl Scanner { plan, plan_schema, )?); - let mut stream = execute_plan(count_plan)?; + let mut stream = execute_plan(count_plan, LanceExecutionOptions::default())?; // A count plan will always return a single batch with a single row. if let Some(first_batch) = stream.next().await { diff --git a/rust/lance/src/index/scalar.rs b/rust/lance/src/index/scalar.rs index 7e5851b0fd..2202cc7aeb 100644 --- a/rust/lance/src/index/scalar.rs +++ b/rust/lance/src/index/scalar.rs @@ -19,7 +19,7 @@ use std::sync::Arc; use async_trait::async_trait; use datafusion::physical_plan::SendableRecordBatchStream; -use lance_datafusion::chunker::chunk_concat_stream; +use lance_datafusion::{chunker::chunk_concat_stream, exec::LanceExecutionOptions}; use lance_index::scalar::{ btree::{train_btree_index, BTreeIndex, BtreeTrainingSource}, flat::FlatIndexMetadata, @@ -63,7 +63,12 @@ impl BtreeTrainingSource for TrainingRequest { )]))? .project(&[&self.column])?; - let ordered_batches = scan.try_into_dfstream().await?; + let ordered_batches = scan + .try_into_dfstream(LanceExecutionOptions { + use_spilling: true, + ..Default::default() + }) + .await?; Ok(chunk_concat_stream(ordered_batches, chunk_size as usize)) } }