diff --git a/Cargo.toml b/Cargo.toml index 600a8b1..5bba8be 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,4 +25,5 @@ name = "basic" path = "examples/basic.rs" [features] -serde_support = ["serde", "serde_derive", "snowflake/serde_support"] \ No newline at end of file +serde_support = ["serde", "serde_derive", "snowflake/serde_support"] +map = [] diff --git a/README.md b/README.md index fd1cadf..f4c32a3 100644 --- a/README.md +++ b/README.md @@ -66,3 +66,4 @@ fn main() { * [Cecile Tonglet](https://github.com/cecton) * [Drakulix](https://github.com/Drakulix) * [石博文](https://github.com/sbwtw) +* [Robin Stumm](https://github.com/dermetfan) diff --git a/src/tree.rs b/src/tree.rs index 3702903..d2a2f55 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -1484,6 +1484,80 @@ impl Tree { Ok(LevelOrderTraversalIds::new(self, node_id.clone())) } + /// Returns a new `Tree` which nodes' values are mapped using the provided function. + /// + /// If the mapping function returns an error it is returned immediately + /// and the process is aborted. + /// + /// ``` + /// # use id_tree::*; + /// # use id_tree::InsertBehavior::*; + /// + /// let mut tree = Tree::new(); + /// tree.insert(Node::new(1), AsRoot).unwrap(); + /// + /// fn map(x: i32) -> Result { + /// Ok(x * 10) + /// } + /// + /// let tree = tree.map(map).unwrap(); + /// + /// assert_eq!(10, *tree.get(&tree.root_node_id().unwrap()).unwrap().data()); + /// ``` + /// + #[cfg(feature = "map")] + pub fn map(mut self, mut map: F) -> Result, E> + where + F: FnMut(T) -> Result, + { + let tree_id = ProcessUniqueId::new(); + + Ok(Tree { + id: tree_id, + root: self.root.as_ref().map(|x| NodeId { + tree_id, + index: x.index, + }), + nodes: self + .nodes + .drain(..) + .map(|mut x| { + match x.as_mut().map(|y| { + Ok(Node { + data: map(unsafe { + std::mem::replace(&mut y.data, std::mem::zeroed()) + })?, + parent: y.parent.as_ref().map(|z| NodeId { + tree_id, + index: z.index, + }), + children: y + .children + .iter() + .map(|z| NodeId { + tree_id, + index: z.index, + }) + .collect(), + }) + }) { + None => Ok(None), + Some(Ok(y)) => Ok(Some(y)), + Some(Err(y)) => Err(y), + } + }) + .collect::>()?, + free_ids: self + .free_ids + .iter() + .map(|x| NodeId { + tree_id, + index: x.index, + }) + .collect(), + }) + } + // Nothing should make it past this function. // If there is a way for a NodeId to be invalid, it should be caught here. fn is_valid_node_id(&self, node_id: &NodeId) -> (bool, Option) { @@ -2990,4 +3064,70 @@ mod tree_tests { // ensure the tree and the cloned tree are equal assert_eq!(tree, cloned); } + + #[cfg(feature = "map")] + #[test] + fn test_map() { + use InsertBehavior::*; + + let mut tree = Tree::new(); + let root_id = tree.insert(Node::new(0), AsRoot).unwrap(); + let node_1_id = tree.insert(Node::new(1), UnderNode(&root_id)).unwrap(); + let node_2_id = tree.insert(Node::new(2), UnderNode(&root_id)).unwrap(); + let _node_3_id = tree.insert(Node::new(3), UnderNode(&node_1_id)).unwrap(); + let node_4_id = tree.insert(Node::new(4), UnderNode(&node_2_id)).unwrap(); + tree.take_node(node_4_id); + + let tree_id = tree.id; + + // ensure errors from the mapping function are propagated + assert_eq!(Err(()), tree.clone().map(|_| Err(()) as Result<(), ()>)); + + let mapped = tree + .map(|x| Ok(x.to_string()) as Result) + .unwrap(); + assert!(mapped.root.is_some()); + + // ensure mapped tree has a new id + assert_ne!(tree_id, mapped.id); + + // ensure mapped tree's root is using the new tree id + assert_eq!(mapped.root.as_ref().map(|x| x.tree_id), Some(mapped.id)); + + // ensure mapped tree's free_ids is using the new tree id + assert_eq!(mapped.free_ids[0].tree_id, mapped.id); + + // ensure nodes' parent are using the new tree id + assert_eq!( + mapped.nodes[1] + .as_ref() + .map(|x| x.parent.as_ref().map(|x| x.tree_id)), + Some(Some(mapped.id)) + ); + + // ensure nodes' children are using the new tree id + assert_eq!( + mapped + .children(mapped.root.as_ref().unwrap()) + .unwrap() + .next() + .map(|x| x.parent.as_ref().map(|x| x.tree_id)), + Some(Some(mapped.id)) + ); + + // ensure the mapping is correct + assert_eq!( + mapped + .traverse_level_order(mapped.root.as_ref().unwrap()) + .unwrap() + .map(Node::data) + .collect::>(), + vec![ + &0.to_string(), + &1.to_string(), + &2.to_string(), + &3.to_string() + ] + ); + } }