From 131db6d6fa9071a5b134b222e542d09cbd978d82 Mon Sep 17 00:00:00 2001 From: "Jose Fernandez (magec)" Date: Wed, 7 Dec 2022 13:18:42 +0100 Subject: [PATCH] Add dns_cache so server addresses are cached and invalidated when DNS changes. Adds a module to deal with dns_cache feature. It's main struct is CachedResolver, which is a simple thread safe hostname <-> Ips cache with the ability to refresh resolutions every `dns_max_ttl` seconds. This way, a client can check whether its ip address has changed. --- Cargo.lock | 288 +++++++++++++++++++++++++++++++ Cargo.toml | 2 + README.md | 2 + examples/docker/pgcat.toml | 9 + src/config.rs | 12 ++ src/dns_cache.rs | 339 +++++++++++++++++++++++++++++++++++++ src/errors.rs | 1 + src/lib.rs | 1 + src/main.rs | 12 +- src/server.rs | 42 ++++- 10 files changed, 705 insertions(+), 3 deletions(-) create mode 100644 src/dns_cache.rs diff --git a/Cargo.lock b/Cargo.lock index 7b18b2c1e..c7b3f49fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,27 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6" +[[package]] +name = "async-stream" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e" +dependencies = [ + "async-stream-impl", + "futures-core", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "async-trait" version = "0.1.64" @@ -195,6 +216,12 @@ dependencies = [ "syn", ] +[[package]] +name = "data-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ee2393c4a91429dffb4bedf19f4d6abf27d8a732c8ce4980305d782e5426d57" + [[package]] name = "digest" version = "0.10.6" @@ -206,6 +233,18 @@ dependencies = [ "subtle", ] +[[package]] +name = "enum-as-inner" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9720bba047d567ffc8a3cba48bf19126600e249ab7f128e9233e6376976a116" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "env_logger" version = "0.10.0" @@ -252,6 +291,15 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "form_urlencoded" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8" +dependencies = [ + "percent-encoding", +] + [[package]] name = "fs_extra" version = "1.2.0" @@ -393,6 +441,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "heck" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9" + [[package]] name = "hermit-abi" version = "0.2.6" @@ -411,6 +465,17 @@ dependencies = [ "digest", ] +[[package]] +name = "hostname" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c731c3e10504cc8ed35cfe2f1db4c9274c3d35fa486e3b31df46f068ef3e867" +dependencies = [ + "libc", + "match_cfg", + "winapi", +] + [[package]] name = "http" version = "0.2.8" @@ -499,6 +564,27 @@ dependencies = [ "cxx-build", ] +[[package]] +name = "idna" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "418a0a6fab821475f634efe3ccc45c013f742efe03d853e8d3355d5cb850ecf8" +dependencies = [ + "matches", + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "idna" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "indexmap" version = "1.9.2" @@ -519,6 +605,24 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "ipconfig" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd302af1b90f2463a98fa5ad469fc212c8e3175a41c3068601bfa2727591c5be" +dependencies = [ + "socket2", + "widestring", + "winapi", + "winreg", +] + +[[package]] +name = "ipnet" +version = "2.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f88c5561171189e69df9d98bcf18fd5f9558300f7ea7b801eb8a0fd748bd8745" + [[package]] name = "is-terminal" version = "0.4.2" @@ -567,6 +671,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "libc" version = "0.2.139" @@ -582,6 +692,12 @@ dependencies = [ "cc", ] +[[package]] +name = "linked-hash-map" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" + [[package]] name = "linux-raw-sys" version = "0.1.4" @@ -607,6 +723,27 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "lru-cache" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31e24f1ad8321ca0e8a1e0ac13f23cb668e6f5466c2c57319f6a5cf1cc8e3b1c" +dependencies = [ + "linked-hash-map", +] + +[[package]] +name = "match_cfg" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4" + +[[package]] +name = "matches" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f" + [[package]] name = "md-5" version = "0.10.5" @@ -701,6 +838,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "percent-encoding" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" + [[package]] name = "pgcat" version = "0.6.0-alpha1" @@ -735,7 +878,9 @@ dependencies = [ "stringprep", "tokio", "tokio-rustls", + "tokio-test", "toml", + "trust-dns-resolver", ] [[package]] @@ -807,6 +952,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + [[package]] name = "quote" version = "1.0.23" @@ -872,6 +1023,16 @@ version = "0.6.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" +[[package]] +name = "resolv-conf" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52e44394d2086d010551b14b53b1f24e31647570cd1deb0379e2c21b329aba00" +dependencies = [ + "hostname", + "quick-error", +] + [[package]] name = "ring" version = "0.16.20" @@ -1083,6 +1244,26 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "thiserror" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "time" version = "0.1.45" @@ -1151,6 +1332,30 @@ dependencies = [ "webpki", ] +[[package]] +name = "tokio-stream" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d660770404473ccd7bc9f8b28494a811bc18542b915c0855c51e8f419d5223ce" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-test" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53474327ae5e166530d17f2d956afcb4f8a004de581b3cae10f12006bc8163e3" +dependencies = [ + "async-stream", + "bytes", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-util" version = "0.7.4" @@ -1213,9 +1418,21 @@ checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" dependencies = [ "cfg-if", "pin-project-lite", + "tracing-attributes", "tracing-core", ] +[[package]] +name = "tracing-attributes" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tracing-core" version = "0.1.30" @@ -1225,6 +1442,51 @@ dependencies = [ "once_cell", ] +[[package]] +name = "trust-dns-proto" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f7f83d1e4a0e4358ac54c5c3681e5d7da5efc5a7a632c90bb6d6669ddd9bc26" +dependencies = [ + "async-trait", + "cfg-if", + "data-encoding", + "enum-as-inner", + "futures-channel", + "futures-io", + "futures-util", + "idna 0.2.3", + "ipnet", + "lazy_static", + "rand", + "smallvec", + "thiserror", + "tinyvec", + "tokio", + "tracing", + "url", +] + +[[package]] +name = "trust-dns-resolver" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aff21aa4dcefb0a1afbfac26deb0adc93888c7d295fb63ab273ef276ba2b7cfe" +dependencies = [ + "cfg-if", + "futures-util", + "ipconfig", + "lazy_static", + "lru-cache", + "parking_lot", + "resolv-conf", + "smallvec", + "thiserror", + "tokio", + "tracing", + "trust-dns-proto", +] + [[package]] name = "try-lock" version = "0.2.4" @@ -1270,6 +1532,17 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +[[package]] +name = "url" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643" +dependencies = [ + "form_urlencoded", + "idna 0.3.0", + "percent-encoding", +] + [[package]] name = "version_check" version = "0.9.4" @@ -1372,6 +1645,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "widestring" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17882f045410753661207383517a6f62ec3dbeb6a4ed2acce01f0728238d1983" + [[package]] name = "winapi" version = "0.3.9" @@ -1459,3 +1738,12 @@ name = "windows_x86_64_msvc" version = "0.42.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5" + +[[package]] +name = "winreg" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" +dependencies = [ + "winapi", +] diff --git a/Cargo.toml b/Cargo.toml index 3cf0e7aa4..2d85d0ef8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,8 @@ phf = { version = "0.11.1", features = ["macros"] } exitcode = "1.1.2" futures = "0.3" socket2 = { version = "0.4.7", features = ["all"] } +trust-dns-resolver = "0.22" +tokio-test = "0.4.2" [target.'cfg(not(target_env = "msvc"))'.dependencies] jemallocator = "0.5.0" diff --git a/README.md b/README.md index 2ba7ba828..e55b3a9ee 100644 --- a/README.md +++ b/README.md @@ -274,6 +274,8 @@ The config can be reloaded by sending a `kill -s SIGHUP` to the process or by qu | `default_role` | no | | `primary_reads_enabled` | no | | `query_parser_enabled` | no | +| `dns_max_ttl` | no | +| `dns_cache_enabled` | no | ## Benchmarks diff --git a/examples/docker/pgcat.toml b/examples/docker/pgcat.toml index c41c8cdd6..2c7f1e3c6 100644 --- a/examples/docker/pgcat.toml +++ b/examples/docker/pgcat.toml @@ -41,6 +41,15 @@ log_client_disconnections = false # Reload config automatically if it changes. autoreload = false +# If enabled, hostname resolution will be cached and +# and server connections will be invalidated if a change on the ip is +# detected. This check is done every `dns_max_ttl` seconds. +# dns_cache_enabled = false + +# The number of seconds to wait until we check again the +# cached hostnames resolution. 30 seconds by default. +# dns_max_ttl = 30 + # TLS # tls_certificate = "server.cert" # tls_private_key = "server.key" diff --git a/src/config.rs b/src/config.rs index f911d415a..57fadf332 100644 --- a/src/config.rs +++ b/src/config.rs @@ -174,6 +174,12 @@ pub struct General { #[serde(default)] // False pub log_client_disconnections: bool, + #[serde(default)] // False + pub dns_cache_enabled: bool, + + #[serde(default = "General::default_dns_max_ttl")] + pub dns_max_ttl: u64, + #[serde(default = "General::default_shutdown_timeout")] pub shutdown_timeout: u64, @@ -234,6 +240,10 @@ impl General { 60000 } + pub fn default_dns_max_ttl() -> u64 { + 30 + } + pub fn default_healthcheck_timeout() -> u64 { 1000 } @@ -270,6 +280,8 @@ impl Default for General { tcp_keepalives_interval: Self::default_tcp_keepalives_interval(), log_client_connections: false, log_client_disconnections: false, + dns_cache_enabled: false, + dns_max_ttl: Self::default_dns_max_ttl(), autoreload: false, tls_certificate: None, tls_private_key: None, diff --git a/src/dns_cache.rs b/src/dns_cache.rs new file mode 100644 index 000000000..475382f82 --- /dev/null +++ b/src/dns_cache.rs @@ -0,0 +1,339 @@ +use crate::config::get_config; +use crate::errors::Error; +use arc_swap::ArcSwap; +use log::{debug, error, info}; +use once_cell::sync::Lazy; +use std::collections::{HashMap, HashSet}; +use std::io; +use std::net::IpAddr; +use std::sync::Arc; +use std::sync::RwLock; +use tokio::time::{sleep, Duration}; +use trust_dns_resolver::error::ResolveResult; +use trust_dns_resolver::lookup_ip::LookupIp; +use trust_dns_resolver::TokioAsyncResolver; + +/// Cached Resolver Globally available +pub static CACHED_RESOLVER: Lazy>>> = + Lazy::new(|| ArcSwap::from_pointee(None)); + +// Ip addressed are returned as a set of addresses +// so we can compare. +#[derive(Clone, PartialEq, Debug)] +pub struct AddrSet { + set: HashSet, +} + +impl AddrSet { + fn new() -> AddrSet { + AddrSet { + set: HashSet::new(), + } + } +} + +impl From for AddrSet { + fn from(lookup_ip: LookupIp) -> Self { + let mut addr_set = AddrSet::new(); + for address in lookup_ip.iter() { + addr_set.set.insert(address); + } + addr_set + } +} + +/// +/// A CachedResolver is a DNS resolution cache mechanism with customizable expiration time. +/// +/// The system works as follows: +/// +/// When a host is to be resolved, if we have not resolved it before, a new resolution is +/// executed and stored in the internal cache. Concurrently, every `dns_max_ttl` time, the +/// cache is refreshed. +/// +/// # Example: +/// +/// ``` +/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver}; +/// +/// # tokio_test::block_on(async { +/// let config = CachedResolverConfig{dns_max_ttl: 10}; +/// let resolver = CachedResolver::new(config).await.unwrap(); +/// let addrset = resolver.lookup_ip("www.example.com.").await.unwrap(); +/// # }) +/// ``` +/// +/// // Now the ip resolution is stored in local cache and subsequent +/// // calls will be returned from cache. Also, the cache is refreshed +/// // and updated every 10 seconds. +/// +/// // You can now check if an 'old' lookup differs from what it's currently +/// // store in cache by using `has_changed`. +/// resolver.has_changed("www.example.com.", addrset) +pub struct CachedResolver { + // The configuration of the cached_resolver. + config: CachedResolverConfig, + + // This is the hash that contains the hash. + data: Arc>>, + + // The resolver to be used for DNS queries. + resolver: Arc, +} + +/// +/// Configuration +#[derive(Clone, Debug)] +pub struct CachedResolverConfig { + /// Amount of time in secods that a resolved dns address is considered stale. + pub dns_max_ttl: u64, +} + +impl CachedResolver { + /// + /// Returns a new Arc based on passed configuration. + /// It also starts the loop that will refresh cache entries. + /// + /// # Arguments: + /// + /// * `config` - The `CachedResolverConfig` to be used to create the resolver. + /// + /// # Example: + /// + /// ``` + /// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver}; + /// + /// # tokio_test::block_on(async { + /// let config = CachedResolverConfig{dns_max_ttl: 10}; + /// let resolver = CachedResolver::new(config); + /// # }) + /// ``` + /// + pub async fn new(config: CachedResolverConfig) -> io::Result> { + // Construct a new Resolver with default configuration options + let resolver = Arc::new(TokioAsyncResolver::tokio_from_system_conf()?); + let data = Arc::new(RwLock::new(HashMap::new())); + + let self_ref = Arc::new(Self { + config, + resolver, + data, + }); + let clone_self_ref = self_ref.clone(); + + info!("Scheduling DNS refresh loop"); + tokio::task::spawn(async move { + clone_self_ref.refresh_dns_entries_loop().await; + }); + + Ok(self_ref) + } + + // Schedules the refresher + async fn refresh_dns_entries_loop(&self) { + let resolver = TokioAsyncResolver::tokio_from_system_conf().unwrap(); + let interval = Duration::from_secs(self.config.dns_max_ttl); + loop { + debug!("Begin refreshing cached DNS addresses."); + // To minimize the time we hold the lock, we first create + // an array with keys. + let mut hostnames: Vec = Vec::new(); + { + for hostname in self.data.read().unwrap().keys() { + hostnames.push(hostname.clone()); + } + } + + for hostname in hostnames.iter() { + let addrset = self + .fetch_from_cache(hostname.as_str()) + .expect("Could not obtain expected address from cache, this should not happen"); + + match resolver.lookup_ip(hostname).await { + Ok(lookup_ip) => { + let new_addrset = AddrSet::from(lookup_ip); + debug!( + "Obtained address for host ({}) -> ({:?})", + hostname, new_addrset + ); + + if addrset != new_addrset { + debug!( + "Addr changed from {:?} to {:?} updating cache.", + addrset, new_addrset + ); + self.store_in_cache(hostname, new_addrset); + } + } + Err(err) => { + error!( + "There was an error trying to resolv {}: ({}).", + hostname, err + ); + } + } + } + debug!("Finished refreshing cached DNS addresses."); + sleep(interval).await; + } + } + + /// Returns a `AddrSet` given the specified hostname. + /// + /// This method first tries to fetch the value from the cache, if it misses + /// then it is resolved and stored in the cache. TTL from records is ignored. + /// + /// # Arguments + /// + /// * `host` - A string slice referencing the hostname to be resolved. + /// + /// # Example: + /// + /// ``` + /// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver}; + /// + /// # tokio_test::block_on(async { + /// let config = CachedResolverConfig { dns_max_ttl: 10 }; + /// let resolver = CachedResolver::new(config).await.unwrap(); + /// let response = resolver.lookup_ip("www.google.com."); + /// # }) + /// ``` + /// + pub async fn lookup_ip(&self, host: &str) -> ResolveResult { + debug!("Lookup up {} in cache", host); + match self.fetch_from_cache(host) { + Some(addr_set) => { + debug!("Cache hit!"); + Ok(addr_set) + } + None => { + debug!("Not found, executing a dns query!"); + let addr_set = AddrSet::from(self.resolver.lookup_ip(host).await?); + debug!("Obtained: {:?}", addr_set); + self.store_in_cache(host, addr_set.clone()); + Ok(addr_set) + } + } + } + + // + // Returns true if the stored host resolution differs from the AddrSet passed. + pub fn has_changed(&self, host: &str, addr_set: &AddrSet) -> bool { + if let Some(fetched_addr_set) = self.fetch_from_cache(host) { + return fetched_addr_set != *addr_set; + } + false + } + + // Fetches an AddrSet from the inner cache adquiring the read lock. + fn fetch_from_cache(&self, key: &str) -> Option { + let hash = &self.data.read().unwrap(); + if let Some(addr_set) = hash.get(key) { + return Some(addr_set.clone()); + } + None + } + + // Sets up the global CACHED_RESOLVER static variable so we can globally use DNS + // cache. + pub async fn from_config() -> Result<(), Error> { + let config = get_config(); + + // Configure dns_cache if enabled + if config.general.dns_cache_enabled { + info!("Starting Dns cache"); + let cached_resolver_config = CachedResolverConfig { + dns_max_ttl: config.general.dns_max_ttl, + }; + return match CachedResolver::new(cached_resolver_config).await { + Ok(ok) => { + let value = Some(ArcSwap::from(ok)); + CACHED_RESOLVER.store(Arc::new(value)); + Ok(()) + } + Err(err) => { + let message = format!("Error Starting cached_resolver error: {:?}, will continue without this feature.", err); + Err(Error::DNSCachedError(message)) + } + }; + } + Ok(()) + } + + // Stores the AddrSet in cache adquiring the write lock. + fn store_in_cache(&self, host: &str, addr_set: AddrSet) { + self.data + .write() + .unwrap() + .insert(host.to_string(), addr_set); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use trust_dns_resolver::error::ResolveError; + + #[tokio::test] + async fn new() { + let config = CachedResolverConfig { dns_max_ttl: 10 }; + let resolver = CachedResolver::new(config).await; + assert!(resolver.is_ok()); + } + + #[tokio::test] + async fn lookup_ip() { + let config = CachedResolverConfig { dns_max_ttl: 10 }; + let resolver = CachedResolver::new(config).await.unwrap(); + let response = resolver.lookup_ip("www.google.com.").await; + assert!(response.is_ok()); + } + + #[tokio::test] + async fn has_changed() { + let config = CachedResolverConfig { dns_max_ttl: 10 }; + let resolver = CachedResolver::new(config).await.unwrap(); + let hostname = "www.google.com."; + let response = resolver.lookup_ip(hostname).await; + let addr_set = response.unwrap(); + assert!(!resolver.has_changed(hostname, &addr_set)); + } + + #[tokio::test] + async fn unknown_host() { + let config = CachedResolverConfig { dns_max_ttl: 10 }; + let resolver = CachedResolver::new(config).await.unwrap(); + let hostname = "www.idontexists."; + let response = resolver.lookup_ip(hostname).await; + assert!(matches!(response, Err(ResolveError { .. }))); + } + + #[tokio::test] + async fn incorrect_address() { + let config = CachedResolverConfig { dns_max_ttl: 10 }; + let resolver = CachedResolver::new(config).await.unwrap(); + let hostname = "w ww.idontexists."; + let response = resolver.lookup_ip(hostname).await; + assert!(matches!(response, Err(ResolveError { .. }))); + assert!(!resolver.has_changed(hostname, &AddrSet::new())); + } + + #[tokio::test] + // Ok, this test is based on the fact that google does DNS RR + // and does not responds with every available ip everytime, so + // if I cache here, it will miss after one cache iteration or two. + async fn thread() { + env_logger::init(); + let config = CachedResolverConfig { dns_max_ttl: 10 }; + let resolver = CachedResolver::new(config).await.unwrap(); + let hostname = "www.google.com."; + let response = resolver.lookup_ip(hostname).await; + let addr_set = response.unwrap(); + assert!(!resolver.has_changed(hostname, &addr_set)); + let resolver_for_refresher = resolver.clone(); + let _thread_handle = tokio::task::spawn(async move { + resolver_for_refresher.refresh_dns_entries_loop().await; + }); + assert!(!resolver.has_changed(hostname, &addr_set)); + } +} diff --git a/src/errors.rs b/src/errors.rs index 4ac23a855..2a66f2398 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -12,6 +12,7 @@ pub enum Error { ClientError(String), TlsError, StatementTimeout, + DNSCachedError(String), ShuttingDown, ParseBytesError(String), } diff --git a/src/lib.rs b/src/lib.rs index e9a683f3d..7702f3070 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod config; pub mod constants; +pub mod dns_cache; pub mod errors; pub mod messages; pub mod pool; diff --git a/src/main.rs b/src/main.rs index cd408c491..3d85be4fc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -36,6 +36,7 @@ extern crate sqlparser; extern crate tokio; extern crate tokio_rustls; extern crate toml; +extern crate trust_dns_resolver; #[cfg(not(target_env = "msvc"))] use jemallocator::Jemalloc; @@ -64,6 +65,7 @@ mod admin; mod client; mod config; mod constants; +mod dns_cache; mod errors; mod messages; mod pool; @@ -163,8 +165,14 @@ fn main() -> Result<(), Box> { let (stats_tx, stats_rx) = mpsc::channel(100_000); REPORTER.store(Arc::new(Reporter::new(stats_tx.clone()))); - // Connection pool that allows to query all shards and replicas. - match ConnectionPool::from_config(client_server_map.clone()).await { + // Starts (if enabled) dns cache before pools initialization + match dns_cache::CachedResolver::from_config().await { + Ok(_) => (), + Err(err) => error!("DNS cache initialization error: {:?}", err), + }; + + // Connection pool that allows to query all shards and replicas. + match ConnectionPool::from_config(client_server_map.clone()).await { Ok(_) => (), Err(err) => { error!("Pool error: {:?}", err); diff --git a/src/server.rs b/src/server.rs index 1d9bcd14f..af2e8950f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,6 +3,7 @@ use bytes::{Buf, BufMut, BytesMut}; use log::{debug, error, info, trace, warn}; use std::io::Read; +use std::net::IpAddr; use std::time::SystemTime; use tokio::io::{AsyncReadExt, BufReader}; use tokio::net::{ @@ -12,6 +13,7 @@ use tokio::net::{ use crate::config::{Address, User}; use crate::constants::*; +use crate::dns_cache::{AddrSet, CACHED_RESOLVER}; use crate::errors::Error; use crate::messages::*; use crate::pool::ClientServerMap; @@ -68,6 +70,9 @@ pub struct Server { // Last time that a successful server send or response happened last_activity: SystemTime, + + // Associated addresses used + addr_set: Option, } impl Server { @@ -81,6 +86,28 @@ impl Server { client_server_map: ClientServerMap, stats: Reporter, ) -> Result { + let cached_resolver = CACHED_RESOLVER.load(); + let addr_set = match cached_resolver.as_ref() { + Some(cached_resolver) => { + if address.host.parse::().is_err() { + debug!("Resolving {}", &address.host); + match cached_resolver.load().lookup_ip(&address.host).await { + Ok(ok) => { + debug!("Obtained: {:?}", ok); + Some(ok) + } + Err(err) => { + warn!("Error trying to resolve {}, ({:?})", &address.host, err); + None + } + } + } else { + None + } + } + None => None, + }; + let mut stream = match TcpStream::connect(&format!("{}:{}", &address.host, address.port)).await { Ok(stream) => stream, @@ -330,6 +357,7 @@ impl Server { bad: false, needs_cleanup: false, client_server_map, + addr_set, connected_at: chrono::offset::Utc::now().naive_utc(), stats, application_name: String::new(), @@ -561,7 +589,19 @@ impl Server { /// Server & client are out of sync, we must discard this connection. /// This happens with clients that misbehave. pub fn is_bad(&self) -> bool { - self.bad + if self.bad { + return self.bad; + }; + + if let Some(cached_resolver) = CACHED_RESOLVER.load().as_ref() { + if let Some(addr_set) = &self.addr_set { + if cached_resolver.load().has_changed(self.address.host.as_str(), addr_set) { + warn!("DNS changed for {}, it was {:?}. Dropping server connection.", self.address.host.as_str(), addr_set); + return true + } + } + } + false } /// Get server startup information to forward it to the client.