From 8bf3cb7dbf4bc8db3e042919eeb12901fd82384c Mon Sep 17 00:00:00 2001 From: Tom Rutter Date: Thu, 4 Apr 2024 11:22:11 +0000 Subject: [PATCH] Make handling of connection by fs/adls.py closer to that of WasbHook and add unit tests. fix tests that were missed by keeping path of the account_url unchanged. --- airflow/providers/microsoft/azure/fs/adls.py | 62 ++++++++- .../providers/microsoft/azure/hooks/wasb.py | 34 +---- airflow/providers/microsoft/azure/utils.py | 34 +++++ tests/always/test_project_structure.py | 1 - .../providers/microsoft/azure/fs/__init__.py | 17 +++ .../providers/microsoft/azure/fs/test_adls.py | 124 ++++++++++++++++++ tests/providers/microsoft/azure/test_utils.py | 28 ++++ 7 files changed, 262 insertions(+), 38 deletions(-) create mode 100644 tests/providers/microsoft/azure/fs/__init__.py create mode 100644 tests/providers/microsoft/azure/fs/test_adls.py diff --git a/airflow/providers/microsoft/azure/fs/adls.py b/airflow/providers/microsoft/azure/fs/adls.py index 54c9cef520e02..84a242015f5b0 100644 --- a/airflow/providers/microsoft/azure/fs/adls.py +++ b/airflow/providers/microsoft/azure/fs/adls.py @@ -18,8 +18,10 @@ from typing import TYPE_CHECKING, Any +from azure.identity import ClientSecretCredential + from airflow.hooks.base import BaseHook -from airflow.providers.microsoft.azure.utils import get_field +from airflow.providers.microsoft.azure.utils import get_field, parse_blob_account_url if TYPE_CHECKING: from fsspec import AbstractFileSystem @@ -35,13 +37,61 @@ def get_fs(conn_id: str | None, storage_options: dict[str, Any] | None = None) - conn = BaseHook.get_connection(conn_id) extras = conn.extra_dejson + conn_type = conn.conn_type or "azure_data_lake" - options = {} - fields = ["connection_string", "account_name", "account_key", "sas_token", "tenant"] - for field in fields: - options[field] = get_field( - conn_id=conn_id, conn_type="azure_data_lake", extras=extras, field_name=field + # connection string always overrides everything else + connection_string = get_field( + conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="connection_string" + ) + + if connection_string: + return AzureBlobFileSystem(connection_string=connection_string) + + options: dict[str, Any] = { + "account_url": parse_blob_account_url(conn.host, conn.login), + } + + # mirror handling of custom field "client_secret_auth_config" from extras. Ignore if missing as AzureBlobFileSystem can handle. + tenant_id = get_field(conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="tenant_id") + login = conn.login or "" + password = conn.password or "" + # assumption (from WasbHook) that if tenant_id is set, we want service principal connection + if tenant_id: + client_secret_auth_config = get_field( + conn_id=conn_id, conn_type=conn_type, extras=extras, field_name="client_secret_auth_config" ) + if login: + options["client_id"] = login + if password: + options["client_secret"] = password + if client_secret_auth_config and login and password: + options["credential"] = ClientSecretCredential( + tenant_id=tenant_id, client_id=login, client_secret=password, **client_secret_auth_config + ) + + # if not service principal, then password is taken to be account admin key + if tenant_id is None and password: + options["account_key"] = password + + # now take any fields from extras and overlay on these + # add empty field to remove defaults + fields = [ + "account_name", + "account_key", + "sas_token", + "tenant_id", + "managed_identity_client_id", + "workload_identity_client_id", + "workload_identity_tenant_id", + "anon", + ] + for field in fields: + value = get_field(conn_id=conn_id, conn_type=conn_type, extras=extras, field_name=field) + if value is not None: + if value == "": + options.pop(field, "") + else: + options[field] = value options.update(storage_options or {}) diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py b/airflow/providers/microsoft/azure/hooks/wasb.py index d53e3ccd9285b..237594139e3c0 100644 --- a/airflow/providers/microsoft/azure/hooks/wasb.py +++ b/airflow/providers/microsoft/azure/hooks/wasb.py @@ -30,7 +30,6 @@ import os from functools import cached_property from typing import TYPE_CHECKING, Any, Union -from urllib.parse import urlparse from asgiref.sync import sync_to_async from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceNotFoundError @@ -52,6 +51,7 @@ add_managed_identity_connection_widgets, get_async_default_azure_credential, get_sync_default_azure_credential, + parse_blob_account_url, ) if TYPE_CHECKING: @@ -167,21 +167,7 @@ def get_conn(self) -> BlobServiceClient: # connection_string auth takes priority return BlobServiceClient.from_connection_string(connection_string, **extra) - account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/" - parsed_url = urlparse(account_url) - - if not parsed_url.netloc: - if "." not in parsed_url.path: - # if there's no netloc and no dots in the path, then user only - # provided the Active Directory ID, not the full URL or DNS name - account_url = f"https://{conn.login}.blob.core.windows.net/" - else: - # if there's no netloc but there are dots in the path, then user - # provided the DNS name without the https:// prefix. - # Azure storage account name can only be 3 to 24 characters in length - # https://learn.microsoft.com/en-us/azure/storage/common/storage-account-overview#storage-account-name - acc_name = account_url.split(".")[0][:24] - account_url = f"https://{acc_name}." + ".".join(account_url.split(".")[1:]) + account_url = parse_blob_account_url(conn.host, conn.login) tenant = self._get_field(extra, "tenant_id") if tenant: @@ -587,21 +573,7 @@ async def get_async_conn(self) -> AsyncBlobServiceClient: ) return self.blob_service_client - account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/" - parsed_url = urlparse(account_url) - - if not parsed_url.netloc: - if "." not in parsed_url.path: - # if there's no netloc and no dots in the path, then user only - # provided the Active Directory ID, not the full URL or DNS name - account_url = f"https://{conn.login}.blob.core.windows.net/" - else: - # if there's no netloc but there are dots in the path, then user - # provided the DNS name without the https:// prefix. - # Azure storage account name can only be 3 to 24 characters in length - # https://learn.microsoft.com/en-us/azure/storage/common/storage-account-overview#storage-account-name - acc_name = account_url.split(".")[0][:24] - account_url = f"https://{acc_name}." + ".".join(account_url.split(".")[1:]) + account_url = parse_blob_account_url(conn.host, conn.login) tenant = self._get_field(extra, "tenant_id") if tenant: diff --git a/airflow/providers/microsoft/azure/utils.py b/airflow/providers/microsoft/azure/utils.py index a2f7cd3f98dcd..703c2fb2e0649 100644 --- a/airflow/providers/microsoft/azure/utils.py +++ b/airflow/providers/microsoft/azure/utils.py @@ -19,6 +19,7 @@ import warnings from functools import partial, wraps +from urllib.parse import urlparse, urlunparse from azure.core.pipeline import PipelineContext, PipelineRequest from azure.core.pipeline.policies import BearerTokenCredentialPolicy @@ -171,3 +172,36 @@ def set_token(self): def signed_session(self, azure_session=None): self.set_token() return super().signed_session(azure_session) + + +def parse_blob_account_url(host: str | None, login: str | None) -> str: + account_url = host if host else f"https://{login}.blob.core.windows.net/" + + parsed_url = urlparse(account_url) + + # if there's no netloc then user provided the DNS name without the https:// prefix. + if parsed_url.scheme == "": + account_url = "https://" + account_url + parsed_url = urlparse(account_url) + + netloc = parsed_url.netloc + if "." not in netloc: + # if there's no netloc and no dots in the path, then user only + # provided the Active Directory ID, not the full URL or DNS name + netloc = f"{login}.blob.core.windows.net/" + + # Now enforce 3 to 23 character limit on account name + # https://learn.microsoft.com/en-us/azure/storage/common/storage-account-overview#storage-account-name + host_components = netloc.split(".") + host_components[0] = host_components[0][:24] + netloc = ".".join(host_components) + + url_components = [ + parsed_url.scheme, + netloc, + parsed_url.path, + parsed_url.params, + parsed_url.query, + parsed_url.fragment, + ] + return urlunparse(url_components) diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index 0b5dcdeae9843..1b965ec998ce4 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -163,7 +163,6 @@ def test_providers_modules_should_have_tests(self): "tests/providers/google/common/links/test_storage.py", "tests/providers/google/common/test_consts.py", "tests/providers/google/test_go_module_utils.py", - "tests/providers/microsoft/azure/fs/test_adls.py", "tests/providers/microsoft/azure/operators/test_adls.py", "tests/providers/microsoft/azure/transfers/test_azure_blob_to_gcs.py", "tests/providers/mongo/sensors/test_mongo.py", diff --git a/tests/providers/microsoft/azure/fs/__init__.py b/tests/providers/microsoft/azure/fs/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/tests/providers/microsoft/azure/fs/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/microsoft/azure/fs/test_adls.py b/tests/providers/microsoft/azure/fs/test_adls.py new file mode 100644 index 0000000000000..623ef37b6f4b8 --- /dev/null +++ b/tests/providers/microsoft/azure/fs/test_adls.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.models import Connection +from airflow.providers.microsoft.azure.fs.adls import get_fs + + +@pytest.fixture +def mocked_blob_file_system(): + with mock.patch("adlfs.AzureBlobFileSystem") as m: + yield m + + +@pytest.mark.parametrize( + ("mocked_connection", "expected_options"), + [ + ( + Connection( + conn_id="testconn", + conn_type="wasb", + host="testaccountname.blob.core.windows.net", + ), + { + "account_url": "https://testaccountname.blob.core.windows.net", + }, + ), + ( + Connection( + conn_id="testconn", + conn_type="wasb", + login="testaccountname", + ), + { + "account_url": "https://testaccountname.blob.core.windows.net/", + }, + ), + ( + Connection( + conn_id="testconn", + conn_type="wasb", + login="testaccountname", + password="p", + host="testaccountID", + extra={ + "connection_string": "c", + }, + ), + { + "connection_string": "c", + }, + ), + ( + Connection( + conn_id="testconn", + conn_type="wasb", + login="testaccountname", + password="p", + host="testaccountID", + extra={ + "account_name": "n", + "account_key": "k", + "sas_token": "s", + "tenant_id": "t", + "managed_identity_client_id": "m", + "workload_identity_tenant_id": "w", + "anon": False, + "other_field_name": "other_field_value", + }, + ), + { + "account_url": "https://testaccountname.blob.core.windows.net/", + # "account_url": "https://testaccountid.blob.core.windows.net/", + "client_id": "testaccountname", + "client_secret": "p", + "account_name": "n", + "account_key": "k", + "sas_token": "s", + "tenant_id": "t", + "managed_identity_client_id": "m", + "workload_identity_tenant_id": "w", + "anon": False, + }, + ), + ( + Connection( + conn_id="testconn", + conn_type="wasb", + login="testaccountname", + password="p", + host="testaccountID", + extra={}, + ), + { + "account_url": "https://testaccountname.blob.core.windows.net/", + # "account_url": "https://testaccountid.blob.core.windows.net/", + "account_key": "p", + }, + ), + ], + indirect=["mocked_connection"], +) +def test_get_fs(mocked_connection, expected_options, mocked_blob_file_system): + get_fs(mocked_connection.conn_id) + mocked_blob_file_system.assert_called_once_with(**expected_options) diff --git a/tests/providers/microsoft/azure/test_utils.py b/tests/providers/microsoft/azure/test_utils.py index 4102961bbef08..79abbf97f0423 100644 --- a/tests/providers/microsoft/azure/test_utils.py +++ b/tests/providers/microsoft/azure/test_utils.py @@ -29,6 +29,7 @@ get_field, # _get_default_azure_credential get_sync_default_azure_credential, + parse_blob_account_url, ) MODULE = "airflow.providers.microsoft.azure.utils" @@ -131,3 +132,30 @@ def test_init_with_identity(self, mock_default_azure_credential, mock_policy, mo adapter.signed_session() assert adapter.token == {"access_token": "token"} + + +@pytest.mark.parametrize( + "host, login, expected_url", + [ + (None, None, "https://None.blob.core.windows.net/"), # to maintain existing behaviour + (None, "storage_account", "https://storage_account.blob.core.windows.net/"), + ("testaccountname.blob.core.windows.net", None, "https://testaccountname.blob.core.windows.net"), + ( + "testaccountname.blob.core.windows.net", + "service_principal_id", + "https://testaccountname.blob.core.windows.net", + ), + ( + "https://testaccountname.blob.core.windows.net", + None, + "https://testaccountname.blob.core.windows.net", + ), + ( + "https://testaccountname.blob.core.windows.net", + "service_principal_id", + "https://testaccountname.blob.core.windows.net", + ), + ], +) +def test_parse_blob_account_url(host, login, expected_url): + assert parse_blob_account_url(host, login) == expected_url