diff --git a/relayer/src/supervisor.rs b/relayer/src/supervisor.rs index e7c86e44fc..b75c577082 100644 --- a/relayer/src/supervisor.rs +++ b/relayer/src/supervisor.rs @@ -14,6 +14,7 @@ use ibc::{ ics24_host::identifier::{ChainId, ChannelId, PortId}, Height, }; +use tracing::info; use crate::{ chain::handle::ChainHandle, @@ -121,11 +122,11 @@ impl Supervisor { loop { for batch in subscription_a.try_iter() { - self.process_batch(batch.unwrap_or_clone())?; + self.process_batch(self.chains.a.clone(), batch.unwrap_or_clone())?; } for batch in subscription_b.try_iter() { - self.process_batch(batch.unwrap_or_clone())?; + self.process_batch(self.chains.b.clone(), batch.unwrap_or_clone())?; } std::thread::sleep(Duration::from_millis(600)); @@ -133,16 +134,23 @@ impl Supervisor { } /// Process a batch of events received from a chain. - fn process_batch(&mut self, batch: EventBatch) -> Result<(), BoxError> { + fn process_batch( + &mut self, + src_chain: Box, + batch: EventBatch, + ) -> Result<(), BoxError> { + assert_eq!(src_chain.id(), batch.chain_id); + let height = batch.height; let chain_id = batch.chain_id.clone(); + let direction = if chain_id == self.chains.a.id() { Direction::AtoB } else { Direction::BtoA }; - let collected = collect_events(batch); + let collected = collect_events(src_chain.as_ref(), batch); if collected.has_new_blocks() { for worker in self.workers.values() { @@ -155,8 +163,11 @@ impl Supervisor { continue; } - let worker = self.worker_for_object(object, direction); - worker.send_packet_events(height, events, chain_id.clone())?; + println!("[{}] events: {:?}", chain_id, events); + + if let Some(worker) = self.worker_for_object(object, direction) { + worker.send_packet_events(height, events, chain_id.clone())?; + } } Ok(()) @@ -169,17 +180,29 @@ impl Supervisor { /// /// The `direction` parameter indicates in which direction the worker should /// relay events. - fn worker_for_object(&mut self, object: Object, direction: Direction) -> &WorkerHandle { + fn worker_for_object(&mut self, object: Object, direction: Direction) -> Option<&WorkerHandle> { if self.workers.contains_key(&object) { - &self.workers[&object] + Some(&self.workers[&object]) } else { let chains = match direction { Direction::AtoB => self.chains.clone(), Direction::BtoA => self.chains.clone().swap(), }; + if object.src_chain_id() != &chains.a.id() || object.dst_chain_id() != &chains.b.id() { + info!( + "object {:?} is not relevant to worker for chains {}/{}", + object, + chains.a.id(), + chains.b.id() + ); + + return None; + } + let worker = Worker::spawn(chains, object.clone()); - self.workers.entry(object).or_insert(worker) + let worker = self.workers.entry(object).or_insert(worker); + Some(worker) } } } @@ -251,6 +274,8 @@ impl Worker { /// A unidirectional path from a source chain, channel and port. #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct UnidirectionalChannelPath { + /// Destination chain identifier. + pub dst_chain_id: ChainId, /// Source chain identifier. pub src_chain_id: ChainId, /// Source channel identiier. @@ -278,10 +303,27 @@ impl From for Object { } impl Object { + pub fn src_chain_id(&self) -> &ChainId { + match self { + Self::UnidirectionalChannelPath(ref path) => &path.src_chain_id, + } + } + + pub fn dst_chain_id(&self) -> &ChainId { + match self { + Self::UnidirectionalChannelPath(ref path) => &path.dst_chain_id, + } + } + /// Build the object associated with the given [`SendPacket`] event. - pub fn for_send_packet(e: &SendPacket, chain_id: &ChainId) -> Self { + pub fn for_send_packet(e: &SendPacket, src_chain: &dyn ChainHandle) -> Self { + let dst_chain_id = + get_counterparty_chain(src_chain, &e.packet.source_channel, &e.packet.source_port) + .unwrap(); + UnidirectionalChannelPath { - src_chain_id: chain_id.clone(), + dst_chain_id, + src_chain_id: src_chain.id(), src_channel_id: e.packet.source_channel.clone(), src_port_id: e.packet.source_port.clone(), } @@ -289,9 +331,14 @@ impl Object { } /// Build the object associated with the given [`WriteAcknowledgement`] event. - pub fn for_write_ack(e: &WriteAcknowledgement, chain_id: &ChainId) -> Self { + pub fn for_write_ack(e: &WriteAcknowledgement, src_chain: &dyn ChainHandle) -> Self { + let dst_chain_id = + get_counterparty_chain(src_chain, &e.packet.source_channel, &e.packet.source_port) + .unwrap(); + UnidirectionalChannelPath { - src_chain_id: chain_id.clone(), + dst_chain_id, + src_chain_id: src_chain.id(), src_channel_id: e.packet.destination_channel.clone(), src_port_id: e.packet.destination_port.clone(), } @@ -299,9 +346,14 @@ impl Object { } /// Build the object associated with the given [`TimeoutPacket`] event. - pub fn for_timeout_packet(e: &TimeoutPacket, chain_id: &ChainId) -> Self { + pub fn for_timeout_packet(e: &TimeoutPacket, src_chain: &dyn ChainHandle) -> Self { + let dst_chain_id = + get_counterparty_chain(src_chain, &e.packet.source_channel, &e.packet.source_port) + .unwrap(); + UnidirectionalChannelPath { - src_chain_id: chain_id.clone(), + dst_chain_id, + src_chain_id: src_chain.id(), src_channel_id: e.src_channel_id().clone(), src_port_id: e.src_port_id().clone(), } @@ -309,9 +361,12 @@ impl Object { } /// Build the object associated with the given [`CloseInit`] event. - pub fn for_close_init_channel(e: &CloseInit, chain_id: &ChainId) -> Self { + pub fn for_close_init_channel(e: &CloseInit, src_chain: &dyn ChainHandle) -> Self { + let dst_chain_id = get_counterparty_chain(src_chain, e.channel_id(), &e.port_id()).unwrap(); + UnidirectionalChannelPath { - src_chain_id: chain_id.clone(), + dst_chain_id, + src_chain_id: src_chain.id(), src_channel_id: e.channel_id().clone(), src_port_id: e.port_id().clone(), } @@ -350,27 +405,28 @@ impl CollectedEvents { /// Collect the events we are interested in from an [`EventBatch`], /// and maps each [`IbcEvent`] to their corresponding [`Object`]. -pub fn collect_events(batch: EventBatch) -> CollectedEvents { +pub fn collect_events(src_chain: &dyn ChainHandle, batch: EventBatch) -> CollectedEvents { let mut collected = CollectedEvents::new(batch.height, batch.chain_id); + for event in batch.events { match event { IbcEvent::NewBlock(inner) => { collected.new_blocks.push(inner); } - IbcEvent::SendPacket(ref inner) => { - let object = Object::for_send_packet(inner, &collected.chain_id); + IbcEvent::SendPacket(ref packet) => { + let object = Object::for_send_packet(packet, src_chain); collected.per_object.entry(object).or_default().push(event); } - IbcEvent::TimeoutPacket(ref inner) => { - let object = Object::for_timeout_packet(inner, &collected.chain_id); + IbcEvent::TimeoutPacket(ref packet) => { + let object = Object::for_timeout_packet(packet, src_chain); collected.per_object.entry(object).or_default().push(event); } - IbcEvent::WriteAcknowledgement(ref inner) => { - let object = Object::for_write_ack(inner, &collected.chain_id); + IbcEvent::WriteAcknowledgement(ref packet) => { + let object = Object::for_write_ack(packet, src_chain); collected.per_object.entry(object).or_default().push(event); } - IbcEvent::CloseInitChannel(ref inner) => { - let object = Object::for_close_init_channel(inner, &collected.chain_id); + IbcEvent::CloseInitChannel(ref packet) => { + let object = Object::for_close_init_channel(packet, src_chain); collected.per_object.entry(object).or_default().push(event); } _ => (), @@ -379,3 +435,42 @@ pub fn collect_events(batch: EventBatch) -> CollectedEvents { collected } + +// TODO: Memoize this result +fn get_counterparty_chain( + src_chain: &dyn ChainHandle, + src_channel_id: &ChannelId, + src_port_id: &PortId, +) -> Result { + info!( + chain_id = %src_chain.id(), + src_channel_id = %src_channel_id, + src_port_id = %src_port_id, + "getting counterparty chain" + ); + + use ibc::ics02_client::client_state::ClientState; + + let src_channel = src_chain.query_channel(src_port_id, src_channel_id, Height::zero())?; + + // TODO: Check channel state? + + let src_connection_id = src_channel + .connection_hops() + .first() + .ok_or_else(|| format!("no connection hops for channel '{}'", src_channel_id))?; + + let src_connection = src_chain.query_connection(&src_connection_id, Height::zero())?; + + // TODO: Check connection state? + + let client_id = src_connection.client_id(); + let client_state = src_chain.query_client_state(client_id, Height::zero())?; + + info!( + chain_id=%src_chain.id(), src_channel_id=%src_channel_id, src_port_id=%src_port_id, + "counterparty chain: {}", client_state.chain_id() + ); + + Ok(client_state.chain_id()) +}