Skip to content

Commit

Permalink
Allow reloading dns cached
Browse files Browse the repository at this point in the history
  • Loading branch information
magec committed Apr 24, 2023
1 parent e2f1aa2 commit 3e0fbbd
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 72 deletions.
5 changes: 5 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::sync::Arc;
use tokio::fs::File;
use tokio::io::AsyncReadExt;

use crate::dns_cache::CachedResolver;
use crate::errors::Error;
use crate::pool::{ClientServerMap, ConnectionPool};
use crate::sharding::ShardingFunction;
Expand Down Expand Up @@ -1032,6 +1033,10 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, E
}
};
let new_config = get_config();
match CachedResolver::from_config().await {
Ok(_) => (),
Err(err) => error!("DNS cache reinitialization error: {:?}", err),
};

if old_config.pools != new_config.pools {
info!("Pool configuration changed");
Expand Down
172 changes: 122 additions & 50 deletions src/dns_cache.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
use crate::config::get_config;
use crate::errors::Error;
use arc_swap::ArcSwap;
use log::{debug, error, info};
use log::{debug, error, info, warn};
use once_cell::sync::Lazy;
use std::collections::{HashMap, HashSet};
use std::io;
use std::net::IpAddr;
use std::sync::Arc;
use std::sync::RwLock;
use tokio::time::{sleep, Duration};
use trust_dns_resolver::error::ResolveResult;
use trust_dns_resolver::error::{ResolveError, ResolveResult};
use trust_dns_resolver::lookup_ip::LookupIp;
use trust_dns_resolver::TokioAsyncResolver;

/// Cached Resolver Globally available
pub static CACHED_RESOLVER: Lazy<ArcSwap<Option<ArcSwap<CachedResolver>>>> =
Lazy::new(|| ArcSwap::from_pointee(None));
pub static CACHED_RESOLVER: Lazy<ArcSwap<CachedResolver>> =
Lazy::new(|| ArcSwap::from_pointee(CachedResolver::default()));

// Ip addressed are returned as a set of addresses
// so we can compare.
Expand Down Expand Up @@ -70,23 +70,45 @@ impl From<LookupIp> for AddrSet {
/// // You can now check if an 'old' lookup differs from what it's currently
/// // store in cache by using `has_changed`.
/// resolver.has_changed("www.example.com.", addrset)
#[derive(Default)]
pub struct CachedResolver {
// The configuration of the cached_resolver.
config: CachedResolverConfig,

// This is the hash that contains the hash.
data: Arc<RwLock<HashMap<String, AddrSet>>>,
data: Option<RwLock<HashMap<String, AddrSet>>>,

// The resolver to be used for DNS queries.
resolver: Arc<TokioAsyncResolver>,
resolver: Option<TokioAsyncResolver>,

// The RefreshLoop
refresh_loop: RwLock<Option<tokio::task::JoinHandle<()>>>,
}

///
/// Configuration
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Default, PartialEq)]
pub struct CachedResolverConfig {
/// Amount of time in secods that a resolved dns address is considered stale.
pub dns_max_ttl: u64,
dns_max_ttl: u64,

/// Enabled or disabled? (this is so we can reload config)
enabled: bool,
}

impl CachedResolverConfig {
fn new(dns_max_ttl: u64, enabled: bool) -> Self {
CachedResolverConfig {
dns_max_ttl,
enabled,
}
}
}

impl From<crate::config::Config> for CachedResolverConfig {
fn from(config: crate::config::Config) -> Self {
CachedResolverConfig::new(config.general.dns_max_ttl, config.general.dns_cache_enabled)
}
}

