Skip to content

Commit

Permalink
[perf] Reduce eval local to remote tracking latency by caching the ar…
Browse files Browse the repository at this point in the history
…m token (#3427)

# Description

Based on the investigation, local to remote tracking involves multiple
client to service calls. Each call requires acquiring an ARM token from
AAD, with each token acquisition taking about 2 seconds. By caching the
token, we could reduce the end-to-end time of the evaluate API call with
one evaluator from 76 seconds to 51 seconds, achieving around a 30%
improvement.

For more details, please check out
[here](https://microsoft-my.sharepoint.com/:w:/p/ninhu/ETB_zdMkFrdAuf3Lcg9ssrUB6RVmyuFs5Un1G74O1HlwSA?e=cBVmsw)

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.
  • Loading branch information
ninghu authored Jun 19, 2024
1 parent 6d14c4a commit e0a7815
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 181 deletions.
44 changes: 44 additions & 0 deletions src/promptflow-azure/promptflow/azure/_utils/_token_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import time

import jwt

from promptflow.core._connection_provider._utils import get_arm_token


class SingletonMeta(type):
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]


class ArmTokenCache(metaclass=SingletonMeta):
TOKEN_REFRESH_THRESHOLD_SECS = 300

def __init__(self):
self._cache = {}

def _is_token_valid(self, entry):
current_time = time.time()
return (entry["expires_at"] - current_time) >= self.TOKEN_REFRESH_THRESHOLD_SECS

def get_token(self, credential):
if credential in self._cache:
entry = self._cache[credential]
if self._is_token_valid(entry):
return entry["token"]

token = self._fetch_token(credential)
decoded_token = jwt.decode(token, options={"verify_signature": False, "verify_aud": False})
expiration_time = decoded_token.get("exp", time.time())
self._cache[credential] = {"token": token, "expires_at": expiration_time}
return token

def _fetch_token(self, credential):
return get_arm_token(credential=credential)
3 changes: 2 additions & 1 deletion src/promptflow-azure/promptflow/azure/_utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jwt

from promptflow.azure._utils._token_cache import ArmTokenCache
from promptflow.core._connection_provider._utils import get_arm_token, get_token


Expand All @@ -24,7 +25,7 @@ def get_aml_token(credential) -> str:


def get_authorization(credential=None) -> str:
token = get_arm_token(credential=credential)
token = ArmTokenCache().get_token(credential=credential)
return "Bearer " + token


Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import time
from unittest.mock import MagicMock, patch

import jwt
import pytest

from promptflow.azure._utils._token_cache import ArmTokenCache
from promptflow.exceptions import UserErrorException


Expand Down Expand Up @@ -50,3 +53,38 @@ def test_user_specified_azure_cli_credential(self):
with patch.dict("os.environ", {EnvironmentVariables.PF_USE_AZURE_CLI_CREDENTIAL: "true"}):
cred = get_credentials_for_cli()
assert isinstance(cred, AzureCliCredential)

@patch.object(ArmTokenCache, "_fetch_token")
def test_arm_token_cache_get_token(self, mock_fetch_token):
expiration_time = time.time() + 3600 # 1 hour in the future
mock_token = jwt.encode({"exp": expiration_time}, "secret", algorithm="HS256")
mock_fetch_token.return_value = mock_token
credential = "test_credential"

cache = ArmTokenCache()

# Test that the token is fetched and cached
token1 = cache.get_token(credential)
assert token1 == mock_token, f"Expected '{mock_token}' but got {token1}"
assert credential in cache._cache, f"Expected '{credential}' to be in cache"
assert cache._cache[credential]["token"] == mock_token, "Expected token in cache to be the mock token"

# Test that the cached token is returned if still valid
token2 = cache.get_token(credential)
assert token2 == mock_token, f"Expected '{mock_token}' but got {token2}"
assert (
mock_fetch_token.call_count == 1
), f"Expected fetch token to be called once, but it was called {mock_fetch_token.call_count} times"

# Test that a new token is fetched if the old one expires
expired_time = time.time() - 10 # Set the token as expired
cache._cache[credential]["expires_at"] = expired_time

new_expiration_time = time.time() + 3600
new_mock_token = jwt.encode({"exp": new_expiration_time}, "secret", algorithm="HS256")
mock_fetch_token.return_value = new_mock_token
token3 = cache.get_token(credential)
assert token3 == new_mock_token, f"Expected '{new_mock_token}' but got {token3}"
assert (
mock_fetch_token.call_count == 2
), f"Expected fetch token to be called twice, but it was called {mock_fetch_token.call_count} times"
9 changes: 3 additions & 6 deletions src/promptflow-evals/promptflow/evals/evaluate/_eval_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from promptflow.azure._utils._token_cache import ArmTokenCache
from promptflow.evals._version import VERSION

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -209,11 +210,7 @@ def get_metrics_url(self):
return f"https://{self._url_base}" "/mlflow/v2.0" f"{self._get_scope()}" f"/api/2.0/mlflow/runs/log-metric"

def _get_token(self):
"""The simple method to get token from the MLClient."""
# This behavior mimics how the authority is taken in azureml-mlflow.
# Note, that here we are taking authority for public cloud, however,
# it will not work for non-public clouds.
return self._ml_client._credential.get_token(EvalRun._SCOPE)
return ArmTokenCache().get_token(self._ml_client._credential)

def request_with_retry(
self, url: str, method: str, json_dict: Dict[str, Any], headers: Optional[Dict[str, str]] = None
Expand All @@ -234,7 +231,7 @@ def request_with_retry(
if headers is None:
headers = {}
headers["User-Agent"] = f"promptflow/{VERSION}"
headers["Authorization"] = f"Bearer {self._get_token().token}"
headers["Authorization"] = f"Bearer {self._get_token()}"
retry = Retry(
total=EvalRun._MAX_RETRIES,
connect=EvalRun._MAX_RETRIES,
Expand Down
Loading

0 comments on commit e0a7815

Please sign in to comment.