diff --git a/common.py b/common.py index f6d3eb5a..5b3098ac 100644 --- a/common.py +++ b/common.py @@ -51,7 +51,7 @@ def set_marathon_auth_args(parser): class DCOSAuth(AuthBase): def __init__(self, credentials, ca_cert): - creds = json.loads(credentials) + creds = cleanup_json(json.loads(credentials)) self.uid = creds['uid'] self.private_key = creds['private_key'] self.login_endpoint = creds['login_endpoint'] @@ -128,3 +128,11 @@ def set_logging_args(parser): help="Set log level", default="DEBUG") return parser + + +def cleanup_json(data): + if isinstance(data, dict): + return {k: cleanup_json(v) for k, v in data.items() if v is not None} + if isinstance(data, list): + return [cleanup_json(e) for e in data] + return data diff --git a/marathon_lb.py b/marathon_lb.py index 4f73c8bc..dbfeeddf 100755 --- a/marathon_lb.py +++ b/marathon_lb.py @@ -43,7 +43,7 @@ import requests from common import (get_marathon_auth_params, set_logging_args, - set_marathon_auth_args, setup_logging) + set_marathon_auth_args, setup_logging, cleanup_json) from config import ConfigTemplater, label_keys from lrucache import LRUCache from utils import (CurlHttpEventStream, get_task_ip_and_ports, ip_cache, @@ -189,16 +189,18 @@ def api_req_raw(self, method, path, auth, body=None, **kwargs): response.raise_for_status() - if 'message' in response.json(): + resp_json = cleanup_json(response.json()) + if 'message' in resp_json: response.reason = "%s (%s)" % ( response.reason, - response.json()['message']) + resp_json['message']) return response def api_req(self, method, path, **kwargs): - return self.api_req_raw(method, path, self.__auth, + data = self.api_req_raw(method, path, self.__auth, verify=self.__verify, **kwargs).json() + return cleanup_json(data) def create(self, app_json): return self.api_req('POST', ['apps'], app_json) @@ -1661,7 +1663,7 @@ def process_sse_events(marathon, processor): # marathon sometimes sends more than one json per event # e.g. {}\r\n{}\r\n\r\n for real_event_data in re.split(r'\r\n', event.data): - data = json.loads(real_event_data) + data = load_json(real_event_data) logger.info( "received event of type {0}" .format(data['eventType'])) @@ -1677,6 +1679,10 @@ def process_sse_events(marathon, processor): processor.stop() +def load_json(data_str): + return cleanup_json(json.loads(data_str)) + + if __name__ == '__main__': # Process arguments arg_parser = get_arg_parser() diff --git a/tests/test_marathon_lb.py b/tests/test_marathon_lb.py index 25112348..85d0a5d1 100644 --- a/tests/test_marathon_lb.py +++ b/tests/test_marathon_lb.py @@ -3036,3 +3036,42 @@ def test_backend_disabled_and_enablede(self): server agent2_2_2_2_2_1025 2.2.2.2:1025 check inter 3s fall 11 ''' self.assertMultiLineEqual(config, expected) + + +class TestFunctions(unittest.TestCase): + + def test_json_number(self): + json_value = '1' + data = marathon_lb.load_json(json_value) + expected = 1 + self.assertEquals(data, expected) + + def test_json_string(self): + json_value = '"1"' + data = marathon_lb.load_json(json_value) + expected = "1" + self.assertEquals(data, expected) + + def test_json_nested_null_dict_remain(self): + json_value = '{"key":null,"key2":"y","key3":{"key4":null,"key5":"x"}}' + data = marathon_lb.load_json(json_value) + expected = {'key3': {'key5': 'x'}, 'key2': 'y'} + self.assertEquals(data, expected) + + def test_json_nested_null_dict(self): + json_value = '{"key":null,"key2":"y","key3":{"key4":null}}' + data = marathon_lb.load_json(json_value) + expected = {'key3': {}, 'key2': 'y'} + self.assertEquals(data, expected) + + def test_json_simple_list_dict(self): + json_value = '["k1",{"k2":null,"k3":"v3"},"k4"]' + data = marathon_lb.load_json(json_value) + expected = ['k1', {'k3': 'v3'}, 'k4'] + self.assertEquals(data, expected) + + def test_json_nested_null_dict_list(self): + json_value = '["k1",{"k2":null,"k3":["k4",{"k5":null}]},"k6"]' + data = marathon_lb.load_json(json_value) + expected = ['k1', {'k3': ['k4', {}]}, 'k6'] + self.assertEquals(data, expected) diff --git a/zdd.py b/zdd.py index a6b08710..eb759636 100755 --- a/zdd.py +++ b/zdd.py @@ -17,7 +17,7 @@ import six.moves.urllib as urllib from common import (get_marathon_auth_params, set_logging_args, - set_marathon_auth_args, setup_logging) + set_marathon_auth_args, setup_logging, cleanup_json) from utils import (get_task_ip_and_ports, get_app_port_mappings) from zdd_exceptions import ( AppCreateException, AppDeleteException, AppScaleException, @@ -76,12 +76,12 @@ def marathon_get_request(args, path): def list_marathon_apps(args): response = marathon_get_request(args, "/v2/apps") - return response.json()['apps'] + return cleanup_json(response.json())['apps'] def fetch_marathon_app(args, app_id): response = marathon_get_request(args, "/v2/apps" + app_id) - return response.json()['app'] + return cleanup_json(response.json())['app'] def _get_alias_records(hostname): @@ -575,7 +575,7 @@ def select_next_port(app): def select_next_colour(app): - if app['labels'].get('HAPROXY_DEPLOYMENT_COLOUR') == 'blue': + if app.get('labels', {}).get('HAPROXY_DEPLOYMENT_COLOUR') == 'blue': return 'green' else: return 'blue'