impl CachedResolver {
Expand All @@ -109,24 +131,42 @@ impl CachedResolver {
/// # })
/// ```
///
pub async fn new(config: CachedResolverConfig) -> io::Result<Arc<Self>> {
pub async fn new(
config: CachedResolverConfig,
data: Option<HashMap<String, AddrSet>>,
) -> Result<Arc<Self>, io::Error> {
// Construct a new Resolver with default configuration options
let resolver = Arc::new(TokioAsyncResolver::tokio_from_system_conf()?);
let data = Arc::new(RwLock::new(HashMap::new()));
let resolver = Some(TokioAsyncResolver::tokio_from_system_conf()?);

let data = if let Some(hash) = data {
Some(RwLock::new(hash))
} else {
Some(RwLock::new(HashMap::new()))
};

let self_ref = Arc::new(Self {
let instance = Arc::new(Self {
config,
resolver,
data,
refresh_loop: RwLock::new(None),
});
let clone_self_ref = self_ref.clone();

info!("Scheduling DNS refresh loop");
tokio::task::spawn(async move {
clone_self_ref.refresh_dns_entries_loop().await;
});
if instance.enabled() {
info!("Scheduling DNS refresh loop");
let refresh_loop = tokio::task::spawn({
let instance = instance.clone();
async move {
instance.refresh_dns_entries_loop().await;
}
});
*(instance.refresh_loop.write().unwrap()) = Some(refresh_loop);
}

Ok(self_ref)
Ok(instance)
}

pub fn enabled(&self) -> bool {
self.config.enabled
}

// Schedules the refresher
Expand All @@ -139,8 +179,10 @@ impl CachedResolver {
// an array with keys.
let mut hostnames: Vec<String> = Vec::new();
{
for hostname in self.data.read().unwrap().keys() {
hostnames.push(hostname.clone());
if let Some(ref data) = self.data {
for hostname in data.read().unwrap().keys() {
hostnames.push(hostname.clone());
}
}
}

Expand Down Expand Up @@ -208,10 +250,14 @@ impl CachedResolver {
}
None => {
debug!("Not found, executing a dns query!");
let addr_set = AddrSet::from(self.resolver.lookup_ip(host).await?);
debug!("Obtained: {:?}", addr_set);
self.store_in_cache(host, addr_set.clone());
Ok(addr_set)
if let Some(ref resolver) = self.resolver {
let addr_set = AddrSet::from(resolver.lookup_ip(host).await?);
debug!("Obtained: {:?}", addr_set);
self.store_in_cache(host, addr_set.clone());
Ok(addr_set)
} else {
Err(ResolveError::from("No resolver available"))
}
}
}
}
Expand All @@ -227,71 +273,88 @@ impl CachedResolver {

// Fetches an AddrSet from the inner cache adquiring the read lock.
fn fetch_from_cache(&self, key: &str) -> Option<AddrSet> {
let hash = &self.data.read().unwrap();
if let Some(addr_set) = hash.get(key) {
return Some(addr_set.clone());
if let Some(ref hash) = self.data {
if let Some(addr_set) = hash.read().unwrap().get(key) {
return Some(addr_set.clone());
}
}
None
}

// Sets up the global CACHED_RESOLVER static variable so we can globally use DNS
// cache.
pub async fn from_config() -> Result<(), Error> {
let config = get_config();
let cached_resolver = CACHED_RESOLVER.load();
let desired_config = CachedResolverConfig::from(get_config());

// Configure dns_cache if enabled
if config.general.dns_cache_enabled {
info!("Starting Dns cache");
let cached_resolver_config = CachedResolverConfig {
dns_max_ttl: config.general.dns_max_ttl,
if cached_resolver.config != desired_config {
if let Some(ref refresh_loop) = *(cached_resolver.refresh_loop.write().unwrap()) {
warn!("Killing Dnscache refresh loop as its configuration is being reloaded");
refresh_loop.abort()
}
let new_resolver = if let Some(ref data) = cached_resolver.data {
let data = Some(data.read().unwrap().clone());
CachedResolver::new(desired_config, data).await
} else {
CachedResolver::new(desired_config, None).await
};
return match CachedResolver::new(cached_resolver_config).await {

match new_resolver {
Ok(ok) => {
let value = Some(ArcSwap::from(ok));
CACHED_RESOLVER.store(Arc::new(value));
CACHED_RESOLVER.store(ok);
Ok(())
}
Err(err) => {
let message = format!("Error Starting cached_resolver error: {:?}, will continue without this feature.", err);
let message = format!("Error setting up cached_resolver. Error: {:?}, will continue without this feature.", err);
Err(Error::DNSCachedError(message))
}
};
}
} else {
Ok(())
}
Ok(())
}

// Stores the AddrSet in cache adquiring the write lock.
fn store_in_cache(&self, host: &str, addr_set: AddrSet) {
self.data
.write()
.unwrap()
.insert(host.to_string(), addr_set);
if let Some(ref data) = self.data {
data.write().unwrap().insert(host.to_string(), addr_set);
} else {
error!("Could not insert, Hash not initialized");
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use trust_dns_resolver::error::ResolveError;

#[tokio::test]
async fn new() {
let config = CachedResolverConfig { dns_max_ttl: 10 };
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config).await;
assert!(resolver.is_ok());
}

#[tokio::test]
async fn lookup_ip() {
let config = CachedResolverConfig { dns_max_ttl: 10 };
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config).await.unwrap();
let response = resolver.lookup_ip("www.google.com.").await;
assert!(response.is_ok());
}

#[tokio::test]
async fn has_changed() {
let config = CachedResolverConfig { dns_max_ttl: 10 };
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config).await.unwrap();
let hostname = "www.google.com.";
let response = resolver.lookup_ip(hostname).await;
Expand All @@ -301,7 +364,10 @@ mod tests {

#[tokio::test]
async fn unknown_host() {
let config = CachedResolverConfig { dns_max_ttl: 10 };
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config).await.unwrap();
let hostname = "www.idontexists.";
let response = resolver.lookup_ip(hostname).await;
Expand All @@ -310,7 +376,10 @@ mod tests {

#[tokio::test]
async fn incorrect_address() {
let config = CachedResolverConfig { dns_max_ttl: 10 };
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config).await.unwrap();
let hostname = "w ww.idontexists.";
let response = resolver.lookup_ip(hostname).await;
Expand All @@ -324,7 +393,10 @@ mod tests {
// if I cache here, it will miss after one cache iteration or two.
async fn thread() {
env_logger::init();
let config = CachedResolverConfig { dns_max_ttl: 10 };
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config).await.unwrap();
let hostname = "www.google.com.";
let response = resolver.lookup_ip(hostname).await;
Expand Down
37 changes: 15 additions & 22 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,25 +93,21 @@ impl Server {
auth_hash: Arc<RwLock<Option<String>>>,
) -> Result<Server, Error> {
let cached_resolver = CACHED_RESOLVER.load();
let addr_set = match cached_resolver.as_ref() {
Some(cached_resolver) => {
if address.host.parse::<IpAddr>().is_err() {
debug!("Resolving {}", &address.host);
match cached_resolver.load().lookup_ip(&address.host).await {
Ok(ok) => {
debug!("Obtained: {:?}", ok);
Some(ok)
}
Err(err) => {
warn!("Error trying to resolve {}, ({:?})", &address.host, err);
None
}
}
} else {
let mut addr_set: Option<AddrSet> = None;

// If we are caching addresses and hostname is not an IP
if cached_resolver.enabled() && address.host.parse::<IpAddr>().is_err() {
debug!("Resolving {}", &address.host);
addr_set = match cached_resolver.lookup_ip(&address.host).await {
Ok(ok) => {
debug!("Obtained: {:?}", ok);
Some(ok)
}
Err(err) => {
warn!("Error trying to resolve {}, ({:?})", &address.host, err);
None
}
}
None => None,
};

let mut stream =
Expand Down Expand Up @@ -729,13 +725,10 @@ impl Server {
if self.bad {
return self.bad;
};

if let Some(cached_resolver) = CACHED_RESOLVER.load().as_ref() {
let cached_resolver = CACHED_RESOLVER.load();
if cached_resolver.enabled() {
if let Some(addr_set) = &self.addr_set {
if cached_resolver
.load()
.has_changed(self.address.host.as_str(), addr_set)
{
if cached_resolver.has_changed(self.address.host.as_str(), addr_set) {
warn!(
"DNS changed for {}, it was {:?}. Dropping server connection.",
self.address.host.as_str(),
Expand Down

0 comments on commit 3e0fbbd

Please sign in to comment.