#!/usr/bin/env python3 import json import logging import os import sys import time from logging.handlers import SysLogHandler import jwt import requests from requests.auth import AuthBase def setup_logging(logger, syslog_socket, log_format, log_level='DEBUG'): log_level = log_level.upper() if log_level not in ['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET']: raise Exception('Invalid log level: {}'.format(log_level.upper())) logger.setLevel(getattr(logging, log_level)) formatter = logging.Formatter(log_format) consoleHandler = logging.StreamHandler() consoleHandler.setFormatter(formatter) logger.addHandler(consoleHandler) if syslog_socket != '/dev/null': syslogHandler = SysLogHandler(syslog_socket) syslogHandler.setFormatter(formatter) logger.addHandler(syslogHandler) def set_marathon_auth_args(parser): parser.add_argument("--marathon-auth-credential-file", help="Path to file containing a user/pass for the " "Marathon HTTP API in the format of 'user:pass'.") parser.add_argument("--auth-credentials", help="user/pass for the Marathon HTTP API in the " "format of 'user:pass'.") parser.add_argument("--dcos-auth-credentials", default=os.getenv('DCOS_SERVICE_ACCOUNT_CREDENTIAL'), help="DC/OS service account credentials") parser.add_argument("--marathon-ca-cert", help="CA certificate for Marathon HTTPS connections") return parser class DCOSAuth(AuthBase): def __init__(self, credentials, ca_cert): creds = cleanup_json(json.loads(credentials)) self.uid = creds['uid'] self.private_key = creds['private_key'] self.login_endpoint = creds['login_endpoint'] self.verify = False self.auth_header = None self.expiry = 0 if ca_cert: self.verify = ca_cert def __call__(self, auth_request): self.refresh_auth_header() auth_request.headers['Authorization'] = self.auth_header return auth_request def refresh_auth_header(self): now = int(time.time()) if not self.auth_header or now >= self.expiry - 10: self.expiry = now + 3600 payload = { 'uid': self.uid, # This is the expiry of the auth request params 'exp': now + 60, } token = jwt.encode(payload, self.private_key, 'RS256') data = { 'uid': self.uid, 'token': token.decode('ascii'), # This is the expiry for the token itself 'exp': self.expiry, } r = requests.post(self.login_endpoint, json=data, timeout=(3.05, 46), verify=self.verify) r.raise_for_status() self.auth_header = 'token=' + r.cookies['dcos-acs-auth-cookie'] def get_marathon_auth_params(args): marathon_auth = None if args.marathon_auth_credential_file: with open(args.marathon_auth_credential_file, 'r') as f: line = f.readline().rstrip('\r\n') if line: marathon_auth = tuple(line.split(':')) elif args.auth_credentials: marathon_auth = tuple(args.auth_credentials.split(':')) elif args.dcos_auth_credentials: return DCOSAuth(args.dcos_auth_credentials, args.marathon_ca_cert) if marathon_auth and len(marathon_auth) != 2: print("Please provide marathon credentials in user:pass format") sys.exit(1) return marathon_auth def set_logging_args(parser): default_log_socket = "/dev/log" if sys.platform == "darwin": default_log_socket = "/var/run/syslog" parser.add_argument("--syslog-socket", help="Socket to write syslog messages to. " "Use '/dev/null' to disable logging to syslog", default=default_log_socket) parser.add_argument("--log-format", help="Set log message format", default="%(asctime)-15s %(name)s: %(message)s") parser.add_argument("--log-level", help="Set log level", default="DEBUG") return parser def cleanup_json(data): if isinstance(data, dict): return {k: cleanup_json(v) for k, v in data.items() if v is not None} if isinstance(data, list): return [cleanup_json(e) for e in data] return data