diff --git a/src/config.rs b/src/config.rs index 57fadf332..c0bf9101e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -11,6 +11,7 @@ use std::sync::Arc; use tokio::fs::File; use tokio::io::AsyncReadExt; +use crate::dns_cache::CachedResolver; use crate::errors::Error; use crate::pool::{ClientServerMap, ConnectionPool}; use crate::sharding::ShardingFunction; @@ -833,6 +834,10 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result (), + Err(err) => error!("DNS cache reinitialization error: {:?}", err), + }; if old_config.pools != new_config.pools { info!("Pool configuration changed"); diff --git a/src/dns_cache.rs b/src/dns_cache.rs index 475382f82..75bdab370 100644 --- a/src/dns_cache.rs +++ b/src/dns_cache.rs @@ -1,7 +1,7 @@ use crate::config::get_config; use crate::errors::Error; use arc_swap::ArcSwap; -use log::{debug, error, info}; +use log::{debug, error, info, warn}; use once_cell::sync::Lazy; use std::collections::{HashMap, HashSet}; use std::io; @@ -9,13 +9,13 @@ use std::net::IpAddr; use std::sync::Arc; use std::sync::RwLock; use tokio::time::{sleep, Duration}; -use trust_dns_resolver::error::ResolveResult; +use trust_dns_resolver::error::{ResolveError, ResolveResult}; use trust_dns_resolver::lookup_ip::LookupIp; use trust_dns_resolver::TokioAsyncResolver; /// Cached Resolver Globally available -pub static CACHED_RESOLVER: Lazy>>> = - Lazy::new(|| ArcSwap::from_pointee(None)); +pub static CACHED_RESOLVER: Lazy> = + Lazy::new(|| ArcSwap::from_pointee(CachedResolver::default())); // Ip addressed are returned as a set of addresses // so we can compare. @@ -70,23 +70,45 @@ impl From for AddrSet { /// // You can now check if an 'old' lookup differs from what it's currently /// // store in cache by using `has_changed`. /// resolver.has_changed("www.example.com.", addrset) +#[derive(Default)] pub struct CachedResolver { // The configuration of the cached_resolver. config: CachedResolverConfig, // This is the hash that contains the hash. - data: Arc>>, + data: Option>>, // The resolver to be used for DNS queries. - resolver: Arc, + resolver: Option, + + // The RefreshLoop + refresh_loop: RwLock>>, } /// /// Configuration -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default, PartialEq)] pub struct CachedResolverConfig { /// Amount of time in secods that a resolved dns address is considered stale. - pub dns_max_ttl: u64, + dns_max_ttl: u64, + + /// Enabled or disabled? (this is so we can reload config) + enabled: bool, +} + +impl CachedResolverConfig { + fn new(dns_max_ttl: u64, enabled: bool) -> Self { + CachedResolverConfig { + dns_max_ttl, + enabled, + } + } +} + +impl From for CachedResolverConfig { + fn from(config: crate::config::Config) -> Self { + CachedResolverConfig::new(config.general.dns_max_ttl, config.general.dns_cache_enabled) + } } impl CachedResolver { @@ -109,24 +131,39 @@ impl CachedResolver { /// # }) /// ``` /// - pub async fn new(config: CachedResolverConfig) -> io::Result> { + pub async fn new(config: CachedResolverConfig, data: Option>) -> Result, io::Error> { // Construct a new Resolver with default configuration options - let resolver = Arc::new(TokioAsyncResolver::tokio_from_system_conf()?); - let data = Arc::new(RwLock::new(HashMap::new())); + let resolver = Some(TokioAsyncResolver::tokio_from_system_conf()?); + + let data = if let Some(hash) = data { + Some(RwLock::new(hash)) + } else { + Some(RwLock::new(HashMap::new())) + }; - let self_ref = Arc::new(Self { + let instance = Arc::new(Self { config, resolver, data, + refresh_loop: RwLock::new(None), }); - let clone_self_ref = self_ref.clone(); - info!("Scheduling DNS refresh loop"); - tokio::task::spawn(async move { - clone_self_ref.refresh_dns_entries_loop().await; - }); + if instance.enabled() { + info!("Scheduling DNS refresh loop"); + let refresh_loop = tokio::task::spawn({ + let instance = instance.clone(); + async move { + instance.refresh_dns_entries_loop().await; + } + }); + *(instance.refresh_loop.write().unwrap()) = Some(refresh_loop); + } + + Ok(instance) + } - Ok(self_ref) + pub fn enabled(&self) -> bool { + self.config.enabled } // Schedules the refresher @@ -139,8 +176,10 @@ impl CachedResolver { // an array with keys. let mut hostnames: Vec = Vec::new(); { - for hostname in self.data.read().unwrap().keys() { - hostnames.push(hostname.clone()); + if let Some(ref data) = self.data { + for hostname in data.read().unwrap().keys() { + hostnames.push(hostname.clone()); + } } } @@ -208,10 +247,14 @@ impl CachedResolver { } None => { debug!("Not found, executing a dns query!"); - let addr_set = AddrSet::from(self.resolver.lookup_ip(host).await?); - debug!("Obtained: {:?}", addr_set); - self.store_in_cache(host, addr_set.clone()); - Ok(addr_set) + if let Some(ref resolver) = self.resolver { + let addr_set = AddrSet::from(resolver.lookup_ip(host).await?); + debug!("Obtained: {:?}", addr_set); + self.store_in_cache(host, addr_set.clone()); + Ok(addr_set) + } else { + Err(ResolveError::from("No resolver available")) + } } } } @@ -227,9 +270,10 @@ impl CachedResolver { // Fetches an AddrSet from the inner cache adquiring the read lock. fn fetch_from_cache(&self, key: &str) -> Option { - let hash = &self.data.read().unwrap(); - if let Some(addr_set) = hash.get(key) { - return Some(addr_set.clone()); + if let Some(ref hash) = self.data { + if let Some(addr_set) = hash.read().unwrap().get(key) { + return Some(addr_set.clone()); + } } None } @@ -237,38 +281,49 @@ impl CachedResolver { // Sets up the global CACHED_RESOLVER static variable so we can globally use DNS // cache. pub async fn from_config() -> Result<(), Error> { - let config = get_config(); - - // Configure dns_cache if enabled - if config.general.dns_cache_enabled { - info!("Starting Dns cache"); - let cached_resolver_config = CachedResolverConfig { - dns_max_ttl: config.general.dns_max_ttl, - }; - return match CachedResolver::new(cached_resolver_config).await { + let cached_resolver = CACHED_RESOLVER.load(); + let desired_config = CachedResolverConfig::from(get_config()); + + if cached_resolver.config != desired_config { + if let Some(ref refresh_loop) = *(cached_resolver.refresh_loop.write().unwrap()) { + warn!("Killing Dnscache refresh loop as its configuration is being reloaded"); + refresh_loop.abort() + } + let new_resolver = if let Some(ref data) = cached_resolver.data { + let data = Some(data.read().unwrap().clone()); + CachedResolver::new(desired_config, data).await + } else { + CachedResolver::new(desired_config, None).await + }; + + match new_resolver { Ok(ok) => { - let value = Some(ArcSwap::from(ok)); - CACHED_RESOLVER.store(Arc::new(value)); - Ok(()) + CACHED_RESOLVER.store(ok); + Ok(()) } Err(err) => { - let message = format!("Error Starting cached_resolver error: {:?}, will continue without this feature.", err); - Err(Error::DNSCachedError(message)) + let message = format!("Error setting up cached_resolver. Error: {:?}, will continue without this feature.", err); + Err(Error::DNSCachedError(message)) } - }; - } - Ok(()) + } + } else { + Ok(()) + } } // Stores the AddrSet in cache adquiring the write lock. fn store_in_cache(&self, host: &str, addr_set: AddrSet) { - self.data - .write() - .unwrap() - .insert(host.to_string(), addr_set); + if let Some(ref data) = self.data { + data + .write() + .unwrap() + .insert(host.to_string(), addr_set); + } else { + error!("Could not insert, Hash not initialized"); + } } -} +} #[cfg(test)] mod tests { use super::*; @@ -276,14 +331,20 @@ mod tests { #[tokio::test] async fn new() { - let config = CachedResolverConfig { dns_max_ttl: 10 }; + let config = CachedResolverConfig { + dns_max_ttl: 10, + enabled: true, + }; let resolver = CachedResolver::new(config).await; assert!(resolver.is_ok()); } #[tokio::test] async fn lookup_ip() { - let config = CachedResolverConfig { dns_max_ttl: 10 }; + let config = CachedResolverConfig { + dns_max_ttl: 10, + enabled: true, + }; let resolver = CachedResolver::new(config).await.unwrap(); let response = resolver.lookup_ip("www.google.com.").await; assert!(response.is_ok()); @@ -291,7 +352,10 @@ mod tests { #[tokio::test] async fn has_changed() { - let config = CachedResolverConfig { dns_max_ttl: 10 }; + let config = CachedResolverConfig { + dns_max_ttl: 10, + enabled: true, + }; let resolver = CachedResolver::new(config).await.unwrap(); let hostname = "www.google.com."; let response = resolver.lookup_ip(hostname).await; @@ -301,7 +365,10 @@ mod tests { #[tokio::test] async fn unknown_host() { - let config = CachedResolverConfig { dns_max_ttl: 10 }; + let config = CachedResolverConfig { + dns_max_ttl: 10, + enabled: true, + }; let resolver = CachedResolver::new(config).await.unwrap(); let hostname = "www.idontexists."; let response = resolver.lookup_ip(hostname).await; @@ -310,7 +377,10 @@ mod tests { #[tokio::test] async fn incorrect_address() { - let config = CachedResolverConfig { dns_max_ttl: 10 }; + let config = CachedResolverConfig { + dns_max_ttl: 10, + enabled: true, + }; let resolver = CachedResolver::new(config).await.unwrap(); let hostname = "w ww.idontexists."; let response = resolver.lookup_ip(hostname).await; @@ -324,7 +394,10 @@ mod tests { // if I cache here, it will miss after one cache iteration or two. async fn thread() { env_logger::init(); - let config = CachedResolverConfig { dns_max_ttl: 10 }; + let config = CachedResolverConfig { + dns_max_ttl: 10, + enabled: true, + }; let resolver = CachedResolver::new(config).await.unwrap(); let hostname = "www.google.com."; let response = resolver.lookup_ip(hostname).await; diff --git a/src/server.rs b/src/server.rs index af2e8950f..0ac71c707 100644 --- a/src/server.rs +++ b/src/server.rs @@ -87,25 +87,21 @@ impl Server { stats: Reporter, ) -> Result { let cached_resolver = CACHED_RESOLVER.load(); - let addr_set = match cached_resolver.as_ref() { - Some(cached_resolver) => { - if address.host.parse::().is_err() { - debug!("Resolving {}", &address.host); - match cached_resolver.load().lookup_ip(&address.host).await { - Ok(ok) => { - debug!("Obtained: {:?}", ok); - Some(ok) - } - Err(err) => { - warn!("Error trying to resolve {}, ({:?})", &address.host, err); - None - } - } - } else { + let mut addr_set: Option = None; + + // If we are caching addresses and hostname is not an IP + if cached_resolver.enabled() && address.host.parse::().is_err() { + debug!("Resolving {}", &address.host); + addr_set = match cached_resolver.lookup_ip(&address.host).await { + Ok(ok) => { + debug!("Obtained: {:?}", ok); + Some(ok) + } + Err(err) => { + warn!("Error trying to resolve {}, ({:?})", &address.host, err); None } } - None => None, }; let mut stream = @@ -592,13 +588,19 @@ impl Server { if self.bad { return self.bad; }; - - if let Some(cached_resolver) = CACHED_RESOLVER.load().as_ref() { + let cached_resolver = CACHED_RESOLVER.load(); + if cached_resolver.enabled() { if let Some(addr_set) = &self.addr_set { - if cached_resolver.load().has_changed(self.address.host.as_str(), addr_set) { - warn!("DNS changed for {}, it was {:?}. Dropping server connection.", self.address.host.as_str(), addr_set); - return true - } + if cached_resolver + .has_changed(self.address.host.as_str(), addr_set) + { + warn!( + "DNS changed for {}, it was {:?}. Dropping server connection.", + self.address.host.as_str(), + addr_set + ); + return true; + } } } false