From 56c53a5a68df42b1f45c3233556bb8e41f79af73 Mon Sep 17 00:00:00 2001 From: Davi Arnaut Date: Tue, 11 Jun 2024 10:42:30 -0700 Subject: [PATCH] fix(azure_ad): fix infinite loop on request error --- .../datahub/ingestion/source/identity/azure_ad.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/identity/azure_ad.py b/metadata-ingestion/src/datahub/ingestion/source/identity/azure_ad.py index 20b313474d174..885b6514779cc 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/identity/azure_ad.py +++ b/metadata-ingestion/src/datahub/ingestion/source/identity/azure_ad.py @@ -9,6 +9,7 @@ import click import requests from pydantic.fields import Field +from requests.adapters import HTTPAdapter, Retry from datahub.configuration.common import AllowDenyPattern from datahub.configuration.source_common import DatasetSourceConfigMixin @@ -268,6 +269,14 @@ def __init__(self, config: AzureADConfig, ctx: PipelineContext): self.report = AzureADSourceReport( filtered_tracking=self.config.filtered_tracking ) + session = requests.Session() + retries = Retry( + total=5, backoff_factor=1, status_forcelist=[429, 500, 502, 503, 504] + ) + adapter = HTTPAdapter(max_retries=retries) + session.mount("http://", adapter) + session.mount("https://", adapter) + self.session = session self.token_data = { "grant_type": "client_credentials", "client_id": self.config.client_id, @@ -494,7 +503,7 @@ def _get_azure_ad_data(self, kind: str) -> Iterable[List]: while True: if not url: break - response = requests.get(url, headers=headers) + response = self.session.get(url, headers=headers) if response.status_code == 200: json_data = json.loads(response.text) try: @@ -512,7 +521,7 @@ def _get_azure_ad_data(self, kind: str) -> Iterable[List]: logger.debug(f"URL = {url}") logger.error(error_str) self.report.report_failure("_get_azure_ad_data_", error_str) - continue + raise Exception(f"Unable to get {url}, error {response.status_code}") def _map_identity_to_urn(self, func, id_to_extract, mapping_identifier, id_type): result, error_str = None, None