Skip to content

Commit

Permalink
Initial attempt at only sending events relevant to the worker
Browse files Browse the repository at this point in the history
  • Loading branch information
romac committed Mar 18, 2021
1 parent 652e78f commit 74c7ab1
Showing 1 changed file with 121 additions and 26 deletions.
147 changes: 121 additions & 26 deletions relayer/src/supervisor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use ibc::{
ics24_host::identifier::{ChainId, ChannelId, PortId},
Height,
};
use tracing::info;

use crate::{
chain::handle::ChainHandle,
Expand Down Expand Up @@ -121,28 +122,35 @@ impl Supervisor {

loop {
for batch in subscription_a.try_iter() {
self.process_batch(batch.unwrap_or_clone())?;
self.process_batch(self.chains.a.clone(), batch.unwrap_or_clone())?;
}

for batch in subscription_b.try_iter() {
self.process_batch(batch.unwrap_or_clone())?;
self.process_batch(self.chains.b.clone(), batch.unwrap_or_clone())?;
}

std::thread::sleep(Duration::from_millis(600));
}
}

/// Process a batch of events received from a chain.
fn process_batch(&mut self, batch: EventBatch) -> Result<(), BoxError> {
fn process_batch(
&mut self,
src_chain: Box<dyn ChainHandle>,
batch: EventBatch,
) -> Result<(), BoxError> {
assert_eq!(src_chain.id(), batch.chain_id);

let height = batch.height;
let chain_id = batch.chain_id.clone();

let direction = if chain_id == self.chains.a.id() {
Direction::AtoB
} else {
Direction::BtoA
};

let collected = collect_events(batch);
let collected = collect_events(src_chain.as_ref(), batch);

if collected.has_new_blocks() {
for worker in self.workers.values() {
Expand All @@ -155,8 +163,11 @@ impl Supervisor {
continue;
}

let worker = self.worker_for_object(object, direction);
worker.send_packet_events(height, events, chain_id.clone())?;
println!("[{}] events: {:?}", chain_id, events);

if let Some(worker) = self.worker_for_object(object, direction) {
worker.send_packet_events(height, events, chain_id.clone())?;
}
}

Ok(())
Expand All @@ -169,17 +180,29 @@ impl Supervisor {
///
/// The `direction` parameter indicates in which direction the worker should
/// relay events.
fn worker_for_object(&mut self, object: Object, direction: Direction) -> &WorkerHandle {
fn worker_for_object(&mut self, object: Object, direction: Direction) -> Option<&WorkerHandle> {
if self.workers.contains_key(&object) {
&self.workers[&object]
Some(&self.workers[&object])
} else {
let chains = match direction {
Direction::AtoB => self.chains.clone(),
Direction::BtoA => self.chains.clone().swap(),
};

if object.src_chain_id() != &chains.a.id() || object.dst_chain_id() != &chains.b.id() {
info!(
"object {:?} is not relevant to worker for chains {}/{}",
object,
chains.a.id(),
chains.b.id()
);

return None;
}

let worker = Worker::spawn(chains, object.clone());
self.workers.entry(object).or_insert(worker)
let worker = self.workers.entry(object).or_insert(worker);
Some(worker)
}
}
}
Expand Down Expand Up @@ -251,6 +274,8 @@ impl Worker {
/// A unidirectional path from a source chain, channel and port.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct UnidirectionalChannelPath {
/// Destination chain identifier.
pub dst_chain_id: ChainId,
/// Source chain identifier.
pub src_chain_id: ChainId,
/// Source channel identiier.
Expand Down Expand Up @@ -278,40 +303,70 @@ impl From<UnidirectionalChannelPath> for Object {
}

impl Object {
pub fn src_chain_id(&self) -> &ChainId {
match self {
Self::UnidirectionalChannelPath(ref path) => &path.src_chain_id,
}
}

pub fn dst_chain_id(&self) -> &ChainId {
match self {
Self::UnidirectionalChannelPath(ref path) => &path.dst_chain_id,
}
}

/// Build the object associated with the given [`SendPacket`] event.
pub fn for_send_packet(e: &SendPacket, chain_id: &ChainId) -> Self {
pub fn for_send_packet(e: &SendPacket, src_chain: &dyn ChainHandle) -> Self {
let dst_chain_id =
get_counterparty_chain(src_chain, &e.packet.source_channel, &e.packet.source_port)
.unwrap();

UnidirectionalChannelPath {
src_chain_id: chain_id.clone(),
dst_chain_id,
src_chain_id: src_chain.id(),
src_channel_id: e.packet.source_channel.clone(),
src_port_id: e.packet.source_port.clone(),
}
.into()
}

/// Build the object associated with the given [`WriteAcknowledgement`] event.
pub fn for_write_ack(e: &WriteAcknowledgement, chain_id: &ChainId) -> Self {
pub fn for_write_ack(e: &WriteAcknowledgement, src_chain: &dyn ChainHandle) -> Self {
let dst_chain_id =
get_counterparty_chain(src_chain, &e.packet.source_channel, &e.packet.source_port)
.unwrap();

UnidirectionalChannelPath {
src_chain_id: chain_id.clone(),
dst_chain_id,
src_chain_id: src_chain.id(),
src_channel_id: e.packet.destination_channel.clone(),
src_port_id: e.packet.destination_port.clone(),
}
.into()
}

/// Build the object associated with the given [`TimeoutPacket`] event.
pub fn for_timeout_packet(e: &TimeoutPacket, chain_id: &ChainId) -> Self {
pub fn for_timeout_packet(e: &TimeoutPacket, src_chain: &dyn ChainHandle) -> Self {
let dst_chain_id =
get_counterparty_chain(src_chain, &e.packet.source_channel, &e.packet.source_port)
.unwrap();

UnidirectionalChannelPath {
src_chain_id: chain_id.clone(),
dst_chain_id,
src_chain_id: src_chain.id(),
src_channel_id: e.src_channel_id().clone(),
src_port_id: e.src_port_id().clone(),
}
.into()
}

/// Build the object associated with the given [`CloseInit`] event.
pub fn for_close_init_channel(e: &CloseInit, chain_id: &ChainId) -> Self {
pub fn for_close_init_channel(e: &CloseInit, src_chain: &dyn ChainHandle) -> Self {
let dst_chain_id = get_counterparty_chain(src_chain, e.channel_id(), &e.port_id()).unwrap();

UnidirectionalChannelPath {
src_chain_id: chain_id.clone(),
dst_chain_id,
src_chain_id: src_chain.id(),
src_channel_id: e.channel_id().clone(),
src_port_id: e.port_id().clone(),
}
Expand Down Expand Up @@ -350,27 +405,28 @@ impl CollectedEvents {

/// Collect the events we are interested in from an [`EventBatch`],
/// and maps each [`IbcEvent`] to their corresponding [`Object`].
pub fn collect_events(batch: EventBatch) -> CollectedEvents {
pub fn collect_events(src_chain: &dyn ChainHandle, batch: EventBatch) -> CollectedEvents {
let mut collected = CollectedEvents::new(batch.height, batch.chain_id);

for event in batch.events {
match event {
IbcEvent::NewBlock(inner) => {
collected.new_blocks.push(inner);
}
IbcEvent::SendPacket(ref inner) => {
let object = Object::for_send_packet(inner, &collected.chain_id);
IbcEvent::SendPacket(ref packet) => {
let object = Object::for_send_packet(packet, src_chain);
collected.per_object.entry(object).or_default().push(event);
}
IbcEvent::TimeoutPacket(ref inner) => {
let object = Object::for_timeout_packet(inner, &collected.chain_id);
IbcEvent::TimeoutPacket(ref packet) => {
let object = Object::for_timeout_packet(packet, src_chain);
collected.per_object.entry(object).or_default().push(event);
}
IbcEvent::WriteAcknowledgement(ref inner) => {
let object = Object::for_write_ack(inner, &collected.chain_id);
IbcEvent::WriteAcknowledgement(ref packet) => {
let object = Object::for_write_ack(packet, src_chain);
collected.per_object.entry(object).or_default().push(event);
}
IbcEvent::CloseInitChannel(ref inner) => {
let object = Object::for_close_init_channel(inner, &collected.chain_id);
IbcEvent::CloseInitChannel(ref packet) => {
let object = Object::for_close_init_channel(packet, src_chain);
collected.per_object.entry(object).or_default().push(event);
}
_ => (),
Expand All @@ -379,3 +435,42 @@ pub fn collect_events(batch: EventBatch) -> CollectedEvents {

collected
}

// TODO: Memoize this result
fn get_counterparty_chain(
src_chain: &dyn ChainHandle,
src_channel_id: &ChannelId,
src_port_id: &PortId,
) -> Result<ChainId, BoxError> {
info!(
chain_id = %src_chain.id(),
src_channel_id = %src_channel_id,
src_port_id = %src_port_id,
"getting counterparty chain"
);

use ibc::ics02_client::client_state::ClientState;

let src_channel = src_chain.query_channel(src_port_id, src_channel_id, Height::zero())?;

// TODO: Check channel state?

let src_connection_id = src_channel
.connection_hops()
.first()
.ok_or_else(|| format!("no connection hops for channel '{}'", src_channel_id))?;

let src_connection = src_chain.query_connection(&src_connection_id, Height::zero())?;

// TODO: Check connection state?

let client_id = src_connection.client_id();
let client_state = src_chain.query_client_state(client_id, Height::zero())?;

info!(
chain_id=%src_chain.id(), src_channel_id=%src_channel_id, src_port_id=%src_port_id,
"counterparty chain: {}", client_state.chain_id()
);

Ok(client_state.chain_id())
}

0 comments on commit 74c7ab1

Please sign in to comment.