Skip to content

Commit

Permalink
[core] add a DNS cache (#3197)
Browse files Browse the repository at this point in the history
* [forwarder] add dns cache.

* [core] adding test for DNS cache.

* [core] fixing dns tests + minor issues with logic.

* [core][dnscache] make sure we work over locations, return valid urls.

* [dnscache][test] resolve intake.

* [test][transaction] fix broken test.

[core][tests] further test fixes.

* [core][dns] enable for all endpoints.

* [forwarder] use _is_affirmative for option parsing.

* [core][forwarder] parse integer string to int.
  • Loading branch information
truthbk authored Mar 8, 2017
1 parent db6f644 commit 5abae9b
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 2 deletions.
31 changes: 29 additions & 2 deletions ddagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from socket import error as socket_error, gaierror
import sys
import threading
from urlparse import urlparse
import zlib

# For pickle & PID files, see issue 293
Expand All @@ -47,11 +48,14 @@
get_config,
get_logging_config,
get_url_endpoint,
get_version
get_version,
_is_affirmative
)
import modules
from transaction import Transaction, TransactionManager
from util import get_uuid
from utils.net import DEFAULT_DNS_TTL, DNSCache



from utils.hostname import get_hostname
Expand Down Expand Up @@ -209,6 +213,8 @@ def __sizeof__(self):

def get_url(self, endpoint, api_key):
endpoint_base_url = get_url_endpoint(endpoint)
if self._application.agent_dns_caching:
endpoint_base_url = self._application.get_from_dns_cache(endpoint_base_url)
return "{0}/intake/{1}?api_key={2}".format(endpoint_base_url, self._msg_type, api_key)

def flush(self):
Expand Down Expand Up @@ -292,6 +298,8 @@ class APIMetricTransaction(MetricTransaction):

def get_url(self, endpoint, api_key):
endpoint_base_url = get_url_endpoint(endpoint)
if self._application.agent_dns_caching:
endpoint_base_url = self._application.get_from_dns_cache(endpoint_base_url)
return "{0}/api/v1/series/?api_key={1}".format(endpoint_base_url, api_key)

def get_data(self):
Expand All @@ -303,6 +311,8 @@ class APIServiceCheckTransaction(AgentTransaction):

def get_url(self, endpoint, api_key):
endpoint_base_url = get_url_endpoint(endpoint)
if self._application.agent_dns_caching:
endpoint_base_url = self._application.get_from_dns_cache(endpoint_base_url)
return "{0}/api/v1/check_run/?api_key={1}".format(endpoint_base_url, api_key)


Expand Down Expand Up @@ -397,6 +407,7 @@ def __init__(self, port, agentConfig, watchdog=True,
self._port = int(port)
self._agentConfig = agentConfig
self._metrics = {}
self._dns_cache = None
AgentTransaction.set_application(self)
AgentTransaction.set_endpoints(agentConfig['endpoints'])
if agentConfig['endpoints'] == {}:
Expand All @@ -414,7 +425,11 @@ def __init__(self, port, agentConfig, watchdog=True,
AgentTransaction.set_tr_manager(self._tr_manager)

self._watchdog = None
self.skip_ssl_validation = skip_ssl_validation or agentConfig.get('skip_ssl_validation', False)
self.skip_ssl_validation = skip_ssl_validation or _is_affirmative(agentConfig.get('skip_ssl_validation'))
self.agent_dns_caching = _is_affirmative(agentConfig.get('dns_caching'))
self.agent_dns_ttl = int(agentConfig.get('dns_ttl', DEFAULT_DNS_TTL))
if self.agent_dns_caching:
self._dns_cache = DNSCache(ttl=self.agent_dns_ttl)
self.use_simple_http_client = use_simple_http_client
if self.skip_ssl_validation:
log.info("Skipping SSL hostname validation, useful when using a transparent proxy")
Expand All @@ -425,6 +440,18 @@ def __init__(self, port, agentConfig, watchdog=True,
self._watchdog = Watchdog.create(watchdog_timeout,
max_resets=WATCHDOG_HIGH_ACTIVITY_THRESHOLD)


def get_from_dns_cache(self, url):
if not self.agent_dns_caching:
log.debug('Caching disabled, not resolving.')
return url

location = urlparse(url)
resolve = self._dns_cache.resolve(location.netloc)
return "{scheme}://{ip}".format(scheme=location.scheme,
ip=resolve)


def log_request(self, handler):
""" Override the tornado logging method.
If everything goes well, log level is DEBUG.
Expand Down
2 changes: 2 additions & 0 deletions tests/core/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def testCustomEndpoint(self):

app = Application()
app.skip_ssl_validation = False
app.agent_dns_caching = False
app._agentConfig = config
app.use_simple_http_client = True

Expand Down Expand Up @@ -183,6 +184,7 @@ def testEndpoints(self):

app = Application()
app.skip_ssl_validation = False
app.agent_dns_caching = False
app._agentConfig = config
app.use_simple_http_client = True

Expand Down
26 changes: 26 additions & 0 deletions tests/core/test_utils_net.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
# stdlib
from unittest import TestCase
from mock import MagicMock, patch
import socket
from urlparse import urlparse
from time import sleep

# 3p
from nose.plugins.skip import SkipTest

# project
from utils.net import inet_pton, _inet_pton_win
from utils.net import IPV6_V6ONLY, IPPROTO_IPV6
from utils.net import DNSCache
from config import get_url_endpoint

DEFAULT_ENDPOINT = "https://app.datadoghq.com"

class TestUtilsNet(TestCase):
DNS_TTL = 3

def test__inet_pton_win(self):

if _inet_pton_win != inet_pton:
Expand All @@ -29,3 +37,21 @@ def test_constants(self):

if not hasattr(socket, 'IPV6_V6ONLY'):
self.assertEqual(IPV6_V6ONLY, 27)

def test_dns_cache(self):
side_effects = [(None, None, ['1.1.1.1', '2.2.2.2']),
(None, None, ['3.3.3.3'])]
mock_resolve = MagicMock(side_effect=side_effects)
cache = DNSCache(self.DNS_TTL)
with patch('socket.gethostbyaddr', mock_resolve):
ip = cache.resolve('foo')
self.assertTrue(ip in side_effects[0][2])
sleep(self.DNS_TTL + 1)
ip = cache.resolve('foo')
self.assertTrue(ip in side_effects[1][2])

# resolve intake
endpoint = get_url_endpoint(DEFAULT_ENDPOINT)
location = urlparse(endpoint)
ip = cache.resolve(location.netloc)
self.assertNotEqual(ip, location.netloc)
27 changes: 27 additions & 0 deletions utils/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

# lib
import ctypes
import time
import random
import socket


# 3p

# project
Expand All @@ -26,6 +29,7 @@
except AttributeError:
IPV6_V6ONLY = 27 # from `Ws2ipdef.h`

DEFAULT_DNS_TTL = 300

class sockaddr(ctypes.Structure):
_fields_ = [("sa_family", ctypes.c_short),
Expand All @@ -34,6 +38,29 @@ class sockaddr(ctypes.Structure):
("ipv6_addr", ctypes.c_byte * 16),
("__pad2", ctypes.c_ulong)]

class DNSCache(object):
"""
Simple, rudimentary DNS cache
"""
def __init__(self, ttl=DEFAULT_DNS_TTL):
self._cache = {}
self._ttl = ttl
random.seed()

def resolve(self, url):
try:
ts, entry = self._cache.get(url, (None, None))
now = int(time.time())
if not ts or ts < now:
_, _, entry = socket.gethostbyaddr(url)
ttl = now + self._ttl
self._cache[url] = ttl, entry

resolve = entry[random.randint(0, len(entry)-1)]
except Exception:
resolve = url

return resolve

def _inet_pton_win(address_family, ip_string):
"""
Expand Down

0 comments on commit 5abae9b

Please sign in to comment.