Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(providers/microsoft): add DefaultAzureCredential support to AzureContainerVolumeHook #33822

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions airflow/providers/microsoft/azure/hooks/container_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

from typing import Any

from azure.identity import DefaultAzureCredential
from azure.mgmt.containerinstance.models import AzureFileVolume, Volume
from azure.mgmt.storage import StorageManagementClient

from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import get_field
Expand Down Expand Up @@ -54,14 +56,22 @@ def _get_field(self, extras, name):
@staticmethod
def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget
from flask_babel import lazy_gettext
from wtforms import PasswordField
from wtforms import PasswordField, StringField

return {
"connection_string": PasswordField(
lazy_gettext("Blob Storage Connection String (optional)"), widget=BS3PasswordFieldWidget()
),
"subscription_id": StringField(
lazy_gettext("Subscription ID (optional)"),
widget=BS3TextFieldWidget(),
),
"resource_group": StringField(
lazy_gettext("Resource group name (optional)"),
widget=BS3TextFieldWidget(),
),
}

@staticmethod
Expand All @@ -77,10 +87,12 @@ def get_ui_field_behaviour() -> dict[str, Any]:
"login": "client_id (token credentials auth)",
"password": "secret (token credentials auth)",
"connection_string": "connection string auth",
"subscription_id": "Subscription id (required for Azure AD authentication)",
"resource_group": "Resource group name (required for Azure AD authentication)",
},
}

def get_storagekey(self) -> str:
def get_storagekey(self, *, storage_account_name: str | None = None) -> str:
"""Get Azure File Volume storage key."""
conn = self.get_connection(self.conn_id)
extras = conn.extra_dejson
Expand All @@ -90,6 +102,17 @@ def get_storagekey(self) -> str:
key, value = keyvalue.split("=", 1)
if key == "AccountKey":
return value

subscription_id = self._get_field(extras, "subscription_id")
resource_group = self._get_field(extras, "resource_group")
if subscription_id and storage_account_name and resource_group:
credentials = DefaultAzureCredential()
storage_client = StorageManagementClient(credentials, subscription_id)
storage_account_list_keys_result = storage_client.storage_accounts.list_keys(
resource_group, storage_account_name
)
return storage_account_list_keys_result.as_dict()["keys"][0]["value"]

return conn.password

def get_file_volume(
Expand All @@ -102,6 +125,6 @@ def get_file_volume(
share_name=share_name,
storage_account_name=storage_account_name,
read_only=read_only,
storage_account_key=self.get_storagekey(),
storage_account_key=self.get_storagekey(storage_account_name=storage_account_name),
),
)
5 changes: 2 additions & 3 deletions airflow/providers/microsoft/azure/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

---
package-name: apache-airflow-providers-microsoft-azure
name: Microsoft Azure
description: |
`Microsoft Azure <https://azure.microsoft.com/>`__
`Microsoft Azure <https://azure.microsoft.com/>`__
suspended: false
versions:
- 6.3.0
Expand Down Expand Up @@ -76,6 +75,7 @@ dependencies:
- azure-storage-blob>=12.14.0
- azure-storage-common>=2.1.0
- azure-storage-file>=2.1.0
- azure-mgmt-storage>=16.0.0
- azure-servicebus>=7.6.1
- azure-synapse-spark
- adal>=1.2.7
Expand Down Expand Up @@ -258,7 +258,6 @@ transfers:
how-to-guide: /docs/apache-airflow-providers-microsoft-azure/transfer/azure_blob_to_gcs.rst
python-module: airflow.providers.microsoft.azure.transfers.azure_blob_to_gcs


connection-types:
- hook-class-name: airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook
connection-type: azure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ Extra (optional)
The following parameters are all optional:

* ``connection_string``: Connection string for use with connection string authentication.
* ``subscription_id``: The ID of the subscription used for the initial connection. This is needed for Azure Active Directory (Azure AD) authentication.
* ``resource_group``: Azure Resource Group Name under which the desired Azure file volume resides. This is needed for Azure Active Directory (Azure AD) authentication.

When specifying the connection in environment variable you should specify
it using URI syntax.
Expand Down
1 change: 1 addition & 0 deletions generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@
"azure-mgmt-datafactory>=1.0.0,<2.0",
"azure-mgmt-datalake-store>=0.5.0",
"azure-mgmt-resource>=2.2.0",
"azure-mgmt-storage>=16.0.0",
"azure-servicebus>=7.6.1",
"azure-storage-blob>=12.14.0",
"azure-storage-common>=2.1.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# under the License.
from __future__ import annotations

from unittest import mock

import pytest

from airflow.models import Connection
Expand Down Expand Up @@ -71,6 +73,48 @@ def test_get_file_volume_connection_string(self, mocked_connection):
assert volume.azure_file.storage_account_name == "storage"
assert volume.azure_file.read_only is True

@pytest.mark.parametrize(
"mocked_connection",
[
Connection(
conn_id="azure_container_volume_test_default_azure-credential",
conn_type="wasb",
login="",
password="",
extra={"subscription_id": "subscription_id", "resource_group": "resource_group"},
)
],
indirect=True,
)
@mock.patch("airflow.providers.microsoft.azure.hooks.container_volume.StorageManagementClient")
@mock.patch("airflow.providers.microsoft.azure.hooks.container_volume.DefaultAzureCredential")
def test_get_file_volume_default_azure_credential(
self, mocked_default_azure_credential, mocked_client, mocked_connection
):
mocked_client.return_value.storage_accounts.list_keys.return_value.as_dict.return_value = {
"keys": [
{
"key_name": "key1",
"value": "value",
"permissions": "FULL",
"creation_time": "2023-07-13T16:16:10.474107Z",
},
]
}

hook = AzureContainerVolumeHook(azure_container_volume_conn_id=mocked_connection.conn_id)
volume = hook.get_file_volume(
mount_name="mount", share_name="share", storage_account_name="storage", read_only=True
)
assert volume is not None
assert volume.name == "mount"
assert volume.azure_file.share_name == "share"
assert volume.azure_file.storage_account_key == "value"
assert volume.azure_file.storage_account_name == "storage"
assert volume.azure_file.read_only is True

mocked_default_azure_credential.assert_called_with()

def test_get_ui_field_behaviour_placeholders(self):
"""
Check that ensure_prefixes decorator working properly
Expand All @@ -81,6 +125,8 @@ def test_get_ui_field_behaviour_placeholders(self):
"login",
"password",
"connection_string",
"subscription_id",
"resource_group",
]
if get_provider_min_airflow_version("apache-airflow-providers-microsoft-azure") >= (2, 5):
raise Exception(
Expand Down