From a0fe5dccc80171403464cfe80ff2d7c790fd5c9a Mon Sep 17 00:00:00 2001 From: Yureka Date: Thu, 23 Nov 2023 14:18:57 +0100 Subject: [PATCH] resolve reverse dns concurrently --- src/api.rs | 139 ++++++++++++++++++++++++++++++++++----------------- src/store.rs | 13 +---- 2 files changed, 95 insertions(+), 57 deletions(-) diff --git a/src/api.rs b/src/api.rs index 71a17d5..969e6a6 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,19 +1,19 @@ -use crate::store::{Query, QueryLimits, Store, QueryResult, RouteAttrs, RouterId, ResolvedRouteAttrs, ResolvedNexthop, NetQuery}; +use crate::store::{NetQuery, Query, QueryLimits, QueryResult, ResolvedNexthop, Store}; use axum::body::StreamBody; use axum::extract::{Query as AxumQuery, State}; use axum::http::StatusCode; -use axum::response::{Response, IntoResponse}; +use axum::response::{IntoResponse, Response}; use axum::routing::get; use axum::Router; +use futures_util::{FutureExt, StreamExt}; use hickory_resolver::config::LookupIpStrategy; use hickory_resolver::TokioAsyncResolver; -use futures_util::{FutureExt, StreamExt}; -use log::*; -use serde::{Serialize, Deserialize}; use ipnet::IpNet; -use std::net::IpAddr; -use std::collections::HashMap; +use log::*; +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; use std::convert::Infallible; +use std::net::IpAddr; use std::net::SocketAddr; use std::sync::Arc; @@ -31,18 +31,12 @@ pub struct ApiServerConfig { } #[derive(Debug, Clone, Serialize)] -pub enum NexthopResolved { - ReverseDns(String), - RouterId(RouterId), -} - -#[derive(Debug, Clone, Serialize, Default)] -pub struct RouteAttrsResolved { - #[serde(flatten)] - inner: RouteAttrs, - communities_resolved: HashMap<(u16, u16), String>, - large_communities_resolved: HashMap<(u32, u32, u32), String>, - nexthop_resolved: Option, +pub enum ApiResult { + Route(QueryResult), + ReverseDns { + nexthop: IpAddr, + nexthop_resolved: ResolvedNexthop, + }, } // Make our own error that wraps `anyhow::Error`. @@ -78,7 +72,13 @@ async fn parse_or_resolve(resolver: &TokioAsyncResolver, name: String) -> anyhow return Ok(addr.into()); } - Ok(resolver.lookup_ip(&name).await?.iter().next().ok_or(anyhow::anyhow!("Name resolution failure"))?.into()) + Ok(resolver + .lookup_ip(&format!("{}.", name)) + .await? + .iter() + .next() + .ok_or(anyhow::anyhow!("Name resolution failure"))? + .into()) } async fn query( @@ -89,7 +89,9 @@ async fn query( let net_query = match query.net_query { NetQuery::Contains(name) => NetQuery::Contains(parse_or_resolve(&resolver, name).await?), - NetQuery::MostSpecific(name) => NetQuery::MostSpecific(parse_or_resolve(&resolver, name).await?), + NetQuery::MostSpecific(name) => { + NetQuery::MostSpecific(parse_or_resolve(&resolver, name).await?) + } NetQuery::Exact(name) => NetQuery::Exact(parse_or_resolve(&resolver, name).await?), NetQuery::OrLonger(name) => NetQuery::OrLonger(parse_or_resolve(&resolver, name).await?), }; @@ -98,7 +100,7 @@ async fn query( table_query: query.table_query, net_query, limits: query.limits, - as_path_regex: query.as_path_regex + as_path_regex: query.as_path_regex, }; let mut limits = query.limits.take().unwrap_or(cfg.query_limits.clone()); @@ -108,32 +110,77 @@ async fn query( cfg.query_limits.max_results_per_table, ); query.limits = Some(limits); - let stream = store.get_routes(query) - .then(move |route| { - let resolver = resolver.clone(); - async move { - QueryResult { - client: route.client, net: route.net, session: route.session, - state: route.state, table: route.table, - attrs: ResolvedRouteAttrs { - resolved_communities: Default::default(), - resolved_large_communities: Default::default(), - resolved_nexthop: match route.attrs.nexthop.as_ref() { - Some(nexthop) => match resolver.reverse_lookup(*nexthop).await.ok().and_then(|reverse| reverse.iter().next().map(|x| x.0.clone())) { - Some(reverse) => ResolvedNexthop::ReverseDns(reverse.to_string()), - None => ResolvedNexthop::None, + + // for deduplicating the nexthop resolutions + let have_resolved = Arc::new(std::sync::Mutex::new(HashSet::new())); + + let stream = store + .get_routes(query) + .flat_map_unordered(None, move |route| { + enum StreamState { + SendRoute, + SendDns, + Done, + } + + futures_util::stream::unfold( + ( + StreamState::SendRoute, + route, + resolver.clone(), + have_resolved.clone(), + ), + move |(state, route, resolver, have_resolved)| { + Box::pin(async move { + match state { + StreamState::SendRoute => Some(( + ApiResult::Route(route.clone()), + (StreamState::SendDns, route, resolver, have_resolved), + )), + StreamState::SendDns => match route.attrs.nexthop.clone() { + Some(nexthop) => { + if have_resolved.lock().unwrap().insert(nexthop) { + resolver + .reverse_lookup(nexthop) + .await + .ok() + .and_then(|reverse| { + reverse.iter().next().map(|x| x.0.clone()) + }) + .map(|x| { + ( + ApiResult::ReverseDns { + nexthop, + nexthop_resolved: + ResolvedNexthop::ReverseDns( + x.to_string(), + ), + }, + ( + StreamState::Done, + route, + resolver, + have_resolved, + ), + ) + }) + } else { + None + } + } + None => None, + }, + StreamState::Done => None, } - None => ResolvedNexthop::None, - }, - inner: route.attrs, + }) }, - } - } - }) - .map(|route| { - let json = serde_json::to_string(&route).unwrap(); - Ok::<_, Infallible>(format!("{}\n", json)) - }); + ) + }) + .map(|result| { + let json = serde_json::to_string(&result).unwrap(); + Ok::<_, Infallible>(format!("{}\n", json)) + }); + Ok(StreamBody::new(stream)) } diff --git a/src/store.rs b/src/store.rs index 88dabff..0c36ddc 100644 --- a/src/store.rs +++ b/src/store.rs @@ -36,15 +36,6 @@ pub enum ResolvedNexthop { ReverseDns(String), } -#[derive(Debug, Clone, Serialize, Default)] -pub struct ResolvedRouteAttrs { - #[serde(flatten)] - pub inner: RouteAttrs, - pub resolved_communities: HashMap<(u16, u16), String>, - pub resolved_large_communities: HashMap<(u32, u32, u32), String>, - pub resolved_nexthop: ResolvedNexthop, -} - #[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct SessionId { @@ -132,7 +123,7 @@ pub struct Query { #[derive(Debug, Clone, Serialize)] #[serde(deny_unknown_fields)] -pub struct QueryResult { +pub struct QueryResult { pub state: RouteState, pub net: IpNet, #[serde(flatten)] @@ -142,7 +133,7 @@ pub struct QueryResult { #[serde(flatten)] pub session: Option, #[serde(flatten)] - pub attrs: T, + pub attrs: RouteAttrs, } #[derive(Debug, Clone, Serialize, Deserialize)]