diff --git a/requirements.txt b/requirements.txt index 6bd25ba..bc136e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ pytest==7.1.2 pytest-httpserver==1.0.5 +Werkzeug==2.2.3 diff --git a/scripts/fetch_vault_credentials.py b/scripts/fetch_vault_credentials.py index 7b541e7..ceb12a4 100755 --- a/scripts/fetch_vault_credentials.py +++ b/scripts/fetch_vault_credentials.py @@ -11,8 +11,13 @@ import re import shlex import sys +import time +import urllib.error import urllib.request +MAX_CONSISTENCY_RETRIES = 5 +RETRY_BACKOFF_FACTOR = 2 + def main(): for env_var in [ @@ -35,11 +40,20 @@ def main(): VAULT_KV_STORE = os.getenv("VAULT_KV_STORE") VAULT_SECRET = os.getenv("VAULT_SECRET") - vault_token = fetch_token( + vault_token_payload = fetch_token( VAULT_ADDRESS, VAULT_ROLE_NAMESPACE, VAULT_ROLE_ID, VAULT_ROLE_SECRET ) + + vault_token = vault_token_payload["client_token"] + x_vault_index = vault_token_payload.get("x_vault_index") + secret = fetch_secret( - VAULT_ADDRESS, vault_token, VAULT_NAMESPACE, VAULT_KV_STORE, VAULT_SECRET + VAULT_ADDRESS, + vault_token, + VAULT_NAMESPACE, + VAULT_KV_STORE, + VAULT_SECRET, + x_vault_index=x_vault_index, ) if not sys.stdout.isatty(): @@ -86,30 +100,72 @@ def fetch_token(vault_address, role_namespace, role_id, role_secret): "utf-8" ), ) - vault_token_resp = urllib.request.urlopen(req).read() - return json.loads(vault_token_resp.decode("utf-8"))["auth"]["client_token"] + vault_token_opener = urllib.request.urlopen(req, timeout=10) + vault_token_response_headers = dict(vault_token_opener.getheaders()) + vault_token_response_body = vault_token_opener.read() + + client_token = json.loads(vault_token_response_body.decode("utf-8"))["auth"][ + "client_token" + ] + token_payload = { + "client_token": client_token, + "x_vault_index": vault_token_response_headers.get("X-Vault-Index"), + } + + return token_payload + except json.JSONDecodeError: raise Exception("Error decoding Vault AppRole token from response") except Exception as e: raise Exception(f"Error fetching Vault AppRole token: {e}") -def fetch_secret(vault_address, vault_token, namespace, kv_store, secret): - try: - req = urllib.request.Request( - method="GET", - url=f"{vault_address}/v1/{kv_store}/data/{secret}", - headers={ - "X-Vault-Namespace": namespace, - "X-Vault-Token": vault_token, - }, - ) - vault_secret_response = urllib.request.urlopen(req) - return json.loads(vault_secret_response.read().decode("utf-8"))["data"]["data"] - except json.JSONDecodeError: - raise Exception("Error decoding Vault secret from response") - except Exception as e: - raise Exception(f"Error fetching Vault secret: {e}") +def fetch_secret( + vault_address, vault_token, namespace, kv_store, secret, x_vault_index=None +): + + retry_count = 0 + + secret_request_headers = { + "X-Vault-Namespace": namespace, + "X-Vault-Token": vault_token, + } + + if x_vault_index: + secret_request_headers.update({"X-Vault-Index": x_vault_index}) + + while retry_count < MAX_CONSISTENCY_RETRIES: + + try: + req = urllib.request.Request( + method="GET", + url=f"{vault_address}/v1/{kv_store}/data/{secret}", + headers=secret_request_headers, + ) + vault_secret_response = urllib.request.urlopen(req, timeout=10) + return json.loads(vault_secret_response.read().decode("utf-8"))["data"][ + "data" + ] + except urllib.error.HTTPError as e: + if e.code == 412: + print( + f"Vault cluster not yet consistent, retry attempt number: {retry_count}" # noqa: E501 + ) + retry_count += 1 + sleep_time = RETRY_BACKOFF_FACTOR * retry_count + if retry_count < MAX_CONSISTENCY_RETRIES: + time.sleep(sleep_time) # sleep before retrying + else: + raise Exception( + f"Error fetching Vault secret after {MAX_CONSISTENCY_RETRIES} attempts: {e}" # noqa: E501 + ) + else: + raise Exception(f"Error fetching Vault secret: {e}") + + except json.JSONDecodeError: + raise Exception("Error decoding Vault secret from response") + except Exception as e: + raise Exception(f"Error fetching Vault secret: {e}") if __name__ == "__main__": diff --git a/scripts/fetch_vault_credentials_test.py b/scripts/fetch_vault_credentials_test.py index b9d07cf..2f40033 100644 --- a/scripts/fetch_vault_credentials_test.py +++ b/scripts/fetch_vault_credentials_test.py @@ -37,10 +37,10 @@ def test_token_fetched_correctly(mock_urlopen): token_mock.read.return_value = b'{ "auth": { "client_token": "xyz-789" } }' mock_urlopen.return_value = token_mock - token = fetch_vault_credentials.fetch_token( + token_payload = fetch_vault_credentials.fetch_token( "https://vault.test:1000", "namespace", "abc-123", "def-345" ) - assert token == "xyz-789" + assert token_payload == {"client_token": "xyz-789", "x_vault_index": None} @patch("urllib.request.urlopen") @@ -125,7 +125,10 @@ def test_secret_json_decoding_fails(mock_urlopen): ) -@patch("fetch_vault_credentials.fetch_token", return_value="xyz-789") +@patch( + "fetch_vault_credentials.fetch_token", + return_value={"client_token": "xyz-789", "x_vault_index": None}, +) @patch( "fetch_vault_credentials.fetch_secret", return_value={"KEY1": "VALUE1", "KEY2": "VALUE2"}, @@ -146,7 +149,10 @@ def test_secret_is_printed_when_stdout_is_not_a_tty(mock_stdout, mock_print, _, ) -@patch("fetch_vault_credentials.fetch_token", return_value="xyz-789") +@patch( + "fetch_vault_credentials.fetch_token", + return_value={"client_token": "xyz-789", "x_vault_index": None}, +) @patch( "fetch_vault_credentials.fetch_secret", return_value={"KEY1": "VALUE1", "KEY2": "VALUE2"}, @@ -162,7 +168,10 @@ def test_secret_is_not_printed_when_stdout_is_a_tty(mock_stdout, mock_print, _, mock_print.assert_not_called() -@patch("fetch_vault_credentials.fetch_token", return_value="xyz-789") +@patch( + "fetch_vault_credentials.fetch_token", + return_value={"client_token": "xyz-789", "x_vault_index": None}, +) @patch( "fetch_vault_credentials.fetch_secret", return_value={"KEY1": "VALUE1", "Key2": "VALUE2"},