diff --git a/crates/chain/src/local_chain.rs b/crates/chain/src/local_chain.rs index 100a9662c5..a86f1a77dc 100644 --- a/crates/chain/src/local_chain.rs +++ b/crates/chain/src/local_chain.rs @@ -187,6 +187,43 @@ impl CheckPoint { core::ops::Bound::Unbounded => true, }) } + + /// Inserts `block_id` at its height within the chain. + /// + /// The effect of `insert` depends on whether a height already exists. If it doesn't the + /// `block_id` we inserted and all pre-existing blocks higher than it will be re-inserted after + /// it. If the height already existed and has a conflicting block hash then it will be purged + /// along with all block followin it. The returned chain will have a tip of the `block_id` + /// passed in. Of course, if the `block_id` was already present then this just returns `self`. + #[must_use] + pub fn insert(self, block_id: BlockId) -> Self { + assert_ne!(block_id.height, 0, "cannot insert the genesis block"); + + let mut cp = self.clone(); + let mut tail = vec![]; + let base = loop { + if cp.height() == block_id.height { + if cp.hash() == block_id.hash { + return self; + } + // if we have a conflict we just return the inserted block because the tail is by + // implication invalid. + tail = vec![]; + break cp.prev().expect("can't be called on genesis block"); + } + + if cp.height() < block_id.height { + break cp; + } + + tail.push(cp.block_id()); + cp = cp.prev().expect("will break before genesis block"); + }; + + base + .extend(core::iter::once(block_id).chain(tail.into_iter().rev())) + .expect("tail is in order") + } } /// Iterates over checkpoints backwards. diff --git a/crates/chain/tests/test_local_chain.rs b/crates/chain/tests/test_local_chain.rs index 482792f501..636b7a4b7f 100644 --- a/crates/chain/tests/test_local_chain.rs +++ b/crates/chain/tests/test_local_chain.rs @@ -574,6 +574,77 @@ fn checkpoint_query() { } } +#[test] +fn checkpoint_insert() { + struct TestCase<'a> { + /// The name of the test. + name: &'a str, + /// The original checkpoint chain to call [`CheckPoint::insert`] on. + chain: &'a [(u32, BlockHash)], + /// The `block_id` to insert. + to_insert: (u32, BlockHash), + /// The expected final checkpoint chain after calling [`CheckPoint::insert`]. + exp_final_chain: &'a [(u32, BlockHash)], + } + + let test_cases = [ + TestCase { + name: "insert_above_tip", + chain: &[(1, h!("a")), (2, h!("b"))], + to_insert: (4, h!("d")), + exp_final_chain: &[(1, h!("a")), (2, h!("b")), (4, h!("d"))], + }, + TestCase { + name: "insert_already_exists_expect_no_change", + chain: &[(1, h!("a")), (2, h!("b")), (3, h!("c"))], + to_insert: (2, h!("b")), + exp_final_chain: &[(1, h!("a")), (2, h!("b")), (3, h!("c"))], + }, + TestCase { + name: "insert_in_middle", + chain: &[(2, h!("b")), (4, h!("d")), (5, h!("e"))], + to_insert: (3, h!("c")), + exp_final_chain: &[(2, h!("b")), (3, h!("c")), (4, h!("d")), (5, h!("e"))], + }, + TestCase { + name: "replace_one", + chain: &[(3, h!("c")), (4, h!("d")), (5, h!("e"))], + to_insert: (5, h!("E")), + exp_final_chain: &[(3, h!("c")), (4, h!("d")), (5, h!("E"))], + }, + TestCase { + name: "insert_conflict_should_evict", + chain: &[(3, h!("c")), (4, h!("d")), (5, h!("e")), (6, h!("f"))], + to_insert: (4, h!("D")), + exp_final_chain: &[(3, h!("c")), (4, h!("D"))], + }, + ]; + + fn genesis_block() -> impl Iterator { + core::iter::once((0, h!("_"))).map(BlockId::from) + } + + for (i, t) in test_cases.into_iter().enumerate() { + println!("Running [{}] '{}'", i, t.name); + + let chain = CheckPoint::from_block_ids( + genesis_block().chain(t.chain.iter().copied().map(BlockId::from)), + ) + .expect("test formed incorrectly, must construct checkpoint chain"); + + let exp_final_chain = CheckPoint::from_block_ids( + genesis_block().chain(t.exp_final_chain.iter().copied().map(BlockId::from)), + ) + .expect("test formed incorrectly, must construct checkpoint chain"); + + assert_eq!( + chain.insert(t.to_insert.into()), + exp_final_chain, + "unexpected final chain" + ); + } +} + #[test] fn local_chain_apply_header_connected_to() { fn header_from_prev_blockhash(prev_blockhash: BlockHash) -> Header {