From e0a781537433f9f974a1d40db140de50216f9825 Mon Sep 17 00:00:00 2001 From: Billy Hu Date: Tue, 18 Jun 2024 23:44:23 -0700 Subject: [PATCH] [perf] Reduce eval local to remote tracking latency by caching the arm token (#3427) # Description Based on the investigation, local to remote tracking involves multiple client to service calls. Each call requires acquiring an ARM token from AAD, with each token acquisition taking about 2 seconds. By caching the token, we could reduce the end-to-end time of the evaluate API call with one evaluator from 76 seconds to 51 seconds, achieving around a 30% improvement. For more details, please check out [here](https://microsoft-my.sharepoint.com/:w:/p/ninhu/ETB_zdMkFrdAuf3Lcg9ssrUB6RVmyuFs5Un1G74O1HlwSA?e=cBVmsw) # All Promptflow Contribution checklist: - [ ] **The pull request does not introduce [breaking changes].** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [ ] Title of the pull request is clear and informative. - [ ] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [ ] Pull request includes test coverage for the included changes. --- .../promptflow/azure/_utils/_token_cache.py | 44 +++ .../promptflow/azure/_utils/general.py | 3 +- .../unittests/test_utils.py | 38 ++ .../promptflow/evals/evaluate/_eval_run.py | 9 +- .../tests/evals/unittests/test_eval_run.py | 336 +++++++++--------- 5 files changed, 249 insertions(+), 181 deletions(-) create mode 100644 src/promptflow-azure/promptflow/azure/_utils/_token_cache.py diff --git a/src/promptflow-azure/promptflow/azure/_utils/_token_cache.py b/src/promptflow-azure/promptflow/azure/_utils/_token_cache.py new file mode 100644 index 00000000000..be31021d9a2 --- /dev/null +++ b/src/promptflow-azure/promptflow/azure/_utils/_token_cache.py @@ -0,0 +1,44 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import time + +import jwt + +from promptflow.core._connection_provider._utils import get_arm_token + + +class SingletonMeta(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + return cls._instances[cls] + + +class ArmTokenCache(metaclass=SingletonMeta): + TOKEN_REFRESH_THRESHOLD_SECS = 300 + + def __init__(self): + self._cache = {} + + def _is_token_valid(self, entry): + current_time = time.time() + return (entry["expires_at"] - current_time) >= self.TOKEN_REFRESH_THRESHOLD_SECS + + def get_token(self, credential): + if credential in self._cache: + entry = self._cache[credential] + if self._is_token_valid(entry): + return entry["token"] + + token = self._fetch_token(credential) + decoded_token = jwt.decode(token, options={"verify_signature": False, "verify_aud": False}) + expiration_time = decoded_token.get("exp", time.time()) + self._cache[credential] = {"token": token, "expires_at": expiration_time} + return token + + def _fetch_token(self, credential): + return get_arm_token(credential=credential) diff --git a/src/promptflow-azure/promptflow/azure/_utils/general.py b/src/promptflow-azure/promptflow/azure/_utils/general.py index 8e5ab801e21..a090173fe4e 100644 --- a/src/promptflow-azure/promptflow/azure/_utils/general.py +++ b/src/promptflow-azure/promptflow/azure/_utils/general.py @@ -4,6 +4,7 @@ import jwt +from promptflow.azure._utils._token_cache import ArmTokenCache from promptflow.core._connection_provider._utils import get_arm_token, get_token @@ -24,7 +25,7 @@ def get_aml_token(credential) -> str: def get_authorization(credential=None) -> str: - token = get_arm_token(credential=credential) + token = ArmTokenCache().get_token(credential=credential) return "Bearer " + token diff --git a/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_utils.py b/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_utils.py index 8cb2f342fee..bb014c92758 100644 --- a/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_utils.py +++ b/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_utils.py @@ -1,10 +1,13 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +import time from unittest.mock import MagicMock, patch +import jwt import pytest +from promptflow.azure._utils._token_cache import ArmTokenCache from promptflow.exceptions import UserErrorException @@ -50,3 +53,38 @@ def test_user_specified_azure_cli_credential(self): with patch.dict("os.environ", {EnvironmentVariables.PF_USE_AZURE_CLI_CREDENTIAL: "true"}): cred = get_credentials_for_cli() assert isinstance(cred, AzureCliCredential) + + @patch.object(ArmTokenCache, "_fetch_token") + def test_arm_token_cache_get_token(self, mock_fetch_token): + expiration_time = time.time() + 3600 # 1 hour in the future + mock_token = jwt.encode({"exp": expiration_time}, "secret", algorithm="HS256") + mock_fetch_token.return_value = mock_token + credential = "test_credential" + + cache = ArmTokenCache() + + # Test that the token is fetched and cached + token1 = cache.get_token(credential) + assert token1 == mock_token, f"Expected '{mock_token}' but got {token1}" + assert credential in cache._cache, f"Expected '{credential}' to be in cache" + assert cache._cache[credential]["token"] == mock_token, "Expected token in cache to be the mock token" + + # Test that the cached token is returned if still valid + token2 = cache.get_token(credential) + assert token2 == mock_token, f"Expected '{mock_token}' but got {token2}" + assert ( + mock_fetch_token.call_count == 1 + ), f"Expected fetch token to be called once, but it was called {mock_fetch_token.call_count} times" + + # Test that a new token is fetched if the old one expires + expired_time = time.time() - 10 # Set the token as expired + cache._cache[credential]["expires_at"] = expired_time + + new_expiration_time = time.time() + 3600 + new_mock_token = jwt.encode({"exp": new_expiration_time}, "secret", algorithm="HS256") + mock_fetch_token.return_value = new_mock_token + token3 = cache.get_token(credential) + assert token3 == new_mock_token, f"Expected '{new_mock_token}' but got {token3}" + assert ( + mock_fetch_token.call_count == 2 + ), f"Expected fetch token to be called twice, but it was called {mock_fetch_token.call_count} times" diff --git a/src/promptflow-evals/promptflow/evals/evaluate/_eval_run.py b/src/promptflow-evals/promptflow/evals/evaluate/_eval_run.py index 7ba60c72c50..0f015ebeb19 100644 --- a/src/promptflow-evals/promptflow/evals/evaluate/_eval_run.py +++ b/src/promptflow-evals/promptflow/evals/evaluate/_eval_run.py @@ -16,6 +16,7 @@ from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry +from promptflow.azure._utils._token_cache import ArmTokenCache from promptflow.evals._version import VERSION LOGGER = logging.getLogger(__name__) @@ -209,11 +210,7 @@ def get_metrics_url(self): return f"https://{self._url_base}" "/mlflow/v2.0" f"{self._get_scope()}" f"/api/2.0/mlflow/runs/log-metric" def _get_token(self): - """The simple method to get token from the MLClient.""" - # This behavior mimics how the authority is taken in azureml-mlflow. - # Note, that here we are taking authority for public cloud, however, - # it will not work for non-public clouds. - return self._ml_client._credential.get_token(EvalRun._SCOPE) + return ArmTokenCache().get_token(self._ml_client._credential) def request_with_retry( self, url: str, method: str, json_dict: Dict[str, Any], headers: Optional[Dict[str, str]] = None @@ -234,7 +231,7 @@ def request_with_retry( if headers is None: headers = {} headers["User-Agent"] = f"promptflow/{VERSION}" - headers["Authorization"] = f"Bearer {self._get_token().token}" + headers["Authorization"] = f"Bearer {self._get_token()}" retry = Retry( total=EvalRun._MAX_RETRIES, connect=EvalRun._MAX_RETRIES, diff --git a/src/promptflow-evals/tests/evals/unittests/test_eval_run.py b/src/promptflow-evals/tests/evals/unittests/test_eval_run.py index 0288b1ddce4..c57127dae0e 100644 --- a/src/promptflow-evals/tests/evals/unittests/test_eval_run.py +++ b/src/promptflow-evals/tests/evals/unittests/test_eval_run.py @@ -1,13 +1,16 @@ import json import logging import os -import pytest - +import time from unittest.mock import MagicMock, patch - -from promptflow.evals.evaluate._eval_run import Singleton, EvalRun from uuid import uuid4 + +import jwt +import pytest + import promptflow.evals.evaluate._utils as ev_utils +from promptflow.azure._utils._token_cache import ArmTokenCache +from promptflow.evals.evaluate._eval_run import EvalRun, Singleton @pytest.fixture @@ -17,26 +20,26 @@ def setup_data(): Singleton._instances.clear() +def generate_mock_token(): + expiration_time = time.time() + 3600 # 1 hour in the future + return jwt.encode({"exp": expiration_time}, "secret", algorithm="HS256") + + @pytest.mark.unittest +@patch.object(ArmTokenCache, "_fetch_token", return_value=generate_mock_token()) class TestEvalRun: """Unit tests for the eval-run object.""" @pytest.mark.parametrize( - 'status,should_raise', - [ - ("KILLED", False), - ("WRONG_STATUS", True), - ("FINISHED", False), - ("FAILED", False) - ] + "status,should_raise", [("KILLED", False), ("WRONG_STATUS", True), ("FINISHED", False), ("FAILED", False)] ) - def test_end_raises(self, setup_data, status, should_raise, caplog): + def test_end_raises(self, token_mock, setup_data, status, should_raise, caplog): """Test that end run raises exception if incorrect status is set.""" mock_session = MagicMock() mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { - 'run': { + "run": { "info": { "run_id": str(uuid4()), "experiment_id": str(uuid4()), @@ -44,14 +47,14 @@ def test_end_raises(self, setup_data, status, should_raise, caplog): } } mock_session.request.return_value = mock_response - with patch('promptflow.evals.evaluate._eval_run.requests.Session', return_value=mock_session): + with patch("promptflow.evals.evaluate._eval_run.requests.Session", return_value=mock_session): run = EvalRun( run_name=None, - tracking_uri='www.microsoft.com', - subscription_id='mock', - group_name='mock', - workspace_name='mock', - ml_client=MagicMock() + tracking_uri="www.microsoft.com", + subscription_id="mock", + group_name="mock", + workspace_name="mock", + ml_client=MagicMock(), ) if should_raise: with pytest.raises(ValueError) as cm: @@ -61,13 +64,13 @@ def test_end_raises(self, setup_data, status, should_raise, caplog): run.end_run(status) assert len(caplog.records) == 0 - def test_run_logs_if_terminated(self, setup_data, caplog): + def test_run_logs_if_terminated(self, token_mock, setup_data, caplog): """Test that run warn user if we are trying to terminate it twice.""" mock_session = MagicMock() mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { - 'run': { + "run": { "info": { "run_id": str(uuid4()), "experiment_id": str(uuid4()), @@ -75,7 +78,7 @@ def test_run_logs_if_terminated(self, setup_data, caplog): } } mock_session.request.return_value = mock_response - with patch('promptflow.evals.evaluate._eval_run.requests.Session', return_value=mock_session): + with patch("promptflow.evals.evaluate._eval_run.requests.Session", return_value=mock_session): logger = logging.getLogger(EvalRun.__module__) # All loggers, having promptflow. prefix will have "promptflow" logger # as a parent. This logger does not propagate the logs and cannot be @@ -83,24 +86,24 @@ def test_run_logs_if_terminated(self, setup_data, caplog): logger.parent = logging.root run = EvalRun( run_name=None, - tracking_uri='www.microsoft.com', - subscription_id='mock', - group_name='mock', - workspace_name='mock', - ml_client=MagicMock() + tracking_uri="www.microsoft.com", + subscription_id="mock", + group_name="mock", + workspace_name="mock", + ml_client=MagicMock(), ) run.end_run("KILLED") run.end_run("KILLED") assert len(caplog.records) == 1 assert "Unable to stop run because it was already terminated." in caplog.records[0].message - def test_end_logs_if_fails(self, setup_data, caplog): + def test_end_logs_if_fails(self, token_mock, setup_data, caplog): """Test that if the terminal status setting was failed, it is logged.""" mock_session = MagicMock() mock_response_start = MagicMock() mock_response_start.status_code = 200 mock_response_start.json.return_value = { - 'run': { + "run": { "info": { "run_id": str(uuid4()), "experiment_id": str(uuid4()), @@ -110,7 +113,7 @@ def test_end_logs_if_fails(self, setup_data, caplog): mock_response_end = MagicMock() mock_response_end.status_code = 500 mock_session.request.side_effect = [mock_response_start, mock_response_end] - with patch('promptflow.evals.evaluate._eval_run.requests.Session', return_value=mock_session): + with patch("promptflow.evals.evaluate._eval_run.requests.Session", return_value=mock_session): logger = logging.getLogger(EvalRun.__module__) # All loggers, having promptflow. prefix will have "promptflow" logger # as a parent. This logger does not propagate the logs and cannot be @@ -118,24 +121,24 @@ def test_end_logs_if_fails(self, setup_data, caplog): logger.parent = logging.root run = EvalRun( run_name=None, - tracking_uri='www.microsoft.com', - subscription_id='mock', - group_name='mock', - workspace_name='mock', - ml_client=MagicMock() + tracking_uri="www.microsoft.com", + subscription_id="mock", + group_name="mock", + workspace_name="mock", + ml_client=MagicMock(), ) run.end_run("FINISHED") assert len(caplog.records) == 1 assert "Unable to terminate the run." in caplog.records[0].message - def test_start_run_fails(self, setup_data, caplog): + def test_start_run_fails(self, token_mock, setup_data, caplog): """Test that there are log messges if run was not started.""" mock_session = MagicMock() mock_response_start = MagicMock() mock_response_start.status_code = 500 mock_response_start.text = "Mock internal service error." mock_session.request.return_value = mock_response_start - with patch('promptflow.evals.evaluate._eval_run.requests.Session', return_value=mock_session): + with patch("promptflow.evals.evaluate._eval_run.requests.Session", return_value=mock_session): logger = logging.getLogger(EvalRun.__module__) # All loggers, having promptflow. prefix will have "promptflow" logger # as a parent. This logger does not propagate the logs and cannot be @@ -143,24 +146,24 @@ def test_start_run_fails(self, setup_data, caplog): logger.parent = logging.root run = EvalRun( run_name=None, - tracking_uri='www.microsoft.com', - subscription_id='mock', - group_name='mock', - workspace_name='mock', - ml_client=MagicMock() + tracking_uri="www.microsoft.com", + subscription_id="mock", + group_name="mock", + workspace_name="mock", + ml_client=MagicMock(), ) assert len(caplog.records) == 1 assert "500" in caplog.records[0].message assert mock_response_start.text in caplog.records[0].message - assert 'The results will be saved locally' in caplog.records[0].message + assert "The results will be saved locally" in caplog.records[0].message caplog.clear() # Log artifact - run.log_artifact('test') + run.log_artifact("test") assert len(caplog.records) == 1 assert "Unable to log artifact because the run failed to start." in caplog.records[0].message caplog.clear() # Log metric - run.log_metric('a', 42) + run.log_metric("a", 42) assert len(caplog.records) == 1 assert "Unable to log metric because the run failed to start." in caplog.records[0].message caplog.clear() @@ -170,21 +173,15 @@ def test_start_run_fails(self, setup_data, caplog): assert "Unable to stop run because the run failed to start." in caplog.records[0].message caplog.clear() - @pytest.mark.parametrize( - 'destroy_run,runs_are_the_same', - [ - (False, True), - (True, False) - ] - ) - @patch('promptflow.evals.evaluate._eval_run.requests.Session') - def test_singleton(self, mock_session_cls, setup_data, destroy_run, runs_are_the_same): + @pytest.mark.parametrize("destroy_run,runs_are_the_same", [(False, True), (True, False)]) + @patch("promptflow.evals.evaluate._eval_run.requests.Session") + def test_singleton(self, mock_session_cls, token_mock, setup_data, destroy_run, runs_are_the_same): """Test that the EvalRun is actually a singleton.""" mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.side_effect = [ { - 'run': { + "run": { "info": { "run_id": str(uuid4()), "experiment_id": str(uuid4()), @@ -192,7 +189,7 @@ def test_singleton(self, mock_session_cls, setup_data, destroy_run, runs_are_the } }, { - 'run': { + "run": { "info": { "run_id": str(uuid4()), "experiment_id": str(uuid4()), @@ -204,35 +201,35 @@ def test_singleton(self, mock_session_cls, setup_data, destroy_run, runs_are_the mock_session.request.return_value = mock_response mock_session_cls.return_value = mock_session run = EvalRun( - run_name='run', - tracking_uri='www.microsoft.com', - subscription_id='mock', - group_name='mock', - workspace_name='mock', - ml_client=MagicMock() + run_name="run", + tracking_uri="www.microsoft.com", + subscription_id="mock", + group_name="mock", + workspace_name="mock", + ml_client=MagicMock(), ) id1 = id(run) if destroy_run: run.end_run("FINISHED") id2 = id( EvalRun( - run_name='run', - tracking_uri='www.microsoft.com', - subscription_id='mock', - group_name='mock', - workspace_name='mock', - ml_client=MagicMock() + run_name="run", + tracking_uri="www.microsoft.com", + subscription_id="mock", + group_name="mock", + workspace_name="mock", + ml_client=MagicMock(), ) ) assert (id1 == id2) == runs_are_the_same - @patch('promptflow.evals.evaluate._eval_run.requests.Session') - def test_run_name(self, mock_session_cls, setup_data): + @patch("promptflow.evals.evaluate._eval_run.requests.Session") + def test_run_name(self, mock_session_cls, token_mock, setup_data): """Test that the run name is the same as ID if name is not given.""" mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { - 'run': { + "run": { "info": { "run_id": str(uuid4()), "experiment_id": str(uuid4()), @@ -244,24 +241,23 @@ def test_run_name(self, mock_session_cls, setup_data): mock_session_cls.return_value = mock_session run = EvalRun( run_name=None, - tracking_uri='www.microsoft.com', - subscription_id='mock', - group_name='mock', - workspace_name='mock', - ml_client=MagicMock() + tracking_uri="www.microsoft.com", + subscription_id="mock", + group_name="mock", + workspace_name="mock", + ml_client=MagicMock(), ) - assert run.info.run_id == mock_response.json.return_value['run']['info']['run_id'] - assert run.info.experiment_id == mock_response.json.return_value[ - 'run']['info']['experiment_id'] + assert run.info.run_id == mock_response.json.return_value["run"]["info"]["run_id"] + assert run.info.experiment_id == mock_response.json.return_value["run"]["info"]["experiment_id"] assert run.name == run.info.run_id - @patch('promptflow.evals.evaluate._eval_run.requests.Session') - def test_run_with_name(self, mock_session_cls, setup_data): + @patch("promptflow.evals.evaluate._eval_run.requests.Session") + def test_run_with_name(self, mock_session_cls, token_mock, setup_data): """Test that the run name is not the same as id if it is given.""" mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { - 'run': { + "run": { "info": { "run_id": str(uuid4()), "experiment_id": str(uuid4()), @@ -272,26 +268,25 @@ def test_run_with_name(self, mock_session_cls, setup_data): mock_session.request.return_value = mock_response mock_session_cls.return_value = mock_session run = EvalRun( - run_name='test', - tracking_uri='www.microsoft.com', - subscription_id='mock', - group_name='mock', - workspace_name='mock', - ml_client=MagicMock() + run_name="test", + tracking_uri="www.microsoft.com", + subscription_id="mock", + group_name="mock", + workspace_name="mock", + ml_client=MagicMock(), ) - assert run.info.run_id == mock_response.json.return_value['run']['info']['run_id'] - assert run.info.experiment_id == mock_response.json.return_value[ - 'run']['info']['experiment_id'] - assert run.name == 'test' + assert run.info.run_id == mock_response.json.return_value["run"]["info"]["run_id"] + assert run.info.experiment_id == mock_response.json.return_value["run"]["info"]["experiment_id"] + assert run.name == "test" assert run.name != run.info.run_id - @patch('promptflow.evals.evaluate._eval_run.requests.Session') - def test_get_urls(self, mock_session_cls, setup_data): + @patch("promptflow.evals.evaluate._eval_run.requests.Session") + def test_get_urls(self, mock_session_cls, token_mock, setup_data): """Test getting url-s from eval run.""" mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { - 'run': { + "run": { "info": { "run_id": str(uuid4()), "experiment_id": str(uuid4()), @@ -302,56 +297,51 @@ def test_get_urls(self, mock_session_cls, setup_data): mock_session.request.return_value = mock_response mock_session_cls.return_value = mock_session run = EvalRun( - run_name='test', + run_name="test", tracking_uri=( - 'https://region.api.azureml.ms/mlflow/v2.0/subscriptions' - '/000000-0000-0000-0000-0000000/resourceGroups/mock-rg-region' - '/providers/Microsoft.MachineLearningServices' - '/workspaces/mock-ws-region'), - subscription_id='000000-0000-0000-0000-0000000', - group_name='mock-rg-region', - workspace_name='mock-ws-region', - ml_client=MagicMock() + "https://region.api.azureml.ms/mlflow/v2.0/subscriptions" + "/000000-0000-0000-0000-0000000/resourceGroups/mock-rg-region" + "/providers/Microsoft.MachineLearningServices" + "/workspaces/mock-ws-region" + ), + subscription_id="000000-0000-0000-0000-0000000", + group_name="mock-rg-region", + workspace_name="mock-ws-region", + ml_client=MagicMock(), ) assert run.get_run_history_uri() == ( - 'https://region.api.azureml.ms/history/v1.0/subscriptions' - '/000000-0000-0000-0000-0000000/resourceGroups/mock-rg-region' - '/providers/Microsoft.MachineLearningServices' - '/workspaces/mock-ws-region/experimentids/' - f'{run.info.experiment_id}/runs/{run.info.run_id}'), 'Wrong RunHistory URL' + "https://region.api.azureml.ms/history/v1.0/subscriptions" + "/000000-0000-0000-0000-0000000/resourceGroups/mock-rg-region" + "/providers/Microsoft.MachineLearningServices" + "/workspaces/mock-ws-region/experimentids/" + f"{run.info.experiment_id}/runs/{run.info.run_id}" + ), "Wrong RunHistory URL" assert run.get_artifacts_uri() == ( - 'https://region.api.azureml.ms/history/v1.0/subscriptions' - '/000000-0000-0000-0000-0000000/resourceGroups/mock-rg-region' - '/providers/Microsoft.MachineLearningServices' - '/workspaces/mock-ws-region/experimentids/' - f'{run.info.experiment_id}/runs/{run.info.run_id}' - '/artifacts/batch/metadata' - ), 'Wrong Artifacts URL' + "https://region.api.azureml.ms/history/v1.0/subscriptions" + "/000000-0000-0000-0000-0000000/resourceGroups/mock-rg-region" + "/providers/Microsoft.MachineLearningServices" + "/workspaces/mock-ws-region/experimentids/" + f"{run.info.experiment_id}/runs/{run.info.run_id}" + "/artifacts/batch/metadata" + ), "Wrong Artifacts URL" assert run.get_metrics_url() == ( - 'https://region.api.azureml.ms/mlflow/v2.0/subscriptions' - '/000000-0000-0000-0000-0000000/resourceGroups/mock-rg-region' - '/providers/Microsoft.MachineLearningServices' - '/workspaces/mock-ws-region/api/2.0/mlflow/runs/log-metric' - ), 'Wrong Metrics URL' + "https://region.api.azureml.ms/mlflow/v2.0/subscriptions" + "/000000-0000-0000-0000-0000000/resourceGroups/mock-rg-region" + "/providers/Microsoft.MachineLearningServices" + "/workspaces/mock-ws-region/api/2.0/mlflow/runs/log-metric" + ), "Wrong Metrics URL" @pytest.mark.parametrize( - 'log_function,expected_str', - [ - ('log_artifact', 'allocate Blob for the artifact'), - ('log_metric', 'save metrics') - ] + "log_function,expected_str", + [("log_artifact", "allocate Blob for the artifact"), ("log_metric", "save metrics")], ) - def test_log_artifacts_logs_error( - self, - setup_data, tmp_path, caplog, - log_function, expected_str - ): + def test_log_artifacts_logs_error(self, token_mock, setup_data, tmp_path, caplog, log_function, expected_str): """Test that the error is logged.""" mock_session = MagicMock() mock_create_response = MagicMock() mock_create_response.status_code = 200 mock_create_response.json.return_value = { - 'run': { + "run": { "info": { "run_id": str(uuid4()), "experiment_id": str(uuid4()), @@ -360,27 +350,25 @@ def test_log_artifacts_logs_error( } mock_response = MagicMock() mock_response.status_code = 404 - mock_response.text = 'Mock not found error.' + mock_response.text = "Mock not found error." - if log_function == 'log_artifact': - with open(os.path.join(tmp_path, 'test.json'), 'w') as fp: - json.dump({'f1': 0.5}, fp) - mock_session.request.side_effect = [ - mock_create_response, - mock_response - ] - with patch('promptflow.evals.evaluate._eval_run.requests.Session', return_value=mock_session): + if log_function == "log_artifact": + with open(os.path.join(tmp_path, "test.json"), "w") as fp: + json.dump({"f1": 0.5}, fp) + mock_session.request.side_effect = [mock_create_response, mock_response] + with patch("promptflow.evals.evaluate._eval_run.requests.Session", return_value=mock_session): run = EvalRun( - run_name='test', + run_name="test", tracking_uri=( - 'https://region.api.azureml.ms/mlflow/v2.0/subscriptions' - '/000000-0000-0000-0000-0000000/resourceGroups/mock-rg-region' - '/providers/Microsoft.MachineLearningServices' - '/workspaces/mock-ws-region'), - subscription_id='000000-0000-0000-0000-0000000', - group_name='mock-rg-region', - workspace_name='mock-ws-region', - ml_client=MagicMock() + "https://region.api.azureml.ms/mlflow/v2.0/subscriptions" + "/000000-0000-0000-0000-0000000/resourceGroups/mock-rg-region" + "/providers/Microsoft.MachineLearningServices" + "/workspaces/mock-ws-region" + ), + subscription_id="000000-0000-0000-0000-0000000", + group_name="mock-rg-region", + workspace_name="mock-ws-region", + ml_client=MagicMock(), ) logger = logging.getLogger(EvalRun.__module__) @@ -389,28 +377,29 @@ def test_log_artifacts_logs_error( # captured by caplog. Here we will skip this logger to capture logs. logger.parent = logging.root fn = getattr(run, log_function) - if log_function == 'log_artifact': - kwargs = {'artifact_folder': tmp_path} + if log_function == "log_artifact": + kwargs = {"artifact_folder": tmp_path} else: - kwargs = {'key': 'f1', 'value': 0.5} + kwargs = {"key": "f1", "value": 0.5} fn(**kwargs) assert len(caplog.records) == 1 assert mock_response.text in caplog.records[0].message - assert '404' in caplog.records[0].message + assert "404" in caplog.records[0].message assert expected_str in caplog.records[0].message @pytest.mark.parametrize( - 'dir_exists,expected_error', [ + "dir_exists,expected_error", + [ (True, "The path to the artifact is empty."), - (False, "The path to the artifact is either not a directory or does not exist.") - ] + (False, "The path to the artifact is either not a directory or does not exist."), + ], ) - def test_wrong_artifact_path(self, tmp_path, caplog, dir_exists, expected_error): + def test_wrong_artifact_path(self, token_mock, setup_data, tmp_path, caplog, dir_exists, expected_error): """Test that if artifact path is empty, or dies not exist we are logging the error.""" mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { - 'run': { + "run": { "info": { "run_id": str(uuid4()), "experiment_id": str(uuid4()), @@ -419,18 +408,19 @@ def test_wrong_artifact_path(self, tmp_path, caplog, dir_exists, expected_error) } mock_session = MagicMock() mock_session.request.return_value = mock_response - with patch('promptflow.evals.evaluate._eval_run.requests.Session', return_value=mock_session): + with patch("promptflow.evals.evaluate._eval_run.requests.Session", return_value=mock_session): run = EvalRun( - run_name='test', + run_name="test", tracking_uri=( - 'https://region.api.azureml.ms/mlflow/v2.0/subscriptions' - '/000000-0000-0000-0000-0000000/resourceGroups/mock-rg-region' - '/providers/Microsoft.MachineLearningServices' - '/workspaces/mock-ws-region'), - subscription_id='000000-0000-0000-0000-0000000', - group_name='mock-rg-region', - workspace_name='mock-ws-region', - ml_client=MagicMock() + "https://region.api.azureml.ms/mlflow/v2.0/subscriptions" + "/000000-0000-0000-0000-0000000/resourceGroups/mock-rg-region" + "/providers/Microsoft.MachineLearningServices" + "/workspaces/mock-ws-region" + ), + subscription_id="000000-0000-0000-0000-0000000", + group_name="mock-rg-region", + workspace_name="mock-ws-region", + ml_client=MagicMock(), ) logger = logging.getLogger(EvalRun.__module__) # All loggers, having promptflow. prefix will have "promptflow" logger @@ -442,7 +432,7 @@ def test_wrong_artifact_path(self, tmp_path, caplog, dir_exists, expected_error) assert len(caplog.records) == 1 assert expected_error in caplog.records[0].message - def test_log_metrics_and_instance_results_logs_error(self, caplog): + def test_log_metrics_and_instance_results_logs_error(self, token_mock, setup_data, caplog): """Test that we are logging the error when there is no trace destination.""" logger = logging.getLogger(ev_utils.__name__) # All loggers, having promptflow. prefix will have "promptflow" logger @@ -450,9 +440,7 @@ def test_log_metrics_and_instance_results_logs_error(self, caplog): # captured by caplog. Here we will skip this logger to capture logs. logger.parent = logging.root ev_utils._log_metrics_and_instance_results( - metrics=None, - instance_results=None, - trace_destination=None, - run=None) + metrics=None, instance_results=None, trace_destination=None, run=None + ) assert len(caplog.records) == 1 assert "Unable to log traces as trace destination was not defined." in caplog.records[0].message