From a191434d25e699c47f683d24eaa438be0ed05fef Mon Sep 17 00:00:00 2001 From: Erin Power Date: Tue, 25 Oct 2022 18:31:59 +0200 Subject: [PATCH] Refactor filter model from Context->Response into ref mut Context --- docs/src/filters/writing_custom_filters.md | 8 +- examples/quilkin-filter-example/src/main.rs | 8 +- src/config/slot.rs | 4 +- src/filters.rs | 35 ++--- src/filters/capture.rs | 37 ++--- src/filters/chain.rs | 165 +++++++++----------- src/filters/compress.rs | 142 +++++++++-------- src/filters/concatenate_bytes.rs | 8 +- src/filters/debug.rs | 8 +- src/filters/drop.rs | 4 +- src/filters/firewall.rs | 28 ++-- src/filters/load_balancer.rs | 22 +-- src/filters/local_rate_limit.rs | 18 +-- src/filters/match.rs | 29 ++-- src/filters/pass.rs | 8 +- src/filters/read.rs | 42 +---- src/filters/registry.rs | 12 +- src/filters/timestamp.rs | 18 +-- src/filters/token_router.rs | 23 ++- src/filters/write.rs | 55 +------ src/proxy.rs | 18 +-- src/proxy/sessions.rs | 22 +-- src/test_utils.rs | 38 ++--- 23 files changed, 320 insertions(+), 432 deletions(-) diff --git a/docs/src/filters/writing_custom_filters.md b/docs/src/filters/writing_custom_filters.md index d352d10dac..4da7b33d58 100644 --- a/docs/src/filters/writing_custom_filters.md +++ b/docs/src/filters/writing_custom_filters.md @@ -40,13 +40,13 @@ sent to a downstream client. use quilkin::filters::prelude::*; impl Filter for Greet { - fn read(&self, mut ctx: ReadContext) -> Option { + fn read(&self, ctx: &mut ReadContext) -> Option<()> { ctx.contents.extend(b"Hello"); - Some(ctx.into()) + Some(()) } - fn write(&self, mut ctx: WriteContext) -> Option { + fn write(&self, ctx: &mut WriteContext) -> Option<()> { ctx.contents.extend(b"Goodbye"); - Some(ctx.into()) + Some(()) } } ``` diff --git a/examples/quilkin-filter-example/src/main.rs b/examples/quilkin-filter-example/src/main.rs index 6b18ada422..e523d48724 100644 --- a/examples/quilkin-filter-example/src/main.rs +++ b/examples/quilkin-filter-example/src/main.rs @@ -58,15 +58,15 @@ struct Greet { } impl Filter for Greet { - fn read(&self, mut ctx: ReadContext) -> Option { + fn read(&self, ctx: &mut ReadContext) -> Option<()> { ctx.contents .splice(0..0, format!("{} ", self.config.greeting).into_bytes()); - Some(ctx.into()) + Some(()) } - fn write(&self, mut ctx: WriteContext) -> Option { + fn write(&self, ctx: &mut WriteContext) -> Option<()> { ctx.contents .splice(0..0, format!("{} ", self.config.greeting).into_bytes()); - Some(ctx.into()) + Some(()) } } // ANCHOR_END: filter diff --git a/src/config/slot.rs b/src/config/slot.rs index 7472bcbd33..6c43517562 100644 --- a/src/config/slot.rs +++ b/src/config/slot.rs @@ -186,11 +186,11 @@ impl JsonSchema for Slot { } impl crate::filters::Filter for Slot { - fn read(&self, ctx: ReadContext) -> Option { + fn read(&self, ctx: &mut ReadContext) -> Option<()> { self.load().read(ctx) } - fn write(&self, ctx: WriteContext) -> Option { + fn write(&self, ctx: &mut WriteContext) -> Option<()> { self.load().write(ctx) } } diff --git a/src/filters.rs b/src/filters.rs index 5b4e1254d6..fbecb071ad 100644 --- a/src/filters.rs +++ b/src/filters.rs @@ -43,7 +43,7 @@ pub mod token_router; pub mod prelude { pub use super::{ ConvertProtoConfigError, CreateFilterArgs, Error, Filter, FilterInstance, ReadContext, - ReadResponse, StaticFilter, WriteContext, WriteResponse, + StaticFilter, WriteContext, }; } @@ -62,12 +62,12 @@ pub use self::{ local_rate_limit::LocalRateLimit, pass::Pass, r#match::Match, - read::{ReadContext, ReadResponse}, + read::ReadContext, registry::FilterRegistry, set::{FilterMap, FilterSet}, timestamp::Timestamp, token_router::TokenRouter, - write::{WriteContext, WriteResponse}, + write::WriteContext, }; pub(crate) use self::chain::FilterChain; @@ -83,13 +83,13 @@ pub(crate) use self::chain::FilterChain; /// struct Greet; /// /// impl Filter for Greet { -/// fn read(&self, mut ctx: ReadContext) -> Option { +/// fn read(&self, ctx: &mut ReadContext) -> Option<()> { /// ctx.contents.splice(0..0, b"Hello ".into_iter().copied()); -/// Some(ctx.into()) +/// Some(()) /// } -/// fn write(&self, mut ctx: WriteContext) -> Option { +/// fn write(&self, ctx: &mut WriteContext) -> Option<()> { /// ctx.contents.splice(0..0, b"Goodbye ".into_iter().copied()); -/// Some(ctx.into()) +/// Some(()) /// } /// } /// @@ -194,23 +194,20 @@ pub trait Filter: Send + Sync { /// [`Filter::read`] is invoked when the proxy receives data from a /// downstream connection on the listening port. /// - /// This function should return a [`ReadResponse`] containing the array of - /// endpoints that the packet should be sent to and the packet that should - /// be sent (which may be manipulated) as well. If the packet should be - /// rejected, return [`None`]. By default, the context passes - /// through unchanged. - fn read(&self, ctx: ReadContext) -> Option { - Some(ctx.into()) + /// This function should return an `Some` if the packet processing should + /// proceed. If the packet should be rejected, it will return [`None`] + /// instead. By default, the context passes through unchanged. + fn read(&self, _: &mut ReadContext) -> Option<()> { + Some(()) } /// [`Filter::write`] is invoked when the proxy is about to send data to a /// downstream connection via the listening port after receiving it via one /// of the upstream Endpoints. /// - /// This function should return an [`WriteResponse`] containing the packet to - /// be sent (which may be manipulated). If the packet should be rejected, - /// return [`None`]. By default, the context passes through unchanged. - fn write(&self, ctx: WriteContext) -> Option { - Some(ctx.into()) + /// This function should return an `Some` if the packet processing should + /// proceed. If the packet should be rejected, it will return [`None`] + fn write(&self, _: &mut WriteContext) -> Option<()> { + Some(()) } } diff --git a/src/filters/capture.rs b/src/filters/capture.rs index e87ccba485..98252c4959 100644 --- a/src/filters/capture.rs +++ b/src/filters/capture.rs @@ -61,7 +61,7 @@ impl Capture { impl Filter for Capture { #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] - fn read(&self, mut ctx: ReadContext) -> Option { + fn read(&self, ctx: &mut ReadContext) -> Option<()> { let capture = self.capture.capture(&mut ctx.contents, &self.metrics); ctx.metadata .insert(self.is_present_key.clone(), Value::Bool(capture.is_some())); @@ -69,7 +69,7 @@ impl Filter for Capture { if let Some(value) = capture { tracing::trace!(key=&**self.metadata_key, %value, "captured value"); ctx.metadata.insert(self.metadata_key.clone(), value); - Some(ctx.into()) + Some(()) } else { tracing::trace!(key = &**self.metadata_key, "No value captured"); None @@ -163,13 +163,14 @@ mod tests { }; let filter = Capture::from_config(config.into()); let endpoints = vec![Endpoint::new("127.0.0.1:81".parse().unwrap())]; - let response = filter.read(ReadContext::new( - endpoints, - (std::net::Ipv4Addr::LOCALHOST, 80).into(), - "abc".to_string().into_bytes(), - )); + assert!(filter + .read(&mut ReadContext::new( + endpoints, + (std::net::Ipv4Addr::LOCALHOST, 80).into(), + "abc".to_string().into_bytes(), + )) + .is_none()); - assert!(response.is_none()); let count = filter.metrics.packets_dropped_total.get(); assert_eq!(1, count); } @@ -243,21 +244,21 @@ mod tests { F: Filter + ?Sized, { let endpoints = vec![Endpoint::new("127.0.0.1:81".parse().unwrap())]; - let response = filter - .read(ReadContext::new( - endpoints, - "127.0.0.1:80".parse().unwrap(), - "helloabc".to_string().into_bytes(), - )) - .unwrap(); + let mut context = ReadContext::new( + endpoints, + "127.0.0.1:80".parse().unwrap(), + "helloabc".to_string().into_bytes(), + ); + + filter.read(&mut context).unwrap(); if remove { - assert_eq!(b"hello".to_vec(), response.contents); + assert_eq!(b"hello", &*context.contents); } else { - assert_eq!(b"helloabc".to_vec(), response.contents); + assert_eq!(b"helloabc", &*context.contents); } - let token = response + let token = context .metadata .get(&Arc::new(key.into())) .unwrap() diff --git a/src/filters/chain.rs b/src/filters/chain.rs index 1029c451c5..ba5facdb3e 100644 --- a/src/filters/chain.rs +++ b/src/filters/chain.rs @@ -246,65 +246,56 @@ impl schemars::JsonSchema for FilterChain { } impl Filter for FilterChain { - fn read(&self, ctx: ReadContext) -> Option { + fn read(&self, ctx: &mut ReadContext) -> Option<()> { self.filters .iter() .zip(self.filter_read_duration_seconds.iter()) - .try_fold(ctx, |ctx, ((id, instance), histogram)| { + .try_fold((), |_, ((id, instance), histogram)| { tracing::trace!(%id, "read filtering packet"); - Some(ReadContext::with_response( - ctx.source.clone(), - match histogram.observe_closure_duration(|| instance.filter.read(ctx)) { - Some(response) => { - tracing::trace!(%id, "read passing packet"); - response - } - None => { - tracing::trace!(%id, "read dropping packet"); - crate::metrics::PACKETS_DROPPED - .with_label_values(&[crate::metrics::READ_DIRECTION_LABEL, id]) - .inc(); - return None; - } - }, - )) + match histogram.observe_closure_duration(|| instance.filter.read(ctx)) { + Some(()) => { + tracing::trace!(%id, "read passing packet"); + } + None => { + tracing::trace!(%id, "read dropping packet"); + crate::metrics::PACKETS_DROPPED + .with_label_values(&[crate::metrics::READ_DIRECTION_LABEL, id]) + .inc(); + return None; + } + } + + Some(()) }) - .map(ReadResponse::from) } - fn write(&self, ctx: WriteContext) -> Option { + fn write(&self, ctx: &mut WriteContext) -> Option<()> { self.filters .iter() .rev() .zip(self.filter_write_duration_seconds.iter().rev()) - .try_fold(ctx, |ctx, ((id, instance), histogram)| { + .try_fold((), |_, ((id, instance), histogram)| { tracing::trace!(%id, "write filtering packet"); - Some(WriteContext::with_response( - ctx.endpoint, - ctx.source.clone(), - ctx.dest.clone(), - match histogram.observe_closure_duration(|| instance.filter.write(ctx)) { - Some(response) => { - tracing::trace!(%id, "write passing packet"); - response - } - None => { - tracing::trace!(%id, "write dropping packet"); - crate::metrics::PACKETS_DROPPED - .with_label_values(&[crate::metrics::WRITE_DIRECTION_LABEL, id]) - .inc(); - return None; - } - }, - )) + match histogram.observe_closure_duration(|| instance.filter.write(ctx)) { + Some(()) => { + tracing::trace!(%id, "write passing packet"); + Some(()) + } + None => { + tracing::trace!(%id, "write dropping packet"); + crate::metrics::PACKETS_DROPPED + .with_label_values(&[crate::metrics::WRITE_DIRECTION_LABEL, id]) + .inc(); + None + } + } }) - .map(WriteResponse::from) } } #[cfg(test)] mod tests { - use std::{str::from_utf8, sync::Arc}; + use std::sync::Arc; use crate::{ config, @@ -349,49 +340,39 @@ mod tests { crate::test_utils::load_test_filters(); let config = new_test_config(); let endpoints_fixture = endpoints(); + let mut context = ReadContext::new( + endpoints_fixture.clone(), + "127.0.0.1:70".parse().unwrap(), + b"hello".to_vec(), + ); - let response = config - .filters - .read(ReadContext::new( - endpoints_fixture.clone(), - "127.0.0.1:70".parse().unwrap(), - b"hello".to_vec(), - )) - .unwrap(); - + config.filters.read(&mut context).unwrap(); let expected = endpoints_fixture.clone(); - assert_eq!(expected, response.endpoints.to_vec()); - assert_eq!( - "hello:odr:127.0.0.1:70", - from_utf8(response.contents.as_slice()).unwrap() - ); + + assert_eq!(expected, &*context.endpoints); + assert_eq!(b"hello:odr:127.0.0.1:70", &*context.contents); assert_eq!( "receive", - response.metadata[&"downstream".to_string()] + context.metadata[&"downstream".to_string()] .as_string() .unwrap() ); - let response = config - .filters - .write(WriteContext::new( - &endpoints_fixture[0], - endpoints_fixture[0].address.clone(), - "127.0.0.1:70".parse().unwrap(), - b"hello".to_vec(), - )) - .unwrap(); + let mut context = WriteContext::new( + endpoints_fixture[0].clone(), + endpoints_fixture[0].address.clone(), + "127.0.0.1:70".parse().unwrap(), + b"hello".to_vec(), + ); + config.filters.write(&mut context).unwrap(); assert_eq!( "receive", - response.metadata[&"upstream".to_string()] + context.metadata[&"upstream".to_string()] .as_string() .unwrap() ); - assert_eq!( - "hello:our:127.0.0.1:80:127.0.0.1:70", - from_utf8(response.contents.as_slice()).unwrap() - ); + assert_eq!(b"hello:our:127.0.0.1:80:127.0.0.1:70", &*context.contents,); } #[test] @@ -415,43 +396,41 @@ mod tests { .unwrap(); let endpoints_fixture = endpoints(); + let mut context = ReadContext::new( + endpoints_fixture.clone(), + "127.0.0.1:70".parse().unwrap(), + b"hello".to_vec(), + ); - let response = chain - .read(ReadContext::new( - endpoints_fixture.clone(), - "127.0.0.1:70".parse().unwrap(), - b"hello".to_vec(), - )) - .unwrap(); - + chain.read(&mut context).unwrap(); let expected = endpoints_fixture.clone(); - assert_eq!(expected, response.endpoints.to_vec()); + assert_eq!(expected, context.endpoints.to_vec()); assert_eq!( - "hello:odr:127.0.0.1:70:odr:127.0.0.1:70", - from_utf8(response.contents.as_slice()).unwrap() + b"hello:odr:127.0.0.1:70:odr:127.0.0.1:70", + &*context.contents ); assert_eq!( "receive:receive", - response.metadata[&"downstream".to_string()] + context.metadata[&"downstream".to_string()] .as_string() .unwrap() ); - let response = chain - .write(WriteContext::new( - &endpoints_fixture[0], - endpoints_fixture[0].address.clone(), - "127.0.0.1:70".parse().unwrap(), - b"hello".to_vec(), - )) - .unwrap(); + let mut context = WriteContext::new( + endpoints_fixture[0].clone(), + endpoints_fixture[0].address.clone(), + "127.0.0.1:70".parse().unwrap(), + b"hello".to_vec(), + ); + + chain.write(&mut context).unwrap(); assert_eq!( - "hello:our:127.0.0.1:80:127.0.0.1:70:our:127.0.0.1:80:127.0.0.1:70", - from_utf8(response.contents.as_slice()).unwrap() + b"hello:our:127.0.0.1:80:127.0.0.1:70:our:127.0.0.1:80:127.0.0.1:70", + &*context.contents, ); assert_eq!( "receive:receive", - response.metadata[&"upstream".to_string()] + context.metadata[&"upstream".to_string()] .as_string() .unwrap() ); diff --git a/src/filters/compress.rs b/src/filters/compress.rs index 9cc610d2fd..372fbe1a14 100644 --- a/src/filters/compress.rs +++ b/src/filters/compress.rs @@ -72,7 +72,7 @@ impl Compress { impl Filter for Compress { #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] - fn read(&self, mut ctx: ReadContext) -> Option { + fn read(&self, ctx: &mut ReadContext) -> Option<()> { let original_size = ctx.contents.len(); match self.on_read { @@ -84,7 +84,7 @@ impl Filter for Compress { self.metrics .compressed_bytes_total .inc_by(ctx.contents.len() as u64); - Some(ctx.into()) + Some(()) } Err(err) => self.failed_compression(&err), }, @@ -96,16 +96,16 @@ impl Filter for Compress { self.metrics .decompressed_bytes_total .inc_by(ctx.contents.len() as u64); - Some(ctx.into()) + Some(()) } Err(err) => self.failed_decompression(&err), }, - Action::DoNothing => Some(ctx.into()), + Action::DoNothing => Some(()), } } #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] - fn write(&self, mut ctx: WriteContext) -> Option { + fn write(&self, ctx: &mut WriteContext) -> Option<()> { let original_size = ctx.contents.len(); match self.on_write { Action::Compress => match self.compressor.encode(&mut ctx.contents) { @@ -116,7 +116,7 @@ impl Filter for Compress { self.metrics .compressed_bytes_total .inc_by(ctx.contents.len() as u64); - Some(ctx.into()) + Some(()) } Err(err) => self.failed_compression(&err), }, @@ -128,12 +128,12 @@ impl Filter for Compress { self.metrics .decompressed_bytes_total .inc_by(ctx.contents.len() as u64); - Some(ctx.into()) + Some(()) } Err(err) => self.failed_decompression(&err), }, - Action::DoNothing => Some(ctx.into()), + Action::DoNothing => Some(()), } } } @@ -281,47 +281,48 @@ mod tests { let expected = contents_fixture(); // read compress - let read_response = compress - .read(ReadContext::new( - vec![Endpoint::new("127.0.0.1:80".parse().unwrap())], - "127.0.0.1:8080".parse().unwrap(), - expected.clone(), - )) - .expect("should compress"); + let mut read_context = ReadContext::new( + vec![Endpoint::new("127.0.0.1:80".parse().unwrap())], + "127.0.0.1:8080".parse().unwrap(), + expected.clone(), + ); + compress.read(&mut read_context).expect("should compress"); - assert_ne!(expected, read_response.contents); + assert_ne!(expected, &*read_context.contents); assert!( - expected.len() > read_response.contents.len(), + expected.len() > read_context.contents.len(), "Original: {}. Compressed: {}", expected.len(), - read_response.contents.len() + read_context.contents.len() ); assert_eq!( expected.len() as u64, compress.metrics.decompressed_bytes_total.get() ); assert_eq!( - read_response.contents.len() as u64, + read_context.contents.len() as u64, compress.metrics.compressed_bytes_total.get() ); // write decompress - let write_response = compress - .write(WriteContext::new( - &Endpoint::new("127.0.0.1:80".parse().unwrap()), - "127.0.0.1:8080".parse().unwrap(), - "127.0.0.1:8081".parse().unwrap(), - read_response.contents.clone(), - )) + let mut write_context = WriteContext::new( + Endpoint::new("127.0.0.1:80".parse().unwrap()), + "127.0.0.1:8080".parse().unwrap(), + "127.0.0.1:8081".parse().unwrap(), + read_context.contents.clone(), + ); + + compress + .write(&mut write_context) .expect("should decompress"); - assert_eq!(expected, write_response.contents); + assert_eq!(expected, &*write_context.contents); assert_eq!(0, compress.metrics.packets_dropped_decompress.get()); assert_eq!(0, compress.metrics.packets_dropped_compress.get()); // multiply by two, because data was sent both upstream and downstream assert_eq!( - (read_response.contents.len() * 2) as u64, + (read_context.contents.len() * 2) as u64, compress.metrics.compressed_bytes_total.get() ); assert_eq!( @@ -369,14 +370,15 @@ mod tests { Metrics::new().unwrap(), ); - let write_response = compression.write(WriteContext::new( - &Endpoint::new("127.0.0.1:80".parse().unwrap()), - "127.0.0.1:8080".parse().unwrap(), - "127.0.0.1:8081".parse().unwrap(), - b"hello".to_vec(), - )); + assert!(compression + .write(&mut WriteContext::new( + Endpoint::new("127.0.0.1:80".parse().unwrap()), + "127.0.0.1:8080".parse().unwrap(), + "127.0.0.1:8081".parse().unwrap(), + b"hello".to_vec(), + )) + .is_none()); - assert!(write_response.is_none()); assert_eq!(1, compression.metrics.packets_dropped_decompress.get()); assert_eq!(0, compression.metrics.packets_dropped_compress.get()); @@ -389,18 +391,19 @@ mod tests { Metrics::new().unwrap(), ); - let read_response = compression.read(ReadContext::new( - vec![Endpoint::new("127.0.0.1:80".parse().unwrap())], - "127.0.0.1:8080".parse().unwrap(), - b"hello".to_vec(), - )); + assert!(compression + .read(&mut ReadContext::new( + vec![Endpoint::new("127.0.0.1:80".parse().unwrap())], + "127.0.0.1:8080".parse().unwrap(), + b"hello".to_vec(), + )) + .is_none()); assert!(logs_contain( "Packets are being dropped as they could not be decompressed" )); assert!(logs_contain("quilkin::filters::compress")); // the given name to the the logger by tracing - assert!(read_response.is_none()); assert_eq!(1, compression.metrics.packets_dropped_decompress.get()); assert_eq!(0, compression.metrics.packets_dropped_compress.get()); assert_eq!(0, compression.metrics.compressed_bytes_total.get()); @@ -418,21 +421,24 @@ mod tests { Metrics::new().unwrap(), ); - let read_response = compression.read(ReadContext::new( + let mut read_context = ReadContext::new( vec![Endpoint::new("127.0.0.1:80".parse().unwrap())], "127.0.0.1:8080".parse().unwrap(), b"hello".to_vec(), - )); - assert_eq!(b"hello".to_vec(), read_response.unwrap().contents); + ); + compression.read(&mut read_context).unwrap(); + assert_eq!(b"hello", &*read_context.contents); - let write_response = compression.write(WriteContext::new( - &Endpoint::new("127.0.0.1:80".parse().unwrap()), + let mut write_context = WriteContext::new( + Endpoint::new("127.0.0.1:80".parse().unwrap()), "127.0.0.1:8080".parse().unwrap(), "127.0.0.1:8081".parse().unwrap(), b"hello".to_vec(), - )); + ); + + compression.write(&mut write_context).unwrap(); - assert_eq!(b"hello".to_vec(), write_response.unwrap().contents) + assert_eq!(b"hello".to_vec(), &*write_context.contents) } #[test] @@ -482,33 +488,33 @@ mod tests { { let expected = contents_fixture(); // write compress - let write_response = filter - .write(WriteContext::new( - &Endpoint::new("127.0.0.1:80".parse().unwrap()), - "127.0.0.1:8080".parse().unwrap(), - "127.0.0.1:8081".parse().unwrap(), - expected.clone(), - )) - .expect("should compress"); + let mut write_context = WriteContext::new( + Endpoint::new("127.0.0.1:80".parse().unwrap()), + "127.0.0.1:8080".parse().unwrap(), + "127.0.0.1:8081".parse().unwrap(), + expected.clone(), + ); + + filter.write(&mut write_context).expect("should compress"); - assert_ne!(expected, write_response.contents); + assert_ne!(expected, &*write_context.contents); assert!( - expected.len() > write_response.contents.len(), + expected.len() > write_context.contents.len(), "Original: {}. Compressed: {}", expected.len(), - write_response.contents.len() + write_context.contents.len() ); // read decompress - let read_response = filter - .read(ReadContext::new( - vec![Endpoint::new("127.0.0.1:80".parse().unwrap())], - "127.0.0.1:8080".parse().unwrap(), - write_response.contents.clone(), - )) - .expect("should decompress"); + let mut read_context = ReadContext::new( + vec![Endpoint::new("127.0.0.1:80".parse().unwrap())], + "127.0.0.1:8080".parse().unwrap(), + write_context.contents.clone(), + ); + + filter.read(&mut read_context).expect("should decompress"); - assert_eq!(expected, read_response.contents); - (expected, write_response.contents) + assert_eq!(expected, &*read_context.contents); + (expected, write_context.contents.to_vec()) } } diff --git a/src/filters/concatenate_bytes.rs b/src/filters/concatenate_bytes.rs index b91f0dfa95..d93d49c4ff 100644 --- a/src/filters/concatenate_bytes.rs +++ b/src/filters/concatenate_bytes.rs @@ -44,7 +44,7 @@ impl ConcatenateBytes { } impl Filter for ConcatenateBytes { - fn read(&self, mut ctx: ReadContext) -> Option { + fn read(&self, ctx: &mut ReadContext) -> Option<()> { match self.on_read { Strategy::Append => { ctx.contents.extend(self.bytes.iter()); @@ -55,10 +55,10 @@ impl Filter for ConcatenateBytes { Strategy::DoNothing => {} } - Some(ctx.into()) + Some(()) } - fn write(&self, mut ctx: WriteContext) -> Option { + fn write(&self, ctx: &mut WriteContext) -> Option<()> { match self.on_write { Strategy::Append => { ctx.contents.extend(self.bytes.iter()); @@ -69,7 +69,7 @@ impl Filter for ConcatenateBytes { Strategy::DoNothing => {} } - Some(ctx.into()) + Some(()) } } diff --git a/src/filters/debug.rs b/src/filters/debug.rs index 3ff6dc6a03..1a42c1caa2 100644 --- a/src/filters/debug.rs +++ b/src/filters/debug.rs @@ -42,16 +42,16 @@ impl Debug { impl Filter for Debug { #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] - fn read(&self, ctx: ReadContext) -> Option { + fn read(&self, ctx: &mut ReadContext) -> Option<()> { info!(id = ?self.config.id, source = ?&ctx.source, contents = ?String::from_utf8_lossy(&ctx.contents), "Read filter event"); - Some(ctx.into()) + Some(()) } #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] - fn write(&self, ctx: WriteContext) -> Option { + fn write(&self, ctx: &mut WriteContext) -> Option<()> { info!(id = ?self.config.id, endpoint = ?ctx.endpoint.address, source = ?&ctx.source, dest = ?&ctx.dest, contents = ?String::from_utf8_lossy(&ctx.contents), "Write filter event"); - Some(ctx.into()) + Some(()) } } diff --git a/src/filters/drop.rs b/src/filters/drop.rs index 3cdd212d8a..ca7054086d 100644 --- a/src/filters/drop.rs +++ b/src/filters/drop.rs @@ -35,12 +35,12 @@ impl Drop { impl Filter for Drop { #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] - fn read(&self, _: ReadContext) -> Option { + fn read(&self, _: &mut ReadContext) -> Option<()> { None } #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] - fn write(&self, _: WriteContext) -> Option { + fn write(&self, _: &mut WriteContext) -> Option<()> { None } } diff --git a/src/filters/firewall.rs b/src/filters/firewall.rs index 07a2b5ba32..4a13206c45 100644 --- a/src/filters/firewall.rs +++ b/src/filters/firewall.rs @@ -60,7 +60,7 @@ impl StaticFilter for Firewall { impl Filter for Firewall { #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] - fn read(&self, ctx: ReadContext) -> Option { + fn read(&self, ctx: &mut ReadContext) -> Option<()> { for rule in &self.on_read { if rule.contains(ctx.source.to_socket_addr().ok()?) { return match rule.action { @@ -71,7 +71,7 @@ impl Filter for Firewall { source = ?ctx.source.to_string() ); self.metrics.packets_allowed_read.inc(); - Some(ctx.into()) + Some(()) } Action::Deny => { debug!(action = "Deny", event = "read", source = ?ctx.source); @@ -91,7 +91,7 @@ impl Filter for Firewall { } #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] - fn write(&self, ctx: WriteContext) -> Option { + fn write(&self, ctx: &mut WriteContext) -> Option<()> { for rule in &self.on_write { if rule.contains(ctx.source.to_socket_addr().ok()?) { return match rule.action { @@ -102,7 +102,7 @@ impl Filter for Firewall { source = ?ctx.source.to_string() ); self.metrics.packets_allowed_write.inc(); - Some(ctx.into()) + Some(()) } Action::Deny => { debug!(action = "Deny", event = "write", source = ?ctx.source); @@ -147,16 +147,16 @@ mod tests { }; let local_ip = [192, 168, 75, 20]; - let ctx = ReadContext::new( + let mut ctx = ReadContext::new( vec![Endpoint::new((Ipv4Addr::LOCALHOST, 8080).into())], (local_ip, 80).into(), vec![], ); - assert!(firewall.read(ctx).is_some()); + assert!(firewall.read(&mut ctx).is_some()); assert_eq!(1, firewall.metrics.packets_allowed_read.get()); assert_eq!(0, firewall.metrics.packets_denied_read.get()); - let ctx = ReadContext::new( + let mut ctx = ReadContext::new( vec![Endpoint::new((Ipv4Addr::LOCALHOST, 8080).into())], (local_ip, 2000).into(), vec![], @@ -164,7 +164,7 @@ mod tests { assert!(logs_contain("quilkin::filters::firewall")); // the given name to the the logger by tracing assert!(logs_contain("Allow")); - assert!(firewall.read(ctx).is_none()); + assert!(firewall.read(&mut ctx).is_none()); assert_eq!(1, firewall.metrics.packets_allowed_read.get()); assert_eq!(1, firewall.metrics.packets_denied_read.get()); @@ -187,23 +187,23 @@ mod tests { let endpoint = Endpoint::new((Ipv4Addr::LOCALHOST, 80).into()); let local_addr: crate::endpoint::EndpointAddress = (Ipv4Addr::LOCALHOST, 8081).into(); - let ctx = WriteContext::new( - &endpoint, + let mut ctx = WriteContext::new( + endpoint.clone(), ([192, 168, 75, 20], 80).into(), local_addr.clone(), vec![], ); - assert!(firewall.write(ctx).is_some()); + assert!(firewall.write(&mut ctx).is_some()); assert_eq!(1, firewall.metrics.packets_allowed_write.get()); assert_eq!(0, firewall.metrics.packets_denied_write.get()); - let ctx = WriteContext::new( - &endpoint, + let mut ctx = WriteContext::new( + endpoint, ([192, 168, 77, 20], 80).into(), local_addr, vec![], ); - assert!(firewall.write(ctx).is_none()); + assert!(firewall.write(&mut ctx).is_none()); assert_eq!(1, firewall.metrics.packets_allowed_write.get()); assert_eq!(1, firewall.metrics.packets_denied_write.get()); diff --git a/src/filters/load_balancer.rs b/src/filters/load_balancer.rs index 64be759f89..74f48fbef3 100644 --- a/src/filters/load_balancer.rs +++ b/src/filters/load_balancer.rs @@ -39,9 +39,9 @@ impl LoadBalancer { } impl Filter for LoadBalancer { - fn read(&self, mut ctx: ReadContext) -> Option { - self.endpoint_chooser.choose_endpoints(&mut ctx); - Some(ctx.into()) + fn read(&self, ctx: &mut ReadContext) -> Option<()> { + self.endpoint_chooser.choose_endpoints(ctx); + Some(()) } } @@ -67,13 +67,15 @@ mod tests { input_addresses: &[EndpointAddress], source: EndpointAddress, ) -> Vec { - filter - .read(ReadContext::new( - input_addresses.iter().cloned().map(Endpoint::new).collect(), - source, - vec![], - )) - .unwrap() + let mut context = ReadContext::new( + Vec::from_iter(input_addresses.iter().cloned().map(Endpoint::new)), + source, + vec![], + ); + + filter.read(&mut context).unwrap(); + + context .endpoints .iter() .map(|ep| ep.address.clone()) diff --git a/src/filters/local_rate_limit.rs b/src/filters/local_rate_limit.rs index cdcf23cf12..fb97c9d193 100644 --- a/src/filters/local_rate_limit.rs +++ b/src/filters/local_rate_limit.rs @@ -154,13 +154,11 @@ impl LocalRateLimit { } impl Filter for LocalRateLimit { - fn read(&self, ctx: ReadContext) -> Option { - self.acquire_token(&ctx.source) - .map(|()| ctx.into()) - .or_else(|| { - self.metrics.packets_dropped_total.inc(); - None - }) + fn read(&self, ctx: &mut ReadContext) -> Option<()> { + self.acquire_token(&ctx.source).or_else(|| { + self.metrics.packets_dropped_total.inc(); + None + }) } } @@ -236,10 +234,12 @@ mod tests { (Ipv4Addr::LOCALHOST, 8089).into(), )]; - let result = r.read(ReadContext::new(endpoints, address.clone(), vec![9])); + let mut context = ReadContext::new(endpoints, address.clone(), vec![9]); + let result = r.read(&mut context); if should_succeed { - assert_eq!(result.unwrap().contents, vec![9]); + result.unwrap(); + assert_eq!(context.contents, vec![9]); } else { assert!(result.is_none()); } diff --git a/src/filters/match.rs b/src/filters/match.rs index 46659fa110..489c364b14 100644 --- a/src/filters/match.rs +++ b/src/filters/match.rs @@ -85,19 +85,18 @@ impl Match { } } -fn match_filter<'config, Ctx, R>( +fn match_filter<'config, 'ctx, Ctx>( config: &'config Option, metrics: &'config Metrics, - ctx: Ctx, - get_metadata: impl for<'ctx> Fn(&'ctx Ctx, &'config String) -> Option<&'ctx Value>, - and_then: impl Fn(Ctx, &'config FilterInstance) -> Option, -) -> Option + ctx: &'ctx mut Ctx, + get_metadata: impl for<'value> Fn(&'value Ctx, &'config String) -> Option<&'value Value>, + and_then: impl Fn(&'ctx mut Ctx, &'config FilterInstance) -> Option<()>, +) -> Option<()> where - Ctx: Into, { match config { Some(config) => { - let value = (get_metadata)(&ctx, &config.metadata_key)?; + let value = (get_metadata)(ctx, &config.metadata_key)?; match config.branches.iter().find(|(key, _)| key == value) { Some((value, instance)) => { @@ -116,13 +115,13 @@ where } } } - None => Some(ctx.into()), + None => Some(()), } } impl Filter for Match { #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] - fn read(&self, ctx: ReadContext) -> Option { + fn read(&self, ctx: &mut ReadContext) -> Option<()> { tracing::trace!(metadata=?ctx.metadata); match_filter( &self.on_read_filters, @@ -134,7 +133,7 @@ impl Filter for Match { } #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] - fn write(&self, ctx: WriteContext) -> Option { + fn write(&self, ctx: &mut WriteContext) -> Option<()> { match_filter( &self.on_write_filters, &self.metrics, @@ -183,9 +182,9 @@ mod tests { // no config, so should make no change. filter - .write(WriteContext::new( - &endpoint, - endpoint.address.clone(), + .write(&mut WriteContext::new( + endpoint.clone(), + endpoint.address, "127.0.0.1:70".parse().unwrap(), contents.clone(), )) @@ -202,7 +201,7 @@ mod tests { ); ctx.metadata.insert(Arc::new(key.into()), "abc".into()); - filter.read(ctx).unwrap(); + filter.read(&mut ctx).unwrap(); assert_eq!(1, filter.metrics.packets_matched_total.get()); assert_eq!(0, filter.metrics.packets_fallthrough_total.get()); @@ -213,7 +212,7 @@ mod tests { ); ctx.metadata.insert(Arc::new(key.into()), "xyz".into()); - let result = filter.read(ctx); + let result = filter.read(&mut ctx); assert!(result.is_none()); assert_eq!(1, filter.metrics.packets_matched_total.get()); assert_eq!(1, filter.metrics.packets_fallthrough_total.get()); diff --git a/src/filters/pass.rs b/src/filters/pass.rs index dcb9cb2496..06cbfe275a 100644 --- a/src/filters/pass.rs +++ b/src/filters/pass.rs @@ -34,13 +34,13 @@ impl Pass { impl Filter for Pass { #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] - fn read(&self, ctx: ReadContext) -> Option { - Some(ctx.into()) + fn read(&self, _: &mut ReadContext) -> Option<()> { + Some(()) } #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] - fn write(&self, ctx: WriteContext) -> Option { - Some(ctx.into()) + fn write(&self, _: &mut WriteContext) -> Option<()> { + Some(()) } } diff --git a/src/filters/read.rs b/src/filters/read.rs index 6a8e108ab8..0f607329ab 100644 --- a/src/filters/read.rs +++ b/src/filters/read.rs @@ -45,44 +45,8 @@ impl ReadContext { } } - /// Creates a new [`ReadContext`] from a given [`ReadResponse`]. - pub fn with_response(source: EndpointAddress, response: ReadResponse) -> Self { - Self { - endpoints: response.endpoints, - source, - contents: response.contents, - metadata: response.metadata, - } - } -} - -impl From for ReadResponse { - fn from(ctx: ReadContext) -> Self { - Self { - endpoints: ctx.endpoints, - contents: ctx.contents, - metadata: ctx.metadata, - } + pub fn metadata(mut self, metadata: DynamicMetadata) -> Self { + self.metadata = metadata; + self } } - -/// The output of [`Filter::read`]. -/// -/// New instances are created from [`ReadContext`]. -/// -/// ```rust -/// # use quilkin::filters::{ReadContext, ReadResponse}; -/// fn read(ctx: ReadContext) -> Option { -/// Some(ctx.into()) -/// } -/// ``` -#[derive(Debug)] -#[non_exhaustive] -pub struct ReadResponse { - /// The upstream endpoints that the packet should be forwarded to. - pub endpoints: Vec, - /// Contents of the packet to be forwarded. - pub contents: Vec, - /// Arbitrary values that can be passed from one filter to another - pub metadata: DynamicMetadata, -} diff --git a/src/filters/registry.rs b/src/filters/registry.rs index c44e0b02f6..93cc69077b 100644 --- a/src/filters/registry.rs +++ b/src/filters/registry.rs @@ -66,18 +66,16 @@ mod tests { use super::*; use crate::endpoint::{Endpoint, EndpointAddress}; - use crate::filters::{ - Filter, FilterRegistry, ReadContext, ReadResponse, WriteContext, WriteResponse, - }; + use crate::filters::{Filter, FilterRegistry, ReadContext, WriteContext}; struct TestFilter {} impl Filter for TestFilter { - fn read(&self, _: ReadContext) -> Option { + fn read(&self, _: &mut ReadContext) -> Option<()> { None } - fn write(&self, _: WriteContext) -> Option { + fn write(&self, _: &mut WriteContext) -> Option<()> { None } } @@ -104,14 +102,14 @@ mod tests { let endpoint = Endpoint::new(addr.clone()); assert!(filter - .read(ReadContext::new( + .read(&mut ReadContext::new( vec![endpoint.clone()], addr.clone(), vec![] )) .is_some()); assert!(filter - .write(WriteContext::new(&endpoint, addr.clone(), addr, vec![],)) + .write(&mut WriteContext::new(endpoint, addr.clone(), addr, vec![],)) .is_some()); } } diff --git a/src/filters/timestamp.rs b/src/filters/timestamp.rs index b23fa063d5..c2de97f6ef 100644 --- a/src/filters/timestamp.rs +++ b/src/filters/timestamp.rs @@ -118,14 +118,14 @@ impl TryFrom for Timestamp { } impl Filter for Timestamp { - fn read(&self, ctx: ReadContext) -> Option { + fn read(&self, ctx: &mut ReadContext) -> Option<()> { self.observe(&ctx.metadata, READ_DIRECTION_LABEL); - Some(ctx.into()) + Some(()) } - fn write(&self, ctx: WriteContext) -> Option { + fn write(&self, ctx: &mut WriteContext) -> Option<()> { self.observe(&ctx.metadata, WRITE_DIRECTION_LABEL); - Some(ctx.into()) + Some(()) } } @@ -197,7 +197,7 @@ mod tests { Value::Number(Utc::now().timestamp() as u64), ); - let _ = filter.read(ctx).unwrap(); + filter.read(&mut ctx).unwrap(); assert_eq!(1, filter.metric(READ_DIRECTION_LABEL).get_sample_count()); } @@ -218,16 +218,14 @@ mod tests { ); let timestamp = Timestamp::from_config(Config::new(TIMESTAMP_KEY.to_string()).into()); let source = (std::net::Ipv4Addr::UNSPECIFIED, 0); - let ctx = ReadContext::new( + let mut ctx = ReadContext::new( vec![], source.into(), [0, 0, 0, 0, 99, 81, 55, 181].to_vec(), ); - let response = capture.read(ctx).unwrap(); - let _ = timestamp - .read(ReadContext::with_response(source.into(), response)) - .unwrap(); + capture.read(&mut ctx).unwrap(); + timestamp.read(&mut ctx).unwrap(); assert_eq!(1, timestamp.metric(READ_DIRECTION_LABEL).get_sample_count()); } diff --git a/src/filters/token_router.rs b/src/filters/token_router.rs index ce98af2382..03fecda59a 100644 --- a/src/filters/token_router.rs +++ b/src/filters/token_router.rs @@ -62,7 +62,7 @@ impl StaticFilter for TokenRouter { } impl Filter for TokenRouter { - fn read(&self, mut ctx: ReadContext) -> Option { + fn read(&self, ctx: &mut ReadContext) -> Option<()> { match ctx.metadata.get(self.metadata_key.as_ref()) { None => { tracing::trace!( @@ -92,7 +92,7 @@ impl Filter for TokenRouter { self.metrics.packets_dropped_no_endpoint_match.inc(); None } else { - Some(ctx.into()) + Some(()) } } _ => { @@ -107,10 +107,6 @@ impl Filter for TokenRouter { }, } } - - fn write(&self, ctx: WriteContext) -> Option { - Some(ctx.into()) - } } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq, schemars::JsonSchema)] @@ -249,13 +245,12 @@ mod tests { Value::Bytes(b"567".to_vec().into()), ); - let option = filter.read(ctx); - assert!(option.is_none()); + assert!(filter.read(&mut ctx).is_none()); assert_eq!(1, filter.metrics.packets_dropped_no_endpoint_match.get()); // no key - let ctx = new_ctx(); - assert!(filter.read(ctx).is_none()); + let mut ctx = new_ctx(); + assert!(filter.read(&mut ctx).is_none()); assert_eq!(1, filter.metrics.packets_dropped_no_token_found.get()); // wrong type key @@ -264,7 +259,7 @@ mod tests { Arc::new(CAPTURED_BYTES.into()), Value::String(String::from("wrong")), ); - assert!(filter.read(ctx).is_none()); + assert!(filter.read(&mut ctx).is_none()); assert_eq!(1, filter.metrics.packets_dropped_invalid_token.get()); } @@ -298,12 +293,12 @@ mod tests { ) } - fn assert_read(filter: &F, ctx: ReadContext) + fn assert_read(filter: &F, mut ctx: ReadContext) where F: Filter + ?Sized, { - let result = filter.read(ctx).unwrap(); + filter.read(&mut ctx).unwrap(); - assert_eq!(b"hello".to_vec(), result.contents); + assert_eq!(b"hello", &*ctx.contents); } } diff --git a/src/filters/write.rs b/src/filters/write.rs index 05ff6b0c7a..744d8f9df3 100644 --- a/src/filters/write.rs +++ b/src/filters/write.rs @@ -26,9 +26,9 @@ use crate::filters::Filter; /// The input arguments to [`Filter::write`]. #[non_exhaustive] -pub struct WriteContext<'a> { +pub struct WriteContext { /// The upstream endpoint that we're expecting packets from. - pub endpoint: &'a Endpoint, + pub endpoint: Endpoint, /// The source of the received packet. pub source: EndpointAddress, /// The destination of the received packet. @@ -39,33 +39,15 @@ pub struct WriteContext<'a> { pub metadata: DynamicMetadata, } -/// The output of [`Filter::write`]. -/// -/// New instances are created from [`WriteContext`]. -/// -/// ```rust -/// # use quilkin::filters::{WriteContext, WriteResponse}; -/// fn write(ctx: WriteContext) -> Option { -/// Some(ctx.into()) -/// } -/// ``` -#[non_exhaustive] -pub struct WriteResponse { - /// Contents of the packet to be sent back to the original sender. - pub contents: Vec, - /// Arbitrary values that can be passed from one filter to another. - pub metadata: DynamicMetadata, -} - -impl WriteContext<'_> { +impl WriteContext { /// Creates a new [`WriteContext`] pub fn new( - endpoint: &Endpoint, + endpoint: Endpoint, source: EndpointAddress, dest: EndpointAddress, contents: Vec, - ) -> WriteContext { - WriteContext { + ) -> Self { + Self { endpoint, source, dest, @@ -73,29 +55,4 @@ impl WriteContext<'_> { metadata: HashMap::new(), } } - - /// Creates a new [`WriteContext`] from a given [`WriteResponse`]. - pub fn with_response( - endpoint: &Endpoint, - source: EndpointAddress, - dest: EndpointAddress, - response: WriteResponse, - ) -> WriteContext { - WriteContext { - endpoint, - source, - dest, - contents: response.contents, - metadata: response.metadata, - } - } -} - -impl From> for WriteResponse { - fn from(ctx: WriteContext) -> Self { - Self { - contents: ctx.contents, - metadata: ctx.metadata, - } - } } diff --git a/src/proxy.rs b/src/proxy.rs index 42842005e6..a9b9d40961 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -240,17 +240,15 @@ impl Proxy { return; } - let result = args.config.filters.load().read(ReadContext::new( - endpoints, - packet.source.clone(), - packet.contents, - )); - - if let Some(response) = result { - for endpoint in response.endpoints.iter() { + let filters = args.config.filters.load(); + let mut context = ReadContext::new(endpoints, packet.source, packet.contents); + let result = filters.read(&mut context); + + if let Some(()) = result { + for endpoint in context.endpoints.iter() { Self::session_send_packet( - &response.contents, - packet.source.clone(), + &context.contents, + context.source.clone(), endpoint, args, ) diff --git a/src/proxy/sessions.rs b/src/proxy/sessions.rs index 6737a6a8b7..9da07209e8 100644 --- a/src/proxy/sessions.rs +++ b/src/proxy/sessions.rs @@ -247,20 +247,22 @@ impl Session { let result = Session::do_update_expiration(expiration, ttl) .and_then(|_| { + let mut context = WriteContext::new( + endpoint.clone(), + from.clone(), + dest.clone(), + packet.to_vec(), + ); config .filters .load() - .write(WriteContext::new( - endpoint, - from.clone(), - dest.clone(), - packet.to_vec(), - )) + .write(&mut context) .ok_or(Error::FilterDroppedPacket) + .map(|_| context) }) - .and_then(|response| { + .and_then(|context| { dest.to_socket_addr() - .map(|addr| (addr, response)) + .map(|addr| (addr, context)) .map_err(Error::ToSocketAddr) }); @@ -275,8 +277,8 @@ impl Session { }; match result { - Ok((addr, response)) => { - let packet = response.contents.as_slice(); + Ok((addr, context)) => { + let packet = context.contents.as_ref(); tracing::trace!(%from, dest = %addr, contents = %debug::bytes_to_string(packet), "sending packet downstream"); let _ = downstream_socket .send_to(packet, addr) diff --git a/src/test_utils.rs b/src/test_utils.rs index 15ec11634c..91eae08408 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -60,7 +60,7 @@ pub async fn available_addr() -> SocketAddr { pub struct TestFilter; impl Filter for TestFilter { - fn read(&self, mut ctx: ReadContext) -> Option { + fn read(&self, ctx: &mut ReadContext) -> Option<()> { // append values on each run ctx.metadata .entry(Arc::new("downstream".into())) @@ -69,10 +69,10 @@ impl Filter for TestFilter { ctx.contents .append(&mut format!(":odr:{}", ctx.source).into_bytes()); - Some(ctx.into()) + Some(()) } - fn write(&self, mut ctx: WriteContext) -> Option { + fn write(&self, ctx: &mut WriteContext) -> Option<()> { // append values on each run ctx.metadata .entry("upstream".to_string().into()) @@ -81,7 +81,7 @@ impl Filter for TestFilter { ctx.contents .append(&mut format!(":our:{}:{}", ctx.source, ctx.dest).into_bytes()); - Some(ctx.into()) + Some(()) } } @@ -264,18 +264,11 @@ where let endpoints = vec!["127.0.0.1:80".parse::().unwrap()]; let source = "127.0.0.1:90".parse().unwrap(); let contents = "hello".to_string().into_bytes(); + let mut context = ReadContext::new(endpoints.clone(), source, contents.clone()); - match filter.read(ReadContext::new( - endpoints.clone(), - source, - contents.clone(), - )) { - None => unreachable!("should return a result"), - Some(response) => { - assert_eq!(endpoints, response.endpoints); - assert_eq!(contents, response.contents); - } - } + filter.read(&mut context).unwrap(); + assert_eq!(endpoints, &*context.endpoints); + assert_eq!(contents, &*context.contents); } /// assert that write makes no changes @@ -285,16 +278,15 @@ where { let endpoint = "127.0.0.1:90".parse::().unwrap(); let contents = "hello".to_string().into_bytes(); - - match filter.write(WriteContext::new( - &endpoint, - endpoint.address.clone(), + let mut context = WriteContext::new( + endpoint.clone(), + endpoint.address, "127.0.0.1:70".parse().unwrap(), contents.clone(), - )) { - None => unreachable!("should return a result"), - Some(response) => assert_eq!(contents, response.contents), - } + ); + + filter.write(&mut context).unwrap(); + assert_eq!(contents, &*context.contents); } /// Opens a new socket bound to an ephemeral port