diff --git a/docs/src/filters/writing_custom_filters.md b/docs/src/filters/writing_custom_filters.md index ed0b99dac1..5382e981f1 100644 --- a/docs/src/filters/writing_custom_filters.md +++ b/docs/src/filters/writing_custom_filters.md @@ -20,7 +20,7 @@ A [trait][Filter] representing an actual [Filter][built-in-filters] instance in A [trait][FilterFactory] representing a type that knows how to create instances of a particular type of [Filter]. - An implementation provides a `name` and `create_filter` method. -- `create_filter` takes in [configuration][filter configuration] for the filter to create and returns a [FilterInstance] type containing a new instance of its filter type. +- `create_filter` takes in [configuration][filter configuration] for the filter to create and returns a [FilterInstance] type containing a new instance of its filter type. `name` returns the Filter name - a unique identifier of filters of the created type (e.g quilkin.filters.debug.v1alpha1.Debug). ### FilterRegistry @@ -72,7 +72,7 @@ We start with the [Filter] implementation # // src/main.rs use quilkin::filters::prelude::*; - + struct Greet; impl Filter for Greet { @@ -94,31 +94,70 @@ Next, we implement a [FilterFactory] for it and give it a name: # #![allow(unused)] # fn main() { # +# #[derive(Default)] # struct Greet; +# impl Greet { +# fn new(_: Config) -> Self { +# <_>::default() +# } +# } # impl Filter for Greet {} -# use quilkin::filters::Filter; +# impl StaticFilter for Greet { +# const NAME: &'static str = "greet.v1"; +# type Configuration = Config; +# type BinaryConfiguration = prost_types::Struct; +# +# fn new(config: Option) -> Result { +# Ok(Greet::new(config.unwrap_or_default())) +# } +# } // src/main.rs use quilkin::filters::prelude::*; -pub const NAME: &str = "greet.v1"; - -pub fn factory() -> DynFilterFactory { - Box::from(GreetFilterFactory) +#[derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)] +struct Config { + greeting: String, } -struct GreetFilterFactory; -impl FilterFactory for GreetFilterFactory { - fn name(&self) -> &'static str { - NAME +impl Default for Config { + fn default() -> Self { + Self { + greeting: "World".into(), + } } +} - fn config_schema(&self) -> schemars::schema::RootSchema { - schemars::schema_for!(serde_json::Value) +impl TryFrom for Config { + type Error = Error; + + fn try_from(map: prost_types::Struct) -> Result { + let greeting = map.fields.get("greeting") + .and_then(|v| v.kind.clone()) + .and_then(|kind| { + match kind { + prost_types::value::Kind::StringValue(string) => Some(string), + _ => None, + } + }).ok_or_else(|| { + Error::FieldInvalid { + field: "greeting".into(), + reason: "Missing".into() + } + })?; + + Ok(Self { greeting }) } +} - fn create_filter(&self, _: CreateFilterArgs) -> Result { - let filter: Box = Box::new(Greet); - Ok(FilterInstance::new(serde_json::Value::Null, filter)) +impl From for prost_types::Struct { + fn from(config: Config) -> Self { + Self { + fields: <_>::from([ + ("greeting".into(), prost_types::Value { + kind: Some(prost_types::value::Kind::StringValue(config.greeting)) + }) + ]) + } } } # } @@ -130,7 +169,7 @@ impl FilterFactory for GreetFilterFactory { #### 3. Start the proxy We can run the proxy in the exact manner as the default Quilkin binary using the [run][runner::run] function, passing in our custom [FilterFactory]. -Let's add a main function that does that. Quilkin relies on the [Tokio] async runtime, so we need to import that +Let's add a main function that does that. Quilkin relies on the [Tokio] async runtime, so we need to import that crate and wrap our main function with it. Add Tokio as a dependency in `Cargo.toml`. @@ -220,41 +259,36 @@ First let's create the config for our static configuration: ```rust,no_run,noplayground // src/main.rs # use serde::{Deserialize, Serialize}; -# #[derive(Serialize, Deserialize, Debug)] +# use quilkin::filters::prelude::*; +# #[derive(Serialize, Default, Deserialize, Debug, schemars::JsonSchema)] # struct Config { # greeting: String, # } -# use quilkin::filters::prelude::*; +# #[derive(Default)] # struct Greet(String); +# impl Greet { +# fn new(_: Config) -> Self { <_>::default() } +# } # impl Filter for Greet { } -use quilkin::config::ConfigType; - -pub const NAME: &str = "greet.v1"; - -pub fn factory() -> DynFilterFactory { - Box::from(GreetFilterFactory) -} - -struct GreetFilterFactory; -impl FilterFactory for GreetFilterFactory { - fn name(&self) -> &'static str { - NAME - } - - fn config_schema(&self) -> schemars::schema::RootSchema { - schemars::schema_for!(serde_json::Value) - } - - fn create_filter(&self, args: CreateFilterArgs) -> Result { - let config = match args.config.unwrap() { - ConfigType::Static(config) => { - serde_yaml::from_str::(serde_yaml::to_string(&config).unwrap().as_str()) - .unwrap() - } - ConfigType::Dynamic(_) => unimplemented!("dynamic config is not yet supported for this filter"), - }; - let filter: Box = Box::new(Greet(config.greeting)); - Ok(FilterInstance::new(serde_json::Value::Null, filter)) +# impl TryFrom for Config { +# type Error = Error; +# fn try_from(map: prost_types::Struct) -> Result { +# todo!() +# } +# } +# impl TryFrom for prost_types::Struct { +# type Error = Error; +# fn try_from(map: Config) -> Result { +# todo!() +# } +# } +impl StaticFilter for Greet { +# const NAME: &'static str = "greet.v1"; +# type Configuration = Config; +# type BinaryConfiguration = prost_types::Struct; +# + fn new(config: Option) -> Result { + Ok(Greet::new(config.unwrap_or_default())) } } ``` @@ -282,7 +316,7 @@ let config = match args.config.unwrap() { The [Dynamic][ConfigType::dynamic] contains the serialized [Protobuf] message received from the [management server] for the [Filter] to create. As a result, its contents are entirely opaque to Quilkin and it is represented with the [Prost Any][prost-any] type so the [FilterFactory] -can interpret its contents however it wishes. +can interpret its contents however it wishes. However, it usually contains a Protobuf equivalent of the filter's static configuration. ###### 1. Add the proto parsing crates to `Cargo.toml`: @@ -334,9 +368,9 @@ recreating the grpc package name as Rust modules: ###### 4. Decode the serialized proto message into a config: If the message contains a Protobuf equivalent of the filter's static configuration, we can -leverage the [deserialize][ConfigType::deserialize] method to deserialize either a static or dynamic config. +leverage the [deserialize][ConfigType::deserialize] method to deserialize either a static or dynamic config. The function automatically deserializes and converts from the Protobuf type if the input contains a dynamic -configuration. +configuration. As a result, the function requires that the [std::convert::TryFrom] is implemented from our dynamic config type to a static equivalent. diff --git a/examples/quilkin-filter-example/src/main.rs b/examples/quilkin-filter-example/src/main.rs index 5b97fef648..f9fc7e8647 100644 --- a/examples/quilkin-filter-example/src/main.rs +++ b/examples/quilkin-filter-example/src/main.rs @@ -15,8 +15,7 @@ */ // ANCHOR: include_proto -quilkin::include_proto!("greet"); -use greet::Greet as ProtoGreet; +mod proto { tonic::include_proto!("greet"); } // ANCHOR_END: include_proto use quilkin::filters::prelude::*; @@ -31,15 +30,23 @@ struct Config { // ANCHOR_END: serde_config // ANCHOR: TryFrom -impl TryFrom for Config { +impl TryFrom for Config { type Error = ConvertProtoConfigError; - fn try_from(p: ProtoGreet) -> Result { - Ok(Config { + fn try_from(p: proto::Greet) -> Result { + Ok(Self { greeting: p.greeting, }) } } + +impl From for proto::Greet { + fn from(config: Config) -> Self { + Self { + greeting: config.greeting, + } + } +} // ANCHOR_END: TryFrom // ANCHOR: filter @@ -60,28 +67,15 @@ impl Filter for Greet { // ANCHOR_END: filter // ANCHOR: factory -pub const NAME: &str = "greet.v1"; - -pub fn factory() -> DynFilterFactory { - Box::from(GreetFilterFactory) -} - -struct GreetFilterFactory; -impl FilterFactory for GreetFilterFactory { - fn name(&self) -> &'static str { - NAME - } +use quilkin::filters::StaticFilter; - fn config_schema(&self) -> schemars::schema::RootSchema { - schemars::schema_for!(Config) - } +impl StaticFilter for Greet { + const NAME: &'static str = "greet.v1"; + type Configuration = Config; + type BinaryConfiguration = proto::Greet; - fn create_filter(&self, args: CreateFilterArgs) -> Result { - let (config_json, config) = self - .require_config(args.config)? - .deserialize::(self.name())?; - let filter: Box = Box::new(Greet(config.greeting)); - Ok(FilterInstance::new(config_json, filter)) + fn new(config: Option) -> Result { + Ok(Self(Self::ensure_config_exists(config)?.greeting)) } } // ANCHOR_END: factory @@ -91,7 +85,7 @@ impl FilterFactory for GreetFilterFactory { async fn main() { quilkin::run( quilkin::Config::builder().build(), - vec![self::factory()].into_iter(), + vec![Greet::factory()].into_iter(), ) .await .unwrap(); diff --git a/src/config/config_type.rs b/src/config/config_type.rs index c1f5d19763..571f9702a3 100644 --- a/src/config/config_type.rs +++ b/src/config/config_type.rs @@ -18,7 +18,7 @@ use std::convert::TryFrom; use bytes::Bytes; -use crate::filters::{ConvertProtoConfigError, Error}; +use crate::filters::Error; /// The configuration of a [`Filter`][crate::filters::Filter] from either a /// static or dynamic source. @@ -48,19 +48,21 @@ impl ConfigType { /// It returns both the deserialized, as well as, a JSON representation /// of the provided config. /// It returns an error if any of the serialization or deserialization steps fail. - pub fn deserialize( + pub fn deserialize( self, filter_name: &str, - ) -> Result<(serde_json::Value, Static), Error> + ) -> Result<(serde_json::Value, TextConfiguration), Error> where - Dynamic: prost::Message + Default, - Static: serde::Serialize - + for<'de> serde::Deserialize<'de> - + TryFrom, + BinaryConfiguration: prost::Message + Default, + TextConfiguration: + serde::Serialize + for<'de> serde::Deserialize<'de> + TryFrom, + Error: From<>::Error>, { match self { ConfigType::Static(ref config) => serde_yaml::to_string(config) - .and_then(|raw_config| serde_yaml::from_str::(raw_config.as_str())) + .and_then(|raw_config| { + serde_yaml::from_str::(raw_config.as_str()) + }) .map_err(|err| { Error::DeserializeFailed(format!( "filter `{filter_name}`: failed to YAML deserialize config: {err}", @@ -76,7 +78,7 @@ impl ConfigType { "filter `{filter_name}`: config decode error: {err}", )) }) - .and_then(|config| Static::try_from(config).map_err(Error::ConvertProtoConfig)) + .and_then(|config| TextConfiguration::try_from(config).map_err(From::from)) .and_then(|config| { Self::get_json_config(filter_name, &config) .map(|config_json| (config_json, config)) diff --git a/src/filters.rs b/src/filters.rs index 61a00f1fec..a368445ba6 100644 --- a/src/filters.rs +++ b/src/filters.rs @@ -16,15 +16,15 @@ //! Filters for processing packets. +mod chain; mod error; mod factory; +mod metadata; mod read; mod registry; mod set; mod write; -pub(crate) mod chain; - pub mod capture; pub mod compress; pub mod concatenate_bytes; @@ -34,7 +34,6 @@ pub mod firewall; pub mod load_balancer; pub mod local_rate_limit; pub mod r#match; -pub mod metadata; pub mod pass; pub mod token_router; @@ -43,21 +42,84 @@ pub mod token_router; pub mod prelude { pub use super::{ ConvertProtoConfigError, CreateFilterArgs, DynFilterFactory, Error, Filter, FilterFactory, - FilterInstance, ReadContext, ReadResponse, WriteContext, WriteResponse, + FilterInstance, ReadContext, ReadResponse, StaticFilter, WriteContext, WriteResponse, }; } // Core Filter types pub use self::{ + capture::Capture, + compress::Compress, + concatenate_bytes::ConcatenateBytes, + debug::Debug, + drop::Drop, error::{ConvertProtoConfigError, Error}, factory::{CreateFilterArgs, DynFilterFactory, FilterFactory, FilterInstance}, + firewall::Firewall, + load_balancer::LoadBalancer, + local_rate_limit::LocalRateLimit, + pass::Pass, + r#match::Match, read::{ReadContext, ReadResponse}, registry::FilterRegistry, set::{FilterMap, FilterSet}, + token_router::TokenRouter, write::{WriteContext, WriteResponse}, }; -pub(crate) use self::chain::{FilterChain, SharedFilterChain}; +pub(crate) use self::chain::{Error as FilterChainError, FilterChain, SharedFilterChain}; + +/// Statically safe version of [`Filter`], if you're writing a Rust filter, you +/// should implement [`StaticFilter`] in addition to [`Filter`], as +/// [`StaticFilter`] guarantees all of the required properties through the type +/// system, allowing Quilkin take care of the virtual table boilerplate +/// automatically at compile-time. +pub trait StaticFilter: Filter + Sized +// This where clause simply states that `Configuration`'s and +// `BinaryConfiguration`'s `Error` types are compatible with `filters::Error`. +where + Error: From<>::Error> + + From<>::Error>, +{ + /// The globally unique name of the filter. + const NAME: &'static str; + /// The human-readable configuration of the filter. **Must** be [`serde`] + /// compatible, have a JSON schema, and be convertible to and + /// from [`Self::BinaryConfiguration`]. + type Configuration: schemars::JsonSchema + + serde::Serialize + + for<'de> serde::Deserialize<'de> + + TryFrom; + /// The binary configuration of the filter. **Must** be [`prost`] compatible, + /// and be convertible to and from [`Self::Configuration`]. + type BinaryConfiguration: prost::Message + + Default + + TryFrom + + Send + + Sync + + Sized; + + /// Instaniates a new [`StaticFilter`] from the given configuration, if any. + /// # Errors + /// If the provided configuration is invalid. + fn new(config: Option) -> Result; + + /// Creates a new dynamic [`FilterFactory`] virtual table. + fn factory() -> DynFilterFactory + where + Self: 'static, + { + Box::from(std::marker::PhantomData:: Self>) + } + + /// Convenience method for providing a consistent error message for filters + /// which require a fully initialized [`Self::Configuration`]. + fn ensure_config_exists( + config: Option, + ) -> Result { + config.ok_or(Error::MissingConfig(Self::NAME)) + } +} /// Trait for routing and manipulating packets. /// diff --git a/src/filters/capture.rs b/src/filters/capture.rs index 48d2580220..6557453a52 100644 --- a/src/filters/capture.rs +++ b/src/filters/capture.rs @@ -32,14 +32,8 @@ use self::{ }; use self::quilkin::filters::capture::v1alpha1 as proto; -pub use config::{Config, Strategy}; - -pub const NAME: &str = "quilkin.filters.capture.v1alpha1.Capture"; -/// Creates a new factory for generating capture filters. -pub fn factory() -> DynFilterFactory { - Box::from(CaptureFactory::new()) -} +pub use config::{Config, Strategy}; /// Trait to implement different strategies for capturing packet data. pub trait CaptureStrategy { @@ -48,7 +42,7 @@ pub trait CaptureStrategy { fn capture(&self, contents: &mut Vec, metrics: &Metrics) -> Option; } -struct Capture { +pub struct Capture { capture: Box, /// metrics reporter for this filter. metrics: Metrics, @@ -83,31 +77,15 @@ impl Filter for Capture { } } -struct CaptureFactory; - -impl CaptureFactory { - pub fn new() -> Self { - CaptureFactory - } -} - -impl FilterFactory for CaptureFactory { - fn name(&self) -> &'static str { - NAME - } - - fn config_schema(&self) -> schemars::schema::RootSchema { - schemars::schema_for!(Config) - } +impl StaticFilter for Capture { + const NAME: &'static str = "quilkin.filters.capture.v1alpha1.Capture"; + type Configuration = Config; + type BinaryConfiguration = proto::Capture; - fn create_filter(&self, args: CreateFilterArgs) -> Result { - let (config_json, config) = self - .require_config(args.config)? - .deserialize::(self.name())?; - let filter = Capture::new(config, Metrics::new()?); - Ok(FilterInstance::new( - config_json, - Box::new(filter) as Box, + fn new(config: Option) -> Result { + Ok(Capture::new( + Self::ensure_config_exists(config)?, + Metrics::new()?, )) } } @@ -121,14 +99,11 @@ mod tests { use crate::{ endpoint::{Endpoint, Endpoints}, filters::metadata::CAPTURED_BYTES, - filters::prelude::*, metadata::Value, test_utils::assert_write_no_change, }; - use super::{ - Capture, CaptureFactory, CaptureStrategy, Config, Metrics, Prefix, Regex, Strategy, Suffix, - }; + use super::*; const TOKEN_KEY: &str = "TOKEN"; @@ -138,7 +113,7 @@ mod tests { #[test] fn factory_valid_config_all() { - let factory = CaptureFactory::new(); + let factory = Capture::factory(); let mut map = Mapping::new(); map.insert( YamlValue::String("metadataKey".into()), @@ -165,7 +140,7 @@ mod tests { #[test] fn factory_valid_config_defaults() { - let factory = CaptureFactory::new(); + let factory = Capture::factory(); let mut map = Mapping::new(); map.insert("suffix".into(), { let mut map = Mapping::new(); @@ -185,7 +160,7 @@ mod tests { #[test] fn factory_invalid_config() { - let factory = CaptureFactory::new(); + let factory = Capture::factory(); let mut map = Mapping::new(); map.insert( YamlValue::String("size".into()), @@ -249,7 +224,7 @@ mod tests { fn regex_capture() { let metrics = Metrics::new().unwrap(); let end = Regex { - pattern: regex::bytes::Regex::new(".{3}$").unwrap(), + pattern: ::regex::bytes::Regex::new(".{3}$").unwrap(), }; let mut contents = b"helloabc".to_vec(); let result = end.capture(&mut contents, &metrics).unwrap(); diff --git a/src/filters/chain.rs b/src/filters/chain.rs index aa1fab37d9..3b1f2f9c64 100644 --- a/src/filters/chain.rs +++ b/src/filters/chain.rs @@ -216,7 +216,7 @@ mod tests { use crate::{ config, endpoint::{Endpoint, Endpoints, UpstreamEndpoints}, - filters::debug, + filters::Debug, test_utils::{new_test_chain, TestFilterFactory}, }; @@ -224,12 +224,12 @@ mod tests { #[test] fn from_config() { - let provider = debug::factory(); + let provider = Debug::factory(); // everything is fine let filter_configs = &[config::Filter { name: provider.name().into(), - config: Default::default(), + config: Some(serde_yaml::Mapping::default().into()), }]; let chain = FilterChain::try_create(filter_configs).unwrap(); diff --git a/src/filters/compress.rs b/src/filters/compress.rs index ac21c62198..0f963eeaed 100644 --- a/src/filters/compress.rs +++ b/src/filters/compress.rs @@ -29,15 +29,8 @@ use metrics::Metrics; pub use config::{Action, Config, Mode}; -pub const NAME: &str = "quilkin.filters.compress.v1alpha1.Compress"; - -/// Returns a factory for creating compression filters. -pub fn factory() -> DynFilterFactory { - Box::from(CompressFactory::new()) -} - /// Filter for compressing and decompressing packet data -struct Compress { +pub struct Compress { metrics: Metrics, compression_mode: Mode, on_read: Action, @@ -145,46 +138,27 @@ impl Filter for Compress { } } -struct CompressFactory {} - -impl CompressFactory { - pub fn new() -> Self { - CompressFactory {} - } -} +impl StaticFilter for Compress { + const NAME: &'static str = "quilkin.filters.compress.v1alpha1.Compress"; + type Configuration = Config; + type BinaryConfiguration = proto::Compress; -impl FilterFactory for CompressFactory { - fn name(&self) -> &'static str { - NAME - } - - fn config_schema(&self) -> schemars::schema::RootSchema { - schemars::schema_for!(Config) - } - - fn create_filter(&self, args: CreateFilterArgs) -> Result { - let (config_json, config) = self - .require_config(args.config)? - .deserialize::(self.name())?; - let filter = Compress::new(config, Metrics::new()?); - Ok(FilterInstance::new( - config_json, - Box::new(filter) as Box, + fn new(config: Option) -> Result { + Ok(Compress::new( + Self::ensure_config_exists(config)?, + Metrics::new()?, )) } } #[cfg(test)] mod tests { - use std::convert::TryFrom; - use serde_yaml::{Mapping, Value}; use tracing_test::traced_test; - use crate::endpoint::{Endpoint, Endpoints, UpstreamEndpoints}; - use crate::filters::{ - compress::{compressor::Snappy, Compressor}, - CreateFilterArgs, Filter, FilterFactory, ReadContext, WriteContext, + use crate::{ + endpoint::{Endpoint, Endpoints, UpstreamEndpoints}, + filters::compress::compressor::Snappy, }; use super::*; @@ -277,7 +251,7 @@ mod tests { #[test] fn default_mode_factory() { - let factory = CompressFactory::new(); + let factory = Compress::factory(); let mut map = Mapping::new(); map.insert( Value::String("on_read".into()), @@ -296,7 +270,7 @@ mod tests { #[test] fn config_factory() { - let factory = CompressFactory::new(); + let factory = Compress::factory(); let mut map = Mapping::new(); map.insert(Value::String("mode".into()), Value::String("SNAPPY".into())); map.insert( diff --git a/src/filters/concatenate_bytes.rs b/src/filters/concatenate_bytes.rs index 3f24318c6f..3d43a4abb3 100644 --- a/src/filters/concatenate_bytes.rs +++ b/src/filters/concatenate_bytes.rs @@ -23,18 +23,11 @@ use crate::filters::prelude::*; use self::quilkin::filters::concatenate_bytes::v1alpha1 as proto; pub use config::{Config, Strategy}; -pub const NAME: &str = "quilkin.filters.concatenate_bytes.v1alpha1.ConcatenateBytes"; - -/// Returns a factory for creating concatenation filters. -pub fn factory() -> DynFilterFactory { - Box::from(ConcatBytesFactory) -} - /// The `ConcatenateBytes` filter's job is to add a byte packet to either the /// beginning or end of each UDP packet that passes through. This is commonly /// used to provide an auth token to each packet, so they can be /// routed appropriately. -struct ConcatenateBytes { +pub struct ConcatenateBytes { on_read: Strategy, on_write: Strategy, bytes: Vec, @@ -80,26 +73,12 @@ impl Filter for ConcatenateBytes { } } -#[derive(Default)] -struct ConcatBytesFactory; - -impl FilterFactory for ConcatBytesFactory { - fn name(&self) -> &'static str { - NAME - } - - fn config_schema(&self) -> schemars::schema::RootSchema { - schemars::schema_for!(Config) - } +impl StaticFilter for ConcatenateBytes { + const NAME: &'static str = "quilkin.filters.concatenate_bytes.v1alpha1.ConcatenateBytes"; + type Configuration = Config; + type BinaryConfiguration = proto::ConcatenateBytes; - fn create_filter(&self, args: CreateFilterArgs) -> Result { - let (config_json, config) = self - .require_config(args.config)? - .deserialize::(self.name())?; - let filter = ConcatenateBytes::new(config); - Ok(FilterInstance::new( - config_json, - Box::new(filter) as Box, - )) + fn new(config: Option) -> Result { + Ok(ConcatenateBytes::new(Self::ensure_config_exists(config)?)) } } diff --git a/src/filters/debug.rs b/src/filters/debug.rs index 525e94b2dd..38b6b69500 100644 --- a/src/filters/debug.rs +++ b/src/filters/debug.rs @@ -25,86 +25,47 @@ use tracing::info; use self::quilkin::filters::debug::v1alpha1 as proto; /// Debug logs all incoming and outgoing packets -struct Debug {} - -pub const NAME: &str = "quilkin.filters.debug.v1alpha1.Debug"; - -/// Creates a new factory for generating debug filters. -pub fn factory() -> DynFilterFactory { - Box::from(DebugFactory::new()) +pub struct Debug { + config: Config, } impl Debug { /// Constructor for the Debug. Pass in a "id" to append a string to your log messages from this /// Filter. - - fn new(_: Option) -> Self { - Debug {} + fn new(config: Option) -> Self { + Self { + config: config.unwrap_or_default(), + } } } impl Filter for Debug { #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] fn read(&self, ctx: ReadContext) -> Option { - info!(source = ?&ctx.source, contents = ?packet_to_string(ctx.contents.clone()), "Read filter event"); + info!(id = ?self.config.id, source = ?&ctx.source, contents = ?String::from_utf8_lossy(&ctx.contents), "Read filter event"); Some(ctx.into()) } #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] fn write(&self, ctx: WriteContext) -> Option { - info!(endpoint = ?ctx.endpoint.address, source = ?&ctx.source, - dest = ?&ctx.dest, contents = ?packet_to_string(ctx.contents.clone()), "Write filter event"); + info!(id = ?self.config.id, endpoint = ?ctx.endpoint.address, source = ?&ctx.source, + dest = ?&ctx.dest, contents = ?String::from_utf8_lossy(&ctx.contents), "Write filter event"); Some(ctx.into()) } } -/// packet_to_string takes the content, and attempts to convert it to a string. -/// Returns a string of "error decoding packet" on failure. -fn packet_to_string(contents: Vec) -> String { - match String::from_utf8(contents) { - Ok(str) => str, - Err(_) => String::from("error decoding packet as UTF-8"), - } -} - -/// Factory for the Debug -struct DebugFactory {} - -impl DebugFactory { - pub fn new() -> Self { - DebugFactory {} - } -} - -impl FilterFactory for DebugFactory { - fn name(&self) -> &'static str { - NAME - } - - fn config_schema(&self) -> schemars::schema::RootSchema { - schemars::schema_for!(Config) - } +impl StaticFilter for Debug { + const NAME: &'static str = "quilkin.filters.debug.v1alpha1.Debug"; + type Configuration = Config; + type BinaryConfiguration = proto::Debug; - fn create_filter(&self, args: CreateFilterArgs) -> Result { - let config: Option<(_, Config)> = args - .config - .map(|config| config.deserialize::(self.name())) - .transpose()?; - - let (config_json, config) = config - .map(|(config_json, config)| (config_json, Some(config))) - .unwrap_or_else(|| (serde_json::Value::Null, None)); - let filter = Debug::new(config.and_then(|cfg| cfg.id)); - - Ok(FilterInstance::new( - config_json, - Box::new(filter) as Box, - )) + fn new(config: Option) -> Result { + Ok(Debug::new(config)) } } /// A Debug filter's configuration. -#[derive(Serialize, Deserialize, Debug, schemars::JsonSchema)] +#[derive(Serialize, Default, Deserialize, Debug, schemars::JsonSchema)] pub struct Config { /// Identifier that will be optionally included with each log message. pub id: Option, @@ -153,7 +114,7 @@ mod tests { #[test] fn from_config_with_id() { let mut map = Mapping::new(); - let factory = DebugFactory::new(); + let factory = Debug::factory(); map.insert(Value::from("id"), Value::from("name")); assert!(factory @@ -164,7 +125,7 @@ mod tests { #[test] fn from_config_without_id() { let mut map = Mapping::new(); - let factory = DebugFactory::new(); + let factory = Debug::factory(); map.insert(Value::from("id"), Value::from("name")); assert!(factory @@ -175,7 +136,7 @@ mod tests { #[test] fn from_config_should_error() { let mut map = Mapping::new(); - let factory = DebugFactory::new(); + let factory = Debug::factory(); map.insert(Value::from("id"), Value::Sequence(vec![])); assert!(factory diff --git a/src/filters/drop.rs b/src/filters/drop.rs index e19f2716ac..164a8eff99 100644 --- a/src/filters/drop.rs +++ b/src/filters/drop.rs @@ -22,15 +22,10 @@ use serde::{Deserialize, Serialize}; crate::include_proto!("quilkin.filters.drop.v1alpha1"); use self::quilkin::filters::drop::v1alpha1 as proto; -/// Always drops a packet, mostly useful in combination with other filters. -struct Drop; - -pub const NAME: &str = "quilkin.filters.drop.v1alpha1.Drop"; +pub const NAME: &str = Drop::NAME; -/// Creates a new factory for generating debug filters. -pub fn factory() -> DynFilterFactory { - Box::from(DropFactory::new()) -} +/// Always drops a packet, mostly useful in combination with other filters. +pub struct Drop; impl Drop { fn new() -> Self { @@ -50,38 +45,13 @@ impl Filter for Drop { } } -/// Factory for the Debug -struct DropFactory; - -impl DropFactory { - pub fn new() -> Self { - Self - } -} - -impl FilterFactory for DropFactory { - fn name(&self) -> &'static str { - NAME - } - - fn config_schema(&self) -> schemars::schema::RootSchema { - schemars::schema_for!(Config) - } - - fn create_filter(&self, args: CreateFilterArgs) -> Result { - let config: Option<(_, Config)> = args - .config - .map(|config| config.deserialize::(self.name())) - .transpose()?; - - let (config_json, _) = config - .map(|(config_json, config)| (config_json, Some(config))) - .unwrap_or_else(|| (serde_json::Value::Null, None)); +impl StaticFilter for Drop { + const NAME: &'static str = "quilkin.filters.drop.v1alpha1.Drop"; + type Configuration = Config; + type BinaryConfiguration = proto::Drop; - Ok(FilterInstance::new( - config_json, - Box::new(Drop::new()) as Box, - )) + fn new(_: Option) -> Result { + Ok(Drop::new()) } } diff --git a/src/filters/error.rs b/src/filters/error.rs index 5124cfd44f..4ab5b3e3d6 100644 --- a/src/filters/error.rs +++ b/src/filters/error.rs @@ -26,6 +26,8 @@ use crate::filters::{Filter, FilterFactory}; pub enum Error { #[error("filter `{}` not found", .0)] NotFound(String), + #[error("Expected <{}> message, received <{}> ", expected, actual)] + MismatchedTypes { expected: String, actual: String }, #[error("filter `{}` requires configuration, but none provided", .0)] MissingConfig(&'static str), #[error("field `{}` is invalid, reason: {}", field, reason)] @@ -70,6 +72,18 @@ impl From for Error { } } +impl From for Error { + fn from(error: prost::DecodeError) -> Self { + Self::ConvertProtoConfig(ConvertProtoConfigError::new(error, None)) + } +} + +impl From for Error { + fn from(error: ConvertProtoConfigError) -> Self { + Self::ConvertProtoConfig(error) + } +} + /// An error representing failure to convert a filter's protobuf configuration /// to its static representation. #[derive(Debug, PartialEq, thiserror::Error)] diff --git a/src/filters/factory.rs b/src/filters/factory.rs index b389464364..d6ed8544c8 100644 --- a/src/filters/factory.rs +++ b/src/filters/factory.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use crate::{ config::ConfigType, - filters::{Error, Filter}, + filters::{Error, Filter, StaticFilter}, }; /// An owned pointer to a dynamic [`FilterFactory`] instance. @@ -63,6 +63,13 @@ pub trait FilterFactory: Sync + Send { /// Returns a filter based on the provided arguments. fn create_filter(&self, args: CreateFilterArgs) -> Result; + /// Converts YAML configuration into its Protobuf equivalvent. + fn encode_config_to_protobuf(&self, args: serde_yaml::Value) + -> Result; + + /// Converts YAML configuration into its Protobuf equivalvent. + fn encode_config_to_yaml(&self, args: prost_types::Any) -> Result; + /// Returns the [`ConfigType`] from the provided Option, otherwise it returns /// Error::MissingConfig if the Option is None. fn require_config(&self, config: Option) -> Result { @@ -70,6 +77,49 @@ pub trait FilterFactory: Sync + Send { } } +impl FilterFactory for std::marker::PhantomData F> +where + F: StaticFilter + 'static, + Error: From<>::Error> + + From<>::Error>, +{ + fn name(&self) -> &'static str { + F::NAME + } + + fn config_schema(&self) -> schemars::schema::RootSchema { + schemars::schema_for!(F::Configuration) + } + + /// Returns a filter based on the provided arguments. + fn create_filter(&self, args: CreateFilterArgs) -> Result { + let (config_json, config): (_, Option) = if let Some(config) = args.config + { + config + .deserialize::(self.name()) + .map(|(j, c)| (j, Some(c)))? + } else { + (serde_json::Value::Null, None) + }; + + Ok(FilterInstance::new( + config_json, + Box::from(F::new(config)?) as Box, + )) + } + + fn encode_config_to_protobuf( + &self, + config: serde_yaml::Value, + ) -> Result { + crate::prost::encode_to_protobuf::(self, config) + } + + fn encode_config_to_yaml(&self, config: prost_types::Any) -> Result { + crate::prost::encode_to_yaml::(self, config) + } +} + /// Arguments needed to create a new filter. pub struct CreateFilterArgs { /// Configuration for the filter. diff --git a/src/filters/firewall.rs b/src/filters/firewall.rs index d3e429cb2c..d4800f7e7f 100644 --- a/src/filters/firewall.rs +++ b/src/filters/firewall.rs @@ -14,8 +14,6 @@ * limitations under the License. */ -//! Filter for allowing/blocking traffic by IP and port. - use tracing::debug; use crate::filters::firewall::metrics::Metrics; @@ -30,43 +28,8 @@ mod metrics; pub use config::{Action, Config, PortRange, PortRangeError, Rule}; -pub const NAME: &str = "quilkin.filters.firewall.v1alpha1.Firewall"; - -pub fn factory() -> DynFilterFactory { - Box::from(FirewallFactory::new()) -} - -struct FirewallFactory {} - -impl FirewallFactory { - pub fn new() -> Self { - Self {} - } -} - -impl FilterFactory for FirewallFactory { - fn name(&self) -> &'static str { - NAME - } - - fn config_schema(&self) -> schemars::schema::RootSchema { - schemars::schema_for!(Config) - } - - fn create_filter(&self, args: CreateFilterArgs) -> Result { - let (config_json, config) = self - .require_config(args.config)? - .deserialize::(self.name())?; - - let filter = Firewall::new(config, Metrics::new()?); - Ok(FilterInstance::new( - config_json, - Box::new(filter) as Box, - )) - } -} - -struct Firewall { +/// Filter for allowing/blocking traffic by IP and port. +pub struct Firewall { metrics: Metrics, on_read: Vec, on_write: Vec, @@ -82,6 +45,19 @@ impl Firewall { } } +impl StaticFilter for Firewall { + const NAME: &'static str = "quilkin.filters.firewall.v1alpha1.Firewall"; + type Configuration = Config; + type BinaryConfiguration = proto::Firewall; + + fn new(config: Option) -> Result { + Ok(Firewall::new( + Self::ensure_config_exists(config)?, + Metrics::new()?, + )) + } +} + impl Filter for Firewall { #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] fn read(&self, ctx: ReadContext) -> Option { diff --git a/src/filters/load_balancer.rs b/src/filters/load_balancer.rs index cc49997593..51b1356b5a 100644 --- a/src/filters/load_balancer.rs +++ b/src/filters/load_balancer.rs @@ -19,25 +19,25 @@ crate::include_proto!("quilkin.filters.load_balancer.v1alpha1"); mod config; mod endpoint_chooser; -use crate::filters::{prelude::*, DynFilterFactory}; - -pub use config::{Config, Policy}; -use endpoint_chooser::EndpointChooser; - use self::quilkin::filters::load_balancer::v1alpha1 as proto; +use crate::filters::prelude::*; +use endpoint_chooser::EndpointChooser; -pub const NAME: &str = "quilkin.filters.load_balancer.v1alpha1.LoadBalancer"; - -/// Returns a factory for creating load balancing filters. -pub fn factory() -> DynFilterFactory { - Box::from(LoadBalancerFilterFactory) -} +pub use config::{Config, Policy}; /// Balances packets over the upstream endpoints. -struct LoadBalancer { +pub struct LoadBalancer { endpoint_chooser: Box, } +impl LoadBalancer { + fn new(config: Config) -> Self { + Self { + endpoint_chooser: config.policy.as_endpoint_chooser(), + } + } +} + impl Filter for LoadBalancer { fn read(&self, mut ctx: ReadContext) -> Option { self.endpoint_chooser.choose_endpoints(&mut ctx); @@ -45,28 +45,13 @@ impl Filter for LoadBalancer { } } -struct LoadBalancerFilterFactory; +impl StaticFilter for LoadBalancer { + const NAME: &'static str = "quilkin.filters.load_balancer.v1alpha1.LoadBalancer"; + type Configuration = Config; + type BinaryConfiguration = proto::LoadBalancer; -impl FilterFactory for LoadBalancerFilterFactory { - fn name(&self) -> &'static str { - NAME - } - - fn config_schema(&self) -> schemars::schema::RootSchema { - schemars::schema_for!(Config) - } - - fn create_filter(&self, args: CreateFilterArgs) -> Result { - let (config_json, config) = self - .require_config(args.config)? - .deserialize::(self.name())?; - let filter = LoadBalancer { - endpoint_chooser: config.policy.as_endpoint_chooser(), - }; - Ok(FilterInstance::new( - config_json, - Box::new(filter) as Box, - )) + fn new(config: Option) -> Result { + Ok(LoadBalancer::new(Self::ensure_config_exists(config)?)) } } @@ -74,16 +59,11 @@ impl FilterFactory for LoadBalancerFilterFactory { mod tests { use std::{collections::HashSet, net::Ipv4Addr}; - use crate::{ - endpoint::{Endpoint, EndpointAddress, Endpoints}, - filters::{ - load_balancer::LoadBalancerFilterFactory, CreateFilterArgs, Filter, FilterFactory, - ReadContext, - }, - }; + use super::*; + use crate::endpoint::{Endpoint, EndpointAddress, Endpoints}; fn create_filter(config: &str) -> Box { - let factory = LoadBalancerFilterFactory; + let factory = LoadBalancer::factory(); factory .create_filter(CreateFilterArgs::fixed(Some( serde_yaml::from_str(config).unwrap(), diff --git a/src/filters/local_rate_limit.rs b/src/filters/local_rate_limit.rs index dec001b5e4..5f0d49ae4c 100644 --- a/src/filters/local_rate_limit.rs +++ b/src/filters/local_rate_limit.rs @@ -34,13 +34,6 @@ use metrics::Metrics; crate::include_proto!("quilkin.filters.local_rate_limit.v1alpha1"); use self::quilkin::filters::local_rate_limit::v1alpha1 as proto; -pub const NAME: &str = "quilkin.filters.local_rate_limit.v1alpha1.LocalRateLimit"; - -/// Creates a new factory for generating rate limiting filters. -pub fn factory() -> DynFilterFactory { - Box::from(LocalRateLimitFactory::new()) -} - // TODO: we should make these values configurable and transparent to the filter. /// SESSION_TIMEOUT_SECONDS is the default session timeout. pub const SESSION_TIMEOUT_SECONDS: Duration = Duration::from_secs(60); @@ -71,7 +64,7 @@ struct Bucket { /// applies rate limiting on packets received from a downstream connection (processed /// through [`LocalRateLimit::read`]). Packets coming from upstream endpoints /// flow through the filter untouched. -struct LocalRateLimit { +pub struct LocalRateLimit { /// Tracks rate limiting state per source address. state: TtlMap, /// Filter configuration. @@ -83,12 +76,19 @@ struct LocalRateLimit { impl LocalRateLimit { /// new returns a new LocalRateLimit. It spawns a future in the background /// that periodically refills the rate limiter's tokens. - fn new(config: Config, metrics: Metrics) -> Self { - LocalRateLimit { + fn new(config: Config, metrics: Metrics) -> Result { + if config.period < 1 { + return Err(Error::FieldInvalid { + field: "period".into(), + reason: "value must be at least 1 second".into(), + }); + } + + Ok(LocalRateLimit { state: TtlMap::new(SESSION_TIMEOUT_SECONDS, SESSION_EXPIRY_POLL_INTERVAL), config, metrics, - } + }) } /// acquire_token is called on behalf of every packet that is eligible @@ -164,41 +164,13 @@ impl Filter for LocalRateLimit { } } -/// Creates instances of [`LocalRateLimit`]. -struct LocalRateLimitFactory {} - -impl LocalRateLimitFactory { - pub fn new() -> Self { - LocalRateLimitFactory {} - } -} - -impl FilterFactory for LocalRateLimitFactory { - fn name(&self) -> &'static str { - NAME - } - - fn config_schema(&self) -> schemars::schema::RootSchema { - schemars::schema_for!(Config) - } - - fn create_filter(&self, args: CreateFilterArgs) -> Result { - let (config_json, config) = self - .require_config(args.config)? - .deserialize::(self.name())?; +impl StaticFilter for LocalRateLimit { + const NAME: &'static str = "quilkin.filters.local_rate_limit.v1alpha1.LocalRateLimit"; + type Configuration = Config; + type BinaryConfiguration = proto::LocalRateLimit; - if config.period < 1 { - Err(Error::FieldInvalid { - field: "period".into(), - reason: "value must be at least 1 second".into(), - }) - } else { - let filter = LocalRateLimit::new(config, Metrics::new()?); - Ok(FilterInstance::new( - config_json, - Box::new(filter) as Box, - )) - } + fn new(config: Option) -> Result { + Self::new(Self::ensure_config_exists(config)?, Metrics::new()?) } } @@ -245,17 +217,14 @@ mod tests { use tokio::time; use super::*; - use crate::config::ConfigType; - use crate::endpoint::{Endpoint, EndpointAddress, Endpoints}; - use crate::filters::local_rate_limit::LocalRateLimitFactory; - use crate::filters::{ - local_rate_limit::{metrics::Metrics, Config, LocalRateLimit}, - CreateFilterArgs, Filter, FilterFactory, ReadContext, + use crate::{ + config::ConfigType, + endpoint::{Endpoint, Endpoints}, + test_utils::assert_write_no_change, }; - use crate::test_utils::assert_write_no_change; fn rate_limiter(config: Config) -> LocalRateLimit { - LocalRateLimit::new(config, Metrics::new().unwrap()) + LocalRateLimit::new(config, Metrics::new().unwrap()).unwrap() } fn address_pair() -> (EndpointAddress, EndpointAddress) { @@ -280,7 +249,7 @@ mod tests { #[tokio::test] async fn config_minimum_period() { - let factory = LocalRateLimitFactory::new(); + let factory = LocalRateLimit::factory(); let config = " max_packets: 10 period: 0 diff --git a/src/filters/match.rs b/src/filters/match.rs index 436da18603..d8cb251fda 100644 --- a/src/filters/match.rs +++ b/src/filters/match.rs @@ -23,14 +23,8 @@ use crate::{config::ConfigType, filters::prelude::*, metadata::Value}; use self::quilkin::filters::matches::v1alpha1 as proto; use crate::filters::r#match::metrics::Metrics; -pub use config::Config; - -pub const NAME: &str = "quilkin.filters.match.v1alpha1.Match"; -/// Creates a new factory for generating match filters. -pub fn factory() -> DynFilterFactory { - Box::from(MatchFactory::new()) -} +pub use config::Config; struct ConfigInstance { metadata_key: String, @@ -61,19 +55,19 @@ impl ConfigInstance { } } -struct MatchInstance { +pub struct Match { metrics: Metrics, on_read_filters: Option, on_write_filters: Option, } -impl MatchInstance { +impl Match { fn new(config: Config, metrics: Metrics) -> Result { let on_read_filters = config.on_read.map(ConfigInstance::new).transpose()?; let on_write_filters = config.on_write.map(ConfigInstance::new).transpose()?; if on_read_filters.is_none() && on_write_filters.is_none() { - return Err(Error::MissingConfig(NAME)); + return Err(Error::MissingConfig(Self::NAME)); } Ok(Self { @@ -113,7 +107,7 @@ where } } -impl Filter for MatchInstance { +impl Filter for Match { #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] fn read(&self, ctx: ReadContext) -> Option { match_filter( @@ -137,33 +131,13 @@ impl Filter for MatchInstance { } } -struct MatchFactory; - -impl MatchFactory { - pub fn new() -> Self { - Self - } -} - -impl FilterFactory for MatchFactory { - fn name(&self) -> &'static str { - NAME - } - - fn config_schema(&self) -> schemars::schema::RootSchema { - schemars::schema_for!(Config) - } - - fn create_filter(&self, args: CreateFilterArgs) -> Result { - let (config_json, config) = self - .require_config(args.config)? - .deserialize::(self.name())?; +impl StaticFilter for Match { + const NAME: &'static str = "quilkin.filters.match.v1alpha1.Match"; + type Configuration = Config; + type BinaryConfiguration = proto::Match; - let filter = MatchInstance::new(config, Metrics::new()?)?; - Ok(FilterInstance::new( - config_json, - Box::new(filter) as Box, - )) + fn new(config: Option) -> Result { + Self::new(Self::ensure_config_exists(config)?, Metrics::new()?) } } @@ -174,7 +148,7 @@ mod tests { filters::{ r#match::{ config::{Branch, DirectionalConfig, Fallthrough, Filter as ConfigFilter}, - Config, MatchInstance, Metrics, + Config, Match, Metrics, }, Filter, ReadContext, WriteContext, }, @@ -196,7 +170,7 @@ mod tests { }), on_write: None, }; - let filter = MatchInstance::new(config, metrics).unwrap(); + let filter = Match::new(config, metrics).unwrap(); let endpoint: Endpoint = Default::default(); let contents = "hello".to_string().into_bytes(); diff --git a/src/filters/match/config.rs b/src/filters/match/config.rs index 2793a8b6e6..db69c75ee6 100644 --- a/src/filters/match/config.rs +++ b/src/filters/match/config.rs @@ -29,6 +29,17 @@ pub struct Config { pub on_write: Option, } +impl TryFrom for proto::Match { + type Error = crate::filters::Error; + + fn try_from(config: Config) -> Result { + Ok(Self { + on_read: config.on_read.map(TryFrom::try_from).transpose()?, + on_write: config.on_write.map(TryFrom::try_from).transpose()?, + }) + } +} + impl TryFrom for Config { type Error = ConvertProtoConfigError; @@ -65,6 +76,22 @@ pub struct DirectionalConfig { pub fallthrough: Fallthrough, } +impl TryFrom for proto::r#match::Config { + type Error = crate::filters::Error; + + fn try_from(config: DirectionalConfig) -> Result { + Ok(Self { + metadata_key: Some(config.metadata_key), + branches: config + .branches + .into_iter() + .map(TryFrom::try_from) + .collect::>()?, + fallthrough: config.fallthrough.try_into().map(Some)?, + }) + } +} + impl TryFrom for DirectionalConfig { type Error = eyre::Report; @@ -98,6 +125,17 @@ pub struct Branch { pub filter: Filter, } +impl TryFrom for proto::r#match::Branch { + type Error = crate::filters::Error; + + fn try_from(branch: Branch) -> Result { + Ok(Self { + value: Some(branch.value.into()), + filter: branch.filter.try_into().map(Some)?, + }) + } +} + impl TryFrom for Branch { type Error = eyre::Report; @@ -218,14 +256,23 @@ impl<'de> Deserialize<'de> for Filter { } } -/// The behaviour when the none of branches match. Defaults to dropping packets. -#[derive(Debug, PartialEq, Serialize, Deserialize, schemars::JsonSchema)] -#[serde(transparent)] -pub struct Fallthrough(pub Filter); +impl TryFrom for proto::r#match::Filter { + type Error = crate::filters::Error; -impl Default for Fallthrough { - fn default() -> Self { - Self(Filter::new(crate::filters::drop::NAME)) + fn try_from(filter: Filter) -> Result { + Ok(Self { + config: match filter.config { + Some(ConfigType::Dynamic(any)) => Some(any), + Some(ConfigType::Static(value)) => { + crate::filters::FilterRegistry::get_factory(&filter.id) + .ok_or_else(|| crate::filters::Error::NotFound(filter.id.clone()))? + .encode_config_to_protobuf(value) + .map(Some)? + } + None => None, + }, + id: Some(filter.id), + }) } } @@ -244,6 +291,24 @@ impl TryFrom for Filter { } } +/// The behaviour when the none of branches match. Defaults to dropping packets. +#[derive(Debug, PartialEq, Serialize, Deserialize, schemars::JsonSchema)] +#[serde(transparent)] +pub struct Fallthrough(pub Filter); + +impl Default for Fallthrough { + fn default() -> Self { + Self(Filter::new(crate::filters::drop::NAME)) + } +} + +impl TryFrom for proto::r#match::Filter { + type Error = crate::filters::Error; + fn try_from(fallthrough: Fallthrough) -> Result { + fallthrough.0.try_into() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/filters/pass.rs b/src/filters/pass.rs index e174cc4210..f859e00424 100644 --- a/src/filters/pass.rs +++ b/src/filters/pass.rs @@ -24,14 +24,7 @@ use self::quilkin::filters::pass::v1alpha1 as proto; /// Allows a packet to pass through, mostly useful in combination with /// other filters. -struct Pass; - -pub const NAME: &str = "quilkin.filters.pass.v1alpha1.Pass"; - -/// Creates a new factory for generating debug filters. -pub fn factory() -> DynFilterFactory { - Box::from(PassFactory::new()) -} +pub struct Pass; impl Pass { fn new() -> Self { @@ -51,38 +44,13 @@ impl Filter for Pass { } } -/// Factory for the Debug -struct PassFactory; - -impl PassFactory { - pub fn new() -> Self { - Self - } -} - -impl FilterFactory for PassFactory { - fn name(&self) -> &'static str { - NAME - } - - fn config_schema(&self) -> schemars::schema::RootSchema { - schemars::schema_for!(Config) - } - - fn create_filter(&self, args: CreateFilterArgs) -> Result { - let config: Option<(_, Config)> = args - .config - .map(|config| config.deserialize::(self.name())) - .transpose()?; - - let (config_json, _) = config - .map(|(config_json, config)| (config_json, Some(config))) - .unwrap_or_else(|| (serde_json::Value::Null, None)); +impl StaticFilter for Pass { + const NAME: &'static str = "quilkin.filters.pass.v1alpha1.Pass"; + type Configuration = Config; + type BinaryConfiguration = proto::Pass; - Ok(FilterInstance::new( - config_json, - Box::new(Pass::new()) as Box, - )) + fn new(_config: Option) -> Result { + Ok(Pass::new()) } } diff --git a/src/filters/registry.rs b/src/filters/registry.rs index c461f651a5..93d96fa458 100644 --- a/src/filters/registry.rs +++ b/src/filters/registry.rs @@ -50,6 +50,12 @@ impl FilterRegistry { Some(filter) => filter, } } + + /// Returns a [`FilterFactory`] for a given `key`. Returning `None` if the + /// factory cannot be found. + pub fn get_factory(key: &str) -> Option> { + REGISTRY.load().get(key).cloned() + } } #[cfg(test)] diff --git a/src/filters/set.rs b/src/filters/set.rs index 6bdd2b0799..3b47729f68 100644 --- a/src/filters/set.rs +++ b/src/filters/set.rs @@ -16,7 +16,7 @@ use std::{iter::FromIterator, sync::Arc}; -use crate::filters::{self, DynFilterFactory}; +use crate::filters::{self, DynFilterFactory, StaticFilter}; #[cfg(doc)] use crate::filters::{FilterFactory, FilterRegistry}; @@ -52,17 +52,17 @@ impl FilterSet { pub fn default_with(filters: impl IntoIterator) -> Self { Self::with( [ - filters::capture::factory(), - filters::compress::factory(), - filters::concatenate_bytes::factory(), - filters::debug::factory(), - filters::drop::factory(), - filters::firewall::factory(), - filters::load_balancer::factory(), - filters::local_rate_limit::factory(), - filters::r#match::factory(), - filters::pass::factory(), - filters::token_router::factory(), + filters::Capture::factory(), + filters::Compress::factory(), + filters::ConcatenateBytes::factory(), + filters::Debug::factory(), + filters::Drop::factory(), + filters::Firewall::factory(), + filters::LoadBalancer::factory(), + filters::LocalRateLimit::factory(), + filters::Match::factory(), + filters::Pass::factory(), + filters::TokenRouter::factory(), ] .into_iter() .chain(filters), diff --git a/src/filters/token_router.rs b/src/filters/token_router.rs index 94304e1b58..4ff35f3807 100644 --- a/src/filters/token_router.rs +++ b/src/filters/token_router.rs @@ -35,16 +35,9 @@ use metrics::Metrics; use self::quilkin::filters::token_router::v1alpha1 as proto; -pub const NAME: &str = "quilkin.filters.token_router.v1alpha1.TokenRouter"; - -/// Returns a factory for creating token routing filters. -pub fn factory() -> DynFilterFactory { - Box::from(TokenRouterFactory::new()) -} - /// Filter that only allows packets to be passed to Endpoints that have a matching /// connection_id to the token stored in the Filter's dynamic metadata. -struct TokenRouter { +pub struct TokenRouter { metadata_key: Arc, metrics: Metrics, } @@ -58,44 +51,15 @@ impl TokenRouter { } } -/// Factory for the TokenRouter filter -struct TokenRouterFactory {} +impl StaticFilter for TokenRouter { + const NAME: &'static str = "quilkin.filters.token_router.v1alpha1.TokenRouter"; + type Configuration = Config; + type BinaryConfiguration = proto::TokenRouter; -impl TokenRouterFactory { - pub fn new() -> Self { - TokenRouterFactory {} - } -} - -impl FilterFactory for TokenRouterFactory { - fn name(&self) -> &'static str { - NAME - } - - fn config_schema(&self) -> schemars::schema::RootSchema { - schemars::schema_for!(Config) - } - - fn create_filter(&self, args: CreateFilterArgs) -> Result { - let (config_json, config) = args - .config - .map(|config| config.deserialize::(self.name())) - .unwrap_or_else(|| { - let config = Config::default(); - serde_json::to_value(&config) - .map_err(|err| { - Error::DeserializeFailed(format!( - "failed to JSON deserialize default config: {err}", - )) - }) - .map(|config_json| (config_json, config)) - })?; - - let filter = TokenRouter::new(config, Metrics::new()?); - - Ok(FilterInstance::new( - config_json, - Box::new(filter) as Box, + fn new(config: Option) -> Result { + Ok(TokenRouter::new( + config.unwrap_or_default(), + Metrics::new()?, )) } } @@ -187,20 +151,14 @@ impl TryFrom for Config { #[cfg(test)] mod tests { - use std::convert::TryFrom; - use std::ops::Deref; - use std::sync::Arc; - use serde_yaml::Mapping; - use crate::endpoint::{Endpoint, Endpoints, Metadata}; - use crate::metadata::Value; - use crate::test_utils::assert_write_no_change; + use crate::{ + endpoint::{Endpoint, Endpoints, Metadata}, + test_utils::assert_write_no_change, + }; use super::*; - use crate::filters::{ - metadata::CAPTURED_BYTES, CreateFilterArgs, Filter, FilterFactory, ReadContext, - }; const TOKEN_KEY: &str = "TOKEN"; @@ -244,7 +202,7 @@ mod tests { #[test] fn factory_custom_tokens() { - let factory = TokenRouterFactory::new(); + let factory = TokenRouter::factory(); let mut map = Mapping::new(); map.insert( serde_yaml::Value::String("metadataKey".into()), @@ -262,12 +220,12 @@ mod tests { Arc::new(TOKEN_KEY.into()), Value::Bytes(b"123".to_vec().into()), ); - assert_read(filter.deref(), ctx); + assert_read(&*filter, ctx); } #[test] fn factory_empty_config() { - let factory = TokenRouterFactory::new(); + let factory = TokenRouter::factory(); let map = Mapping::new(); let filter = factory @@ -281,12 +239,12 @@ mod tests { Arc::new(CAPTURED_BYTES.into()), Value::Bytes(b"123".to_vec().into()), ); - assert_read(filter.deref(), ctx); + assert_read(&*filter, ctx); } #[test] fn factory_no_config() { - let factory = TokenRouterFactory::new(); + let factory = TokenRouter::factory(); let filter = factory .create_filter(CreateFilterArgs::fixed(None)) @@ -297,7 +255,7 @@ mod tests { Arc::new(CAPTURED_BYTES.into()), Value::Bytes(b"123".to_vec().into()), ); - assert_read(filter.deref(), ctx); + assert_read(&*filter, ctx); } #[test] diff --git a/src/prost.rs b/src/prost.rs index b3e111a641..835a704294 100644 --- a/src/prost.rs +++ b/src/prost.rs @@ -19,6 +19,51 @@ use prost_types::value::Kind; use serde_yaml::Value; +pub fn encode_to_protobuf( + factory: &dyn crate::filters::FilterFactory, + config: Value, +) -> Result +where + C: for<'de> serde::Deserialize<'de> + TryInto, + M: prost::Message, + crate::filters::Error: From<>::Error>, +{ + let config: C = serde_yaml::from_value(config)?; + + Ok(prost_types::Any { + type_url: factory.name().into(), + value: crate::prost::encode::(&config.try_into()?)?, + }) +} + +pub fn encode_to_yaml( + factory: &dyn crate::filters::FilterFactory, + config: prost_types::Any, +) -> Result +where + C: serde::Serialize, + M: prost::Message + TryInto + Default, + crate::filters::Error: From<>::Error>, +{ + if factory.name() != config.type_url { + return Err(crate::filters::Error::MismatchedTypes { + expected: factory.name().into(), + actual: config.type_url, + }); + } + + let message = M::decode(&*config.value)?; + + Ok(serde_yaml::to_value(&message.try_into()?)?) +} + +pub fn encode(message: &M) -> Result, prost::EncodeError> { + let mut buf = Vec::new(); + buf.reserve(message.encoded_len()); + message.encode(&mut buf)?; + Ok(buf) +} + pub fn mapping_from_kind(kind: Kind) -> Option { match value_from_kind(kind) { Value::Mapping(mapping) => Some(mapping), diff --git a/src/proxy/builder.rs b/src/proxy/builder.rs index ec14d9c26f..e183941521 100644 --- a/src/proxy/builder.rs +++ b/src/proxy/builder.rs @@ -21,7 +21,7 @@ use tonic::transport::Endpoint as TonicEndpoint; use crate::{ config::{self, Config, ManagementServer, Proxy, Source, ValidationError, ValueInvalidArgs}, endpoint::Endpoints, - filters::chain::Error as FilterChainError, + filters::FilterChainError, proxy::{ server::metrics::Metrics as ProxyMetrics, sessions::metrics::Metrics as SessionMetrics, Admin as ProxyAdmin, Health, Server, diff --git a/src/proxy/config_dump.rs b/src/proxy/config_dump.rs index ed73e44a48..2b7c7ffb6e 100644 --- a/src/proxy/config_dump.rs +++ b/src/proxy/config_dump.rs @@ -97,7 +97,11 @@ fn create_config_dump_json( #[cfg(test)] mod tests { use super::handle_request; - use crate::{cluster::SharedCluster, endpoint::Endpoint, filters::SharedFilterChain}; + use crate::{ + cluster::SharedCluster, + endpoint::Endpoint, + filters::{SharedFilterChain, StaticFilter}, + }; #[tokio::test] async fn test_handle_request() { @@ -105,7 +109,7 @@ mod tests { SharedCluster::new_static_cluster(vec![Endpoint::new(([127, 0, 0, 1], 8080).into())]) .unwrap(); let filter_chain = SharedFilterChain::new(&[crate::config::Filter { - name: crate::filters::debug::NAME.into(), + name: crate::filters::Debug::NAME.into(), config: Some(serde_yaml::from_str("id: hello").unwrap()), }]) .unwrap(); diff --git a/src/test_utils.rs b/src/test_utils.rs index c76049a6a9..e3a0bb9520 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -15,18 +15,22 @@ */ /// Common utilities for testing -use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; -use std::str::from_utf8; -use std::sync::Arc; +use std::{ + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + str::from_utf8, + sync::Arc, +}; use tokio::net::UdpSocket; use tokio::sync::{mpsc, oneshot, watch}; -use crate::config::{Builder as ConfigBuilder, Config}; -use crate::endpoint::{Endpoint, EndpointAddress, Endpoints}; -use crate::filters::{prelude::*, FilterRegistry}; -use crate::metadata::Value; -use crate::proxy::{Builder, PendingValidation}; +use crate::{ + config::{Builder as ConfigBuilder, Config}, + endpoint::{Endpoint, EndpointAddress, Endpoints}, + filters::{prelude::*, FilterRegistry}, + metadata::Value, + proxy::{Builder, PendingValidation}, +}; pub struct TestFilterFactory; @@ -39,6 +43,17 @@ impl FilterFactory for TestFilterFactory { schemars::schema_for_value!(serde_json::Value::Null) } + fn encode_config_to_protobuf( + &self, + _config: serde_yaml::Value, + ) -> Result { + Ok(<_>::default()) + } + + fn encode_config_to_yaml(&self, _config: prost_types::Any) -> Result { + Ok(serde_yaml::Value::Null) + } + fn create_filter(&self, _: CreateFilterArgs) -> Result { Ok(Self::create_empty_filter()) } diff --git a/src/xds/listener.rs b/src/xds/listener.rs index 6ce1dd8fc4..4fb89e364a 100644 --- a/src/xds/listener.rs +++ b/src/xds/listener.rs @@ -159,35 +159,35 @@ impl ListenerManager { #[cfg(test)] mod tests { - use super::ListenerManager; - use crate::filters::prelude::*; - use crate::xds::{ - config::listener::v3::{ - filter::ConfigType, Filter as LdsFilter, FilterChain as LdsFilterChain, Listener, - }, - service::discovery::v3::{DiscoveryRequest, DiscoveryResponse}, - }; - use std::time::Duration; - use crate::endpoint::{Endpoint, Endpoints, UpstreamEndpoints}; - use crate::filters::{ - ConvertProtoConfigError, DynFilterFactory, FilterRegistry, SharedFilterChain, - }; - use crate::xds::LISTENER_TYPE; - use prost::Message; use serde::{Deserialize, Serialize}; - use std::convert::TryFrom; - use tokio::sync::mpsc; use tokio::time; + use super::*; + use crate::{ + endpoint::{Endpoint, Endpoints, UpstreamEndpoints}, + filters::prelude::*, + xds::config::listener::v3::{ + filter::ConfigType, Filter as LdsFilter, FilterChain as LdsFilterChain, + }, + }; + // A simple filter that will be used in the following tests. // It appends a string to each payload. const APPEND_TYPE_URL: &str = "filter.append"; + #[derive(Clone, PartialEq, Serialize, Deserialize, schemars::JsonSchema)] pub struct Append { pub value: Option, } + + impl Append { + fn load() { + FilterRegistry::register([Self::factory()]) + } + } + #[derive(Clone, PartialEq, prost::Message)] pub struct ProtoAppend { #[prost(message, optional, tag = "1")] @@ -221,38 +221,22 @@ mod tests { } } - fn load_append_filter() { - FilterRegistry::register([DynFilterFactory::from(Box::from(AppendFactory))]) - } - - struct AppendFactory; + impl StaticFilter for Append { + const NAME: &'static str = APPEND_TYPE_URL; + type Configuration = Append; + type BinaryConfiguration = ProtoAppend; - impl FilterFactory for AppendFactory { - fn name(&self) -> &'static str { - APPEND_TYPE_URL - } - - fn config_schema(&self) -> schemars::schema::RootSchema { - schemars::schema_for!(Append) - } + fn new(config: Option) -> Result { + let config = Self::ensure_config_exists(config)?; - fn create_filter(&self, args: CreateFilterArgs) -> Result { - let (config_json, filter) = args - .config - .map(|config| config.deserialize::(self.name())) - .transpose()? - .unwrap(); - if filter.value.as_ref().unwrap() == "reject" { - Err(Error::FieldInvalid { + if config.value.as_ref().unwrap() == "reject" { + return Err(Error::FieldInvalid { field: "value".into(), reason: "reject requested".into(), - }) - } else { - Ok(FilterInstance::new( - config_json, - Box::new(filter) as Box, - )) + }); } + + Ok(config) } } @@ -262,7 +246,7 @@ mod tests { // LDS filters and it can build up a filter chain from it. // Prepare a filter registry with the filter factories we need for the test. - load_append_filter(); + Append::load(); let filter_chain = SharedFilterChain::empty(); let (discovery_req_tx, mut discovery_req_rx) = mpsc::channel(10); let mut manager = ListenerManager::new(filter_chain.clone(), discovery_req_tx); @@ -353,7 +337,7 @@ mod tests { // contains no filter chain. // Prepare a filter registry with the filter factories we need for the test. - load_append_filter(); + Append::load(); let filter_chain = SharedFilterChain::empty(); let (discovery_req_tx, mut discovery_req_rx) = mpsc::channel(10); let mut manager = ListenerManager::new(filter_chain.clone(), discovery_req_tx); @@ -446,7 +430,7 @@ mod tests { async fn listener_manager_reject_updates() { // Test that the manager returns NACK DiscoveryRequests for updates it failed to process. - load_append_filter(); + Append::load(); let filter_chain = SharedFilterChain::empty(); let (discovery_req_tx, mut discovery_req_rx) = mpsc::channel(10); let mut manager = ListenerManager::new(filter_chain.clone(), discovery_req_tx); @@ -602,17 +586,11 @@ mod tests { ); } - #[allow(deprecated)] fn create_lds_filter_chain(filters: Vec) -> LdsFilterChain { LdsFilterChain { - filter_chain_match: None, filters, - use_proxy_proto: None, - metadata: None, - transport_socket: None, - transport_socket_connect_timeout: None, name: "test-lds-filter-chain".into(), - on_demand_configuration: None, + ..<_>::default() } } diff --git a/tests/capture.rs b/tests/capture.rs index ac22ed0aac..289db94381 100644 --- a/tests/capture.rs +++ b/tests/capture.rs @@ -21,7 +21,7 @@ use tokio::time::{timeout, Duration}; use quilkin::{ config::{Builder, Filter}, endpoint::Endpoint, - filters::{capture, token_router}, + filters::{Capture, StaticFilter, TokenRouter}, metadata::MetadataView, test_utils::TestHelper, }; @@ -48,11 +48,11 @@ quilkin.dev: .with_static( vec![ Filter { - name: capture::factory().name().into(), + name: Capture::factory().name().into(), config: serde_yaml::from_str(capture_yaml).unwrap(), }, Filter { - name: token_router::factory().name().into(), + name: TokenRouter::factory().name().into(), config: None, }, ], diff --git a/tests/compress.rs b/tests/compress.rs index e39c22c245..45633e22f5 100644 --- a/tests/compress.rs +++ b/tests/compress.rs @@ -21,7 +21,7 @@ use tokio::time::{timeout, Duration}; use quilkin::{ config::{Builder, Filter}, endpoint::Endpoint, - filters::compress, + filters::{Compress, StaticFilter}, test_utils::TestHelper, }; @@ -40,7 +40,7 @@ on_write: COMPRESS .with_port(server_port) .with_static( vec![Filter { - name: compress::factory().name().into(), + name: Compress::factory().name().into(), config: serde_yaml::from_str(yaml).unwrap(), }], vec![Endpoint::new(echo)], @@ -59,7 +59,7 @@ on_write: DECOMPRESS .with_port(client_port) .with_static( vec![Filter { - name: compress::factory().name().into(), + name: Compress::factory().name().into(), config: serde_yaml::from_str(yaml).unwrap(), }], vec![Endpoint::new((Ipv4Addr::LOCALHOST, server_port).into())], diff --git a/tests/concatenate_bytes.rs b/tests/concatenate_bytes.rs index 107923813b..a1a54859a1 100644 --- a/tests/concatenate_bytes.rs +++ b/tests/concatenate_bytes.rs @@ -21,7 +21,7 @@ use tokio::time::{timeout, Duration}; use quilkin::{ config::{Builder, Filter}, endpoint::Endpoint, - filters::concatenate_bytes, + filters::{ConcatenateBytes, StaticFilter}, test_utils::TestHelper, }; @@ -39,7 +39,7 @@ bytes: YWJj #abc .with_port(server_port) .with_static( vec![Filter { - name: concatenate_bytes::factory().name().into(), + name: ConcatenateBytes::factory().name().into(), config: serde_yaml::from_str(yaml).unwrap(), }], vec![Endpoint::new(echo)], diff --git a/tests/filter_order.rs b/tests/filter_order.rs index 7eee3ea0db..6bc991ce99 100644 --- a/tests/filter_order.rs +++ b/tests/filter_order.rs @@ -22,7 +22,7 @@ use tokio::time::{timeout, Duration}; use quilkin::{ config::{Builder, Filter}, endpoint::Endpoint, - filters::{compress, concatenate_bytes}, + filters::{Compress, ConcatenateBytes, StaticFilter}, test_utils::TestHelper, }; @@ -60,15 +60,15 @@ on_write: DECOMPRESS .with_static( vec![ Filter { - name: concatenate_bytes::factory().name().into(), + name: ConcatenateBytes::factory().name().into(), config: serde_yaml::from_str(yaml_concat_read).unwrap(), }, Filter { - name: concatenate_bytes::factory().name().into(), + name: ConcatenateBytes::factory().name().into(), config: serde_yaml::from_str(yaml_concat_write).unwrap(), }, Filter { - name: compress::factory().name().into(), + name: Compress::factory().name().into(), config: serde_yaml::from_str(yaml_compress).unwrap(), }, ], diff --git a/tests/filters.rs b/tests/filters.rs index eae5f32acd..91fd468346 100644 --- a/tests/filters.rs +++ b/tests/filters.rs @@ -24,7 +24,7 @@ use serde_yaml::{Mapping, Value}; use quilkin::{ config::{Builder as ConfigBuilder, Filter}, endpoint::Endpoint, - filters::debug, + filters::{Debug, StaticFilter}, test_utils::{load_test_filters, TestHelper}, Builder as ProxyBuilder, }; @@ -102,7 +102,7 @@ async fn debug_filter() { let mut t = TestHelper::default(); // handy for grabbing the configuration name - let factory = debug::factory(); + let factory = Debug::factory(); // create an echo server as an endpoint. let echo = t.run_echo_server().await; diff --git a/tests/firewall.rs b/tests/firewall.rs index af908084a3..496892815d 100644 --- a/tests/firewall.rs +++ b/tests/firewall.rs @@ -14,10 +14,12 @@ * limitations under the License. */ -use quilkin::config::{Builder, Filter}; -use quilkin::endpoint::Endpoint; -use quilkin::filters::firewall; -use quilkin::test_utils::TestHelper; +use quilkin::{ + config::{Builder, Filter}, + endpoint::Endpoint, + filters::{Firewall, StaticFilter}, + test_utils::TestHelper, +}; use std::net::SocketAddr; use tokio::sync::oneshot::Receiver; use tokio::time::{timeout, Duration}; @@ -104,7 +106,7 @@ async fn test(t: &mut TestHelper, server_port: u16, yaml: &str) -> Receiver