Skip to content

Commit

Permalink
feat: add push unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hydra-yse committed Nov 19, 2024
1 parent 8d2fa73 commit 01c49df
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 11 deletions.
142 changes: 134 additions & 8 deletions lib/core/src/sync/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,16 @@ impl SyncService {
#[cfg(test)]
mod tests {
use anyhow::{anyhow, Result};
use std::sync::Arc;
use tokio::sync::mpsc;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::{mpsc, Mutex};

use crate::{
prelude::Signer,
persist::Persister,
prelude::{Direction, PaymentState, Signer},
sync::model::SyncState,
test_utils::{
persist::new_persister,
chain_swap::new_chain_swap,
persist::{new_persister, new_receive_swap, new_send_swap},
sync::{
new_chain_sync_data, new_receive_sync_data, new_send_sync_data, MockSyncerClient,
},
Expand Down Expand Up @@ -332,7 +335,8 @@ mod tests {
];

let (incoming_tx, incoming_rx) = mpsc::channel::<Record>(10);
let client = Box::new(MockSyncerClient::new(incoming_rx));
let outgoing_records = Arc::new(Mutex::new(HashMap::new()));
let client = Box::new(MockSyncerClient::new(incoming_rx, outgoing_records.clone()));
let sync_service =
SyncService::new("".to_string(), persister.clone(), signer.clone(), client);

Expand Down Expand Up @@ -367,7 +371,6 @@ mod tests {
let new_description = Some("description".to_string());
let new_claim_address = Some("claim_address".to_string());
let new_accept_zero_conf = false;
let new_server_lockup_tx_id = Some("server_lockup_tx_id".to_string());
let sync_data = vec![
SyncData::Receive(new_receive_sync_data(
new_payment_hash.clone(),
Expand All @@ -382,7 +385,6 @@ mod tests {
new_description.clone(),
new_claim_address.clone(),
Some(new_accept_zero_conf),
new_server_lockup_tx_id.clone(),
)),
];
let incoming_records = vec![
Expand Down Expand Up @@ -428,11 +430,135 @@ mod tests {
assert_eq!(chain_swap.claim_address, new_claim_address);
assert_eq!(chain_swap.description, new_description);
assert_eq!(chain_swap.accept_zero_conf, new_accept_zero_conf);
assert_eq!(chain_swap.server_lockup_tx_id, new_server_lockup_tx_id);
} else {
return Err(anyhow!("Chain swap not found"));
}

Ok(())
}

fn get_outgoing_record<'a>(
persister: Arc<Persister>,
outgoing: &'a HashMap<String, Record>,
data_id: &'a str,
) -> Result<&'a Record> {
let sync_state = persister
.get_sync_state_by_data_id(data_id)?
.ok_or(anyhow::anyhow!("Expected existing swap state"))?;
let Some(record) = outgoing.get(&sync_state.record_id) else {
return Err(anyhow::anyhow!(
"Expecting existing record in client's outgoing list"
));
};
Ok(record)
}

#[tokio::test]
async fn test_outgoing_sync() -> Result<()> {
let (_temp_dir, persister) = new_persister()?;
let persister = Arc::new(persister);

let signer: Arc<Box<dyn Signer>> = Arc::new(Box::new(MockSigner::new()));

let (_incoming_tx, incoming_rx) = mpsc::channel::<Record>(10);
let outgoing_records = Arc::new(Mutex::new(HashMap::new()));
let client = Box::new(MockSyncerClient::new(incoming_rx, outgoing_records.clone()));
let sync_service =
SyncService::new("".to_string(), persister.clone(), signer.clone(), client);

// Test insert
persister.insert_receive_swap(&new_receive_swap(None))?;
persister.insert_send_swap(&new_send_swap(None))?;
persister.insert_chain_swap(&new_chain_swap(Direction::Incoming, None, true, None))?;

sync_service.push().await?;

let outgoing = outgoing_records.lock().await;
assert_eq!(outgoing.len(), 3);
drop(outgoing);

// Test conflict
let swap = new_receive_swap(None);
persister.insert_receive_swap(&swap)?;

sync_service.push().await?;

let outgoing = outgoing_records.lock().await;
assert_eq!(outgoing.len(), 4);
let record = get_outgoing_record(persister.clone(), &outgoing, &swap.id)?;
persister.set_sync_state(SyncState {
data_id: swap.id.clone(),
record_id: record.id.clone(),
record_revision: 90, // Set a wrong record revision
is_local: true,
})?;
drop(outgoing);

sync_service.push().await?;

let outgoing = outgoing_records.lock().await;
assert_eq!(outgoing.len(), 4); // No records were added
drop(outgoing);

// Test update before push
let swap = new_send_swap(None);
persister.insert_send_swap(&swap)?;
let new_preimage = Some("new-preimage");
persister.try_handle_send_swap_update(
&swap.id,
PaymentState::Pending,
new_preimage.clone(),
None,
None,
)?;

sync_service.push().await?;

let outgoing = outgoing_records.lock().await;

let record = get_outgoing_record(persister.clone(), &outgoing, &swap.id)?;
let decrypted_record = record.clone().decrypt(signer.clone())?;
assert_eq!(decrypted_record.data.id(), &swap.id);
match decrypted_record.data {
SyncData::Send(data) => {
assert_eq!(data.preimage, new_preimage.map(|p| p.to_string()));
}
_ => {
return Err(anyhow::anyhow!("Unexpected sync data type received."));
}
}
drop(outgoing);

// Test update after push
let swap = new_send_swap(None);
persister.insert_send_swap(&swap)?;

sync_service.push().await?;

let new_preimage = Some("new-preimage");
persister.try_handle_send_swap_update(
&swap.id,
PaymentState::Pending,
new_preimage.clone(),
None,
None,
)?;

sync_service.push().await?;

let outgoing = outgoing_records.lock().await;
let record = get_outgoing_record(persister.clone(), &outgoing, &swap.id)?;
let decrypted_record = record.clone().decrypt(signer.clone())?;
assert_eq!(decrypted_record.data.id(), &swap.id);
match decrypted_record.data {
SyncData::Send(data) => {
assert_eq!(data.preimage, new_preimage.map(|p| p.to_string()),);
}
_ => {
return Err(anyhow::anyhow!("Unexpected sync data type received."));
}
}

Ok(())
}
}
36 changes: 33 additions & 3 deletions lib/core/src/test_utils/sync.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#![cfg(test)]

use std::{collections::HashMap, sync::Arc};

use crate::{
prelude::Direction,
sync::{
Expand All @@ -8,6 +10,7 @@ use crate::{
data::{ChainSyncData, ReceiveSyncData, SendSyncData},
sync::{
ListChangesReply, ListChangesRequest, Record, SetRecordReply, SetRecordRequest,
SetRecordStatus,
},
},
},
Expand All @@ -18,12 +21,17 @@ use tokio::sync::{mpsc::Receiver, Mutex};

pub(crate) struct MockSyncerClient {
pub(crate) incoming_rx: Mutex<Receiver<Record>>,
pub(crate) outgoing_records: Arc<Mutex<HashMap<String, Record>>>,
}

impl MockSyncerClient {
pub(crate) fn new(incoming_rx: Receiver<Record>) -> Self {
pub(crate) fn new(
incoming_rx: Receiver<Record>,
outgoing_records: Arc<Mutex<HashMap<String, Record>>>,
) -> Self {
Self {
incoming_rx: Mutex::new(incoming_rx),
outgoing_records,
}
}
}
Expand All @@ -34,8 +42,30 @@ impl SyncerClient for MockSyncerClient {
todo!()
}

async fn push(&self, _req: SetRecordRequest) -> Result<SetRecordReply> {
todo!()
async fn push(&self, req: SetRecordRequest) -> Result<SetRecordReply> {
if let Some(mut record) = req.record {
let mut outgoing_records = self.outgoing_records.lock().await;

if let Some(existing_record) = outgoing_records.get(&record.id) {
if existing_record.revision != record.revision {
return Ok(SetRecordReply {
status: SetRecordStatus::Conflict as i32,
new_revision: 0,
});
}
}

record.revision = outgoing_records.len() as u64 + 1;
let record_revision = record.revision;

outgoing_records.insert(record.id.clone(), record);
return Ok(SetRecordReply {
status: SetRecordStatus::Success as i32,
new_revision: record_revision,
});
}

return Err(anyhow::anyhow!("No record was sent"));
}

async fn pull(&self, _req: ListChangesRequest) -> Result<ListChangesReply> {
Expand Down

0 comments on commit 01c49df

Please sign in to comment.