diff --git a/crossbeam-skiplist/src/base.rs b/crossbeam-skiplist/src/base.rs
index 862a1ae3d..826e8188c 100644
--- a/crossbeam-skiplist/src/base.rs
+++ b/crossbeam-skiplist/src/base.rs
@@ -471,6 +471,21 @@ where
/// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist.
pub fn get_or_insert(&self, key: K, value: V, guard: &Guard) -> RefEntry<'_, K, V> {
+ self.insert_internal(key, || value, false, guard)
+ }
+
+ /// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist,
+ /// where value is calculated with a function.
+ ///
+ ///
+ /// Note: Another thread may write key value first, leading to the result of this closure
+ /// discarded. If closure is modifying some other state (such as shared counters or shared
+ /// objects), it may lead to undesired behaviour such as counters being changed without
+ /// result of closure inserted
+ pub fn get_or_insert_with(&self, key: K, value: F, guard: &Guard) -> RefEntry<'_, K, V>
+ where
+ F: FnOnce() -> V,
+ {
self.insert_internal(key, value, false, guard)
}
@@ -831,13 +846,16 @@ where
/// Inserts an entry with the specified `key` and `value`.
///
/// If `replace` is `true`, then any existing entry with this key will first be removed.
- fn insert_internal(
+ fn insert_internal(
&self,
key: K,
- value: V,
+ value: F,
replace: bool,
guard: &Guard,
- ) -> RefEntry<'_, K, V> {
+ ) -> RefEntry<'_, K, V>
+ where
+ F: FnOnce() -> V,
+ {
self.check_guard(guard);
unsafe {
@@ -876,6 +894,9 @@ where
}
}
+ // create value before creating node, so extra allocation doesn't happen if value() function panics
+ let value = value();
+
// Create a new node.
let height = self.random_height();
let (node, n) = {
@@ -1061,7 +1082,7 @@ where
/// If there is an existing entry with this key, it will be removed before inserting the new
/// one.
pub fn insert(&self, key: K, value: V, guard: &Guard) -> RefEntry<'_, K, V> {
- self.insert_internal(key, value, true, guard)
+ self.insert_internal(key, || value, true, guard)
}
/// Removes an entry with the specified `key` from the map and returns it.
diff --git a/crossbeam-skiplist/src/map.rs b/crossbeam-skiplist/src/map.rs
index b035d1fc1..6beb3fb91 100644
--- a/crossbeam-skiplist/src/map.rs
+++ b/crossbeam-skiplist/src/map.rs
@@ -254,6 +254,39 @@ where
Entry::new(self.inner.get_or_insert(key, value, guard))
}
+ /// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist,
+ /// where value is calculated with a function.
+ ///
+ ///
+ /// Note: Another thread may write key value first, leading to the result of this closure
+ /// discarded. If closure is modifying some other state (such as shared counters or shared
+ /// objects), it may lead to undesired behaviour such as counters being changed without
+ /// result of closure inserted
+ ////
+ /// This function returns an [`Entry`] which
+ /// can be used to access the key's associated value.
+ ///
+ ///
+ /// # Example
+ /// ```
+ /// use crossbeam_skiplist::SkipMap;
+ ///
+ /// let ages = SkipMap::new();
+ /// let gates_age = ages.get_or_insert_with("Bill Gates", || 64);
+ /// assert_eq!(*gates_age.value(), 64);
+ ///
+ /// ages.insert("Steve Jobs", 65);
+ /// let jobs_age = ages.get_or_insert_with("Steve Jobs", || -1);
+ /// assert_eq!(*jobs_age.value(), 65);
+ /// ```
+ pub fn get_or_insert_with(&self, key: K, value_fn: F) -> Entry<'_, K, V>
+ where
+ F: FnOnce() -> V,
+ {
+ let guard = &epoch::pin();
+ Entry::new(self.inner.get_or_insert_with(key, value_fn, guard))
+ }
+
/// Returns an iterator over all entries in the map,
/// sorted by key.
///
diff --git a/crossbeam-skiplist/tests/base.rs b/crossbeam-skiplist/tests/base.rs
index f08e1409e..c0af2d1b4 100644
--- a/crossbeam-skiplist/tests/base.rs
+++ b/crossbeam-skiplist/tests/base.rs
@@ -431,6 +431,76 @@ fn get_or_insert() {
assert_eq!(*s.get_or_insert(6, 600, guard).value(), 600);
}
+#[test]
+fn get_or_insert_with() {
+ let guard = &epoch::pin();
+ let s = SkipList::new(epoch::default_collector().clone());
+ s.insert(3, 3, guard);
+ s.insert(5, 5, guard);
+ s.insert(1, 1, guard);
+ s.insert(4, 4, guard);
+ s.insert(2, 2, guard);
+
+ assert_eq!(*s.get(&4, guard).unwrap().value(), 4);
+ assert_eq!(*s.insert(4, 40, guard).value(), 40);
+ assert_eq!(*s.get(&4, guard).unwrap().value(), 40);
+
+ assert_eq!(*s.get_or_insert_with(4, || 400, guard).value(), 40);
+ assert_eq!(*s.get(&4, guard).unwrap().value(), 40);
+ assert_eq!(*s.get_or_insert_with(6, || 600, guard).value(), 600);
+}
+
+#[test]
+fn get_or_insert_with_panic() {
+ use std::panic;
+
+ let s = SkipList::new(epoch::default_collector().clone());
+ let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
+ let guard = &epoch::pin();
+ s.get_or_insert_with(4, || panic!(), guard);
+ }));
+ assert!(res.is_err());
+ assert!(s.is_empty());
+ let guard = &epoch::pin();
+ assert_eq!(*s.get_or_insert_with(4, || 40, guard).value(), 40);
+ assert_eq!(s.len(), 1);
+}
+
+#[test]
+fn get_or_insert_with_parallel_run() {
+ use std::sync::{Arc, Mutex};
+
+ let s = Arc::new(SkipList::new(epoch::default_collector().clone()));
+ let s2 = s.clone();
+ let called = Arc::new(Mutex::new(false));
+ let called2 = called.clone();
+ let handle = std::thread::spawn(move || {
+ let guard = &epoch::pin();
+ assert_eq!(
+ *s2.get_or_insert_with(
+ 7,
+ || {
+ *called2.lock().unwrap() = true;
+
+ // allow main thread to run before we return result
+ std::thread::sleep(std::time::Duration::from_secs(4));
+ 70
+ },
+ guard,
+ )
+ .value(),
+ 700
+ );
+ });
+ std::thread::sleep(std::time::Duration::from_secs(2));
+ let guard = &epoch::pin();
+
+ // main thread writes the value first
+ assert_eq!(*s.get_or_insert(7, 700, guard).value(), 700);
+ handle.join().unwrap();
+ assert!(*called.lock().unwrap());
+}
+
#[test]
fn get_next_prev() {
let guard = &epoch::pin();
diff --git a/crossbeam-skiplist/tests/map.rs b/crossbeam-skiplist/tests/map.rs
index 06a0567e0..d00658505 100644
--- a/crossbeam-skiplist/tests/map.rs
+++ b/crossbeam-skiplist/tests/map.rs
@@ -370,6 +370,67 @@ fn get_or_insert() {
assert_eq!(*s.get_or_insert(6, 600).value(), 600);
}
+#[test]
+fn get_or_insert_with() {
+ let s = SkipMap::new();
+ s.insert(3, 3);
+ s.insert(5, 5);
+ s.insert(1, 1);
+ s.insert(4, 4);
+ s.insert(2, 2);
+
+ assert_eq!(*s.get(&4).unwrap().value(), 4);
+ assert_eq!(*s.insert(4, 40).value(), 40);
+ assert_eq!(*s.get(&4).unwrap().value(), 40);
+
+ assert_eq!(*s.get_or_insert_with(4, || 400).value(), 40);
+ assert_eq!(*s.get(&4).unwrap().value(), 40);
+ assert_eq!(*s.get_or_insert_with(6, || 600).value(), 600);
+}
+
+#[test]
+fn get_or_insert_with_panic() {
+ use std::panic;
+
+ let s = SkipMap::new();
+ let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
+ s.get_or_insert_with(4, || panic!());
+ }));
+ assert!(res.is_err());
+ assert!(s.is_empty());
+ assert_eq!(*s.get_or_insert_with(4, || 40).value(), 40);
+ assert_eq!(s.len(), 1);
+}
+
+#[test]
+fn get_or_insert_with_parallel_run() {
+ use std::sync::{Arc, Mutex};
+
+ let s = Arc::new(SkipMap::new());
+ let s2 = s.clone();
+ let called = Arc::new(Mutex::new(false));
+ let called2 = called.clone();
+ let handle = std::thread::spawn(move || {
+ assert_eq!(
+ *s2.get_or_insert_with(7, || {
+ *called2.lock().unwrap() = true;
+
+ // allow main thread to run before we return result
+ std::thread::sleep(std::time::Duration::from_secs(4));
+ 70
+ })
+ .value(),
+ 700
+ );
+ });
+ std::thread::sleep(std::time::Duration::from_secs(2));
+
+ // main thread writes the value first
+ assert_eq!(*s.get_or_insert(7, 700).value(), 700);
+ handle.join().unwrap();
+ assert!(*called.lock().unwrap());
+}
+
#[test]
fn get_next_prev() {
let s = SkipMap::new();