diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 5cf8969aa46d..fc2cdbb7518d 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1952,6 +1952,11 @@ impl SessionState { &self.config } + /// Return the mutable [`SessionConfig`]. + pub fn config_mut(&mut self) -> &mut SessionConfig { + &mut self.config + } + /// Return the physical optimizers pub fn physical_optimizers(&self) -> &[Arc] { &self.physical_optimizers.rules diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 0a7a87c7d81a..e29030e61457 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -501,13 +501,54 @@ impl SessionConfig { /// /// [^1]: Compare that to [`ConfigOptions`] which only supports [`ScalarValue`] payloads. pub fn with_extension(mut self, ext: Arc) -> Self + where + T: Send + Sync + 'static, + { + self.set_extension(ext); + self + } + + /// Set extension. Pretty much the same as [`with_extension`](Self::with_extension), but take + /// mutable reference instead of owning it. Useful if you want to add another extension after + /// the [`SessionConfig`] is created. + /// + /// # Example + /// ``` + /// use std::sync::Arc; + /// use datafusion_execution::config::SessionConfig; + /// + /// // application-specific extension types + /// struct Ext1(u8); + /// struct Ext2(u8); + /// struct Ext3(u8); + /// + /// let ext1a = Arc::new(Ext1(10)); + /// let ext1b = Arc::new(Ext1(11)); + /// let ext2 = Arc::new(Ext2(2)); + /// + /// let mut cfg = SessionConfig::default(); + /// + /// // will only remember the last Ext1 + /// cfg.set_extension(Arc::clone(&ext1a)); + /// cfg.set_extension(Arc::clone(&ext1b)); + /// cfg.set_extension(Arc::clone(&ext2)); + /// + /// let ext1_received = cfg.get_extension::().unwrap(); + /// assert!(!Arc::ptr_eq(&ext1_received, &ext1a)); + /// assert!(Arc::ptr_eq(&ext1_received, &ext1b)); + /// + /// let ext2_received = cfg.get_extension::().unwrap(); + /// assert!(Arc::ptr_eq(&ext2_received, &ext2)); + /// + /// assert!(cfg.get_extension::().is_none()); + /// ``` + pub fn set_extension(&mut self, ext: Arc) where T: Send + Sync + 'static, { let ext = ext as Arc; let id = TypeId::of::(); self.extensions.insert(id, ext); - self } /// Get extension, if any for the specified type `T` exists.