Skip to content

Commit

Permalink
Bring parity between eval and sim for azure_ai_project (#3378)
Browse files Browse the repository at this point in the history
# Description

Please add an informative description that covers that changes made by
the pull request and link all relevant issues.

# All Promptflow Contribution checklist:
- [x] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.
  • Loading branch information
nagkumar91 authored Jun 27, 2024
1 parent 9a8c10d commit ea1dd04
Show file tree
Hide file tree
Showing 10 changed files with 528,165 additions and 393,709 deletions.
5 changes: 2 additions & 3 deletions src/promptflow-evals/promptflow/evals/synthetic/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ import asyncio
azure_ai_project = {
"subscription_id": <subscription_id>,
"resource_group_name": <resource_group_name>,
"project_name": <project_name>,
"credential": DefaultAzureCredential(),
"project_name": <project_name>
}

async def callback(
Expand Down Expand Up @@ -90,7 +89,7 @@ Make sure you change the snippets below to remove the `asyncio.run` wrapper and
### Adversarial QA:
```python
scenario = AdversarialScenario.ADVERSARIAL_QA
simulator = AdversarialSimulator(azure_ai_project=azure_ai_project)
simulator = AdversarialSimulator(azure_ai_project=azure_ai_project, credential=DefaultAzureCredential())

outputs = asyncio.run(
simulator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import random
from typing import Any, Callable, Dict, List

from azure.identity import DefaultAzureCredential
from tqdm import tqdm

from promptflow._sdk._telemetry import ActivityType, monitor_operation
Expand Down Expand Up @@ -56,33 +57,31 @@ class AdversarialSimulator:
Initializes the adversarial simulator with a project scope.
:param azure_ai_project: Dictionary defining the scope of the project. It must include the following keys:
* "subscription_id": Azure subscription ID.
* "resource_group_name": Name of the Azure resource group.
* "project_name": Name of the Azure Machine Learning workspace.
* "credential": Azure credentials object for authentication.
- "subscription_id": Azure subscription ID.
- "resource_group_name": Name of the Azure resource group.
- "project_name": Name of the Azure Machine Learning workspace.
:param credential: The credential for connecting to Azure AI project.
:type credential: TokenCredential
:type azure_ai_project: Dict[str, Any]
"""

def __init__(self, *, azure_ai_project: Dict[str, Any]):
def __init__(self, *, azure_ai_project: Dict[str, Any], credential=None):
"""Constructor."""
# check if azure_ai_project has the keys: subscription_id, resource_group_name, project_name, credential
if not all(
key in azure_ai_project for key in ["subscription_id", "resource_group_name", "project_name", "credential"]
):
raise ValueError(
"azure_ai_project must contain keys: subscription_id, resource_group_name, project_name, credential"
)
# check if azure_ai_project has the keys: subscription_id, resource_group_name and project_name
if not all(key in azure_ai_project for key in ["subscription_id", "resource_group_name", "project_name"]):
raise ValueError("azure_ai_project must contain keys: subscription_id, resource_group_name, project_name")
# check the value of the keys in azure_ai_project is not none
if not all(
azure_ai_project[key] for key in ["subscription_id", "resource_group_name", "project_name", "credential"]
):
raise ValueError("subscription_id, resource_group_name, project_name, and credential must not be None")
if not all(azure_ai_project[key] for key in ["subscription_id", "resource_group_name", "project_name"]):
raise ValueError("subscription_id, resource_group_name and project_name must not be None")
if "credential" not in azure_ai_project and not credential:
credential = DefaultAzureCredential()
elif "credential" in azure_ai_project:
credential = azure_ai_project["credential"]
self.azure_ai_project = azure_ai_project
self.token_manager = ManagedIdentityAPITokenManager(
token_scope=TokenScope.DEFAULT_AZURE_MANAGEMENT,
logger=logging.getLogger("AdversarialSimulator"),
credential=self.azure_ai_project["credential"],
credential=credential,
)
self.rai_client = RAIClient(azure_ai_project=azure_ai_project, token_manager=self.token_manager)
self.adversarial_template_handler = AdversarialTemplateHandler(
Expand Down
47 changes: 42 additions & 5 deletions src/promptflow-evals/tests/evals/e2etests/test_adv_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ async def callback(
)
)
assert len(outputs) == 1
print(outputs)
assert len(outputs[0]["messages"]) == 4

@pytest.mark.usefixtures("vcr_recording")
Expand Down Expand Up @@ -182,8 +181,6 @@ async def callback(
concurrent_async_task=1,
)
)
print(outputs.to_json_lines())
print("*****************************")
assert len(outputs) == 1

@pytest.mark.usefixtures("vcr_recording")
Expand Down Expand Up @@ -227,6 +224,46 @@ async def callback(
jailbreak=True,
)
)
print(outputs.to_json_lines())
print("*****************************")
assert len(outputs) == 1

@pytest.mark.usefixtures("vcr_recording")
def test_adv_rewrite_sim_responds_with_responses(self, azure_cred, project_scope):
os.environ.pop("RAI_SVC_URL", None)
from promptflow.evals.synthetic import AdversarialScenario, AdversarialSimulator

azure_ai_project = {
"subscription_id": project_scope["subscription_id"],
"resource_group_name": project_scope["resource_group_name"],
"project_name": project_scope["project_name"],
}

async def callback(
messages: List[Dict], stream: bool = False, session_state: Any = None, context: Dict[str, Any] = None
) -> dict:
question = messages["messages"][0]["content"]

formatted_response = {"content": question, "role": "assistant"}
messages["messages"].append(formatted_response)
return {
"messages": messages["messages"],
"stream": stream,
"session_state": session_state,
"context": context,
}

simulator = AdversarialSimulator(azure_ai_project=azure_ai_project, credential=azure_cred)

outputs = asyncio.run(
simulator(
scenario=AdversarialScenario.ADVERSARIAL_REWRITE,
max_conversation_turns=1,
max_simulation_results=1,
target=callback,
api_call_retry_limit=3,
api_call_retry_sleep_sec=1,
api_call_delay_sec=30,
concurrent_async_task=1,
jailbreak=True,
)
)
assert len(outputs) == 1
34 changes: 34 additions & 0 deletions src/promptflow-evals/tests/evals/unittests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,37 @@ async def callback(x):
scenario="unknown-scenario", max_conversation_turns=1, max_simulation_results=3, target=callback
)
)

@patch("promptflow.evals.synthetic._model_tools._rai_client.RAIClient._get_service_discovery_url")
@patch("promptflow.evals.synthetic._model_tools.AdversarialTemplateHandler._get_content_harm_template_collections")
@patch("promptflow.evals.synthetic.adversarial_simulator.AdversarialSimulator._simulate_async")
@patch("promptflow.evals.synthetic.adversarial_simulator.AdversarialSimulator._ensure_service_dependencies")
def test_initialization_parity_with_evals(
self,
mock_ensure_service_dependencies,
mock_get_content_harm_template_collections,
mock_simulate_async,
mock_get_service_discovery_url,
):
mock_get_service_discovery_url.return_value = "http://some.url/discovery/"
mock_simulate_async.return_value = MagicMock()
mock_get_content_harm_template_collections.return_value = ["t1", "t2", "t3", "t4", "t5", "t6", "t7"]
mock_ensure_service_dependencies.return_value = True
azure_ai_project = {
"subscription_id": "test_subscription",
"resource_group_name": "test_resource_group",
"project_name": "test_workspace",
}
available_scenarios = [
AdversarialScenario.ADVERSARIAL_CONVERSATION,
AdversarialScenario.ADVERSARIAL_QA,
AdversarialScenario.ADVERSARIAL_SUMMARIZATION,
AdversarialScenario.ADVERSARIAL_SEARCH,
AdversarialScenario.ADVERSARIAL_REWRITE,
AdversarialScenario.ADVERSARIAL_CONTENT_GEN_UNGROUNDED,
AdversarialScenario.ADVERSARIAL_CONTENT_GEN_GROUNDED,
]
for scenario in available_scenarios:
simulator = AdversarialSimulator(azure_ai_project=azure_ai_project, credential="test_credential")
assert callable(simulator)
simulator(scenario=scenario, max_conversation_turns=1, max_simulation_results=3, target=async_callback)
Loading

0 comments on commit ea1dd04

Please sign in to comment.