Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring parity between eval and sim for azure_ai_project #3378

Merged
merged 13 commits into from
Jun 27, 2024
5 changes: 2 additions & 3 deletions src/promptflow-evals/promptflow/evals/synthetic/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ import asyncio
azure_ai_project = {
"subscription_id": <subscription_id>,
"resource_group_name": <resource_group_name>,
"project_name": <project_name>,
"credential": DefaultAzureCredential(),
"project_name": <project_name>
}

async def callback(
Expand Down Expand Up @@ -90,7 +89,7 @@ Make sure you change the snippets below to remove the `asyncio.run` wrapper and
### Adversarial QA:
```python
scenario = AdversarialScenario.ADVERSARIAL_QA
simulator = AdversarialSimulator(azure_ai_project=azure_ai_project)
simulator = AdversarialSimulator(azure_ai_project=azure_ai_project, credential=DefaultAzureCredential())

outputs = asyncio.run(
simulator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,34 +52,31 @@ def wrapper(*args, **kwargs):


class AdversarialSimulator:
def __init__(self, *, azure_ai_project: Dict[str, Any]):
def __init__(self, *, azure_ai_project: Dict[str, Any], credential=None):
"""
Initializes the adversarial simulator with a project scope.

:param azure_ai_project: Dictionary defining the scope of the project. It must include the following keys:
- "subscription_id": Azure subscription ID.
- "resource_group_name": Name of the Azure resource group.
- "project_name": Name of the Azure Machine Learning workspace.
- "credential": Azure credentials object for authentication.
:param credential: The credential for connecting to Azure AI project.
:type credential: TokenCredential
:type azure_ai_project: Dict[str, Any]
"""
# check if azure_ai_project has the keys: subscription_id, resource_group_name, project_name, credential
if not all(
key in azure_ai_project for key in ["subscription_id", "resource_group_name", "project_name", "credential"]
):
raise ValueError(
"azure_ai_project must contain keys: subscription_id, resource_group_name, project_name, credential"
)
if not all(key in azure_ai_project for key in ["subscription_id", "resource_group_name", "project_name"]):
raise ValueError("azure_ai_project must contain keys: subscription_id, resource_group_name, project_name")
# check the value of the keys in azure_ai_project is not none
if not all(
azure_ai_project[key] for key in ["subscription_id", "resource_group_name", "project_name", "credential"]
):
raise ValueError("subscription_id, resource_group_name, project_name, and credential must not be None")
if not all(azure_ai_project[key] for key in ["subscription_id", "resource_group_name", "project_name"]):
raise ValueError("subscription_id, resource_group_name and project_name must not be None")
if "credential" in azure_ai_project and not credential:
nagkumar91 marked this conversation as resolved.
Show resolved Hide resolved
credential = azure_ai_project["credential"]
self.azure_ai_project = azure_ai_project
self.token_manager = ManagedIdentityAPITokenManager(
token_scope=TokenScope.DEFAULT_AZURE_MANAGEMENT,
logger=logging.getLogger("AdversarialSimulator"),
credential=self.azure_ai_project["credential"],
credential=credential,
nagkumar91 marked this conversation as resolved.
Show resolved Hide resolved
)
self.rai_client = RAIClient(azure_ai_project=azure_ai_project, token_manager=self.token_manager)
self.adversarial_template_handler = AdversarialTemplateHandler(
Expand Down
44 changes: 44 additions & 0 deletions src/promptflow-evals/tests/evals/e2etests/test_adv_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,47 @@ async def callback(
print(outputs.to_json_lines())
print("*****************************")
assert len(outputs) == 1

@pytest.mark.usefixtures("vcr_recording")
def test_adv_rewrite_sim_responds_with_responses(self, azure_cred, project_scope):
os.environ.pop("RAI_SVC_URL", None)
from promptflow.evals.synthetic import AdversarialScenario, AdversarialSimulator

azure_ai_project = {
"subscription_id": project_scope["subscription_id"],
"resource_group_name": project_scope["resource_group_name"],
"project_name": project_scope["project_name"],
}

async def callback(
messages: List[Dict], stream: bool = False, session_state: Any = None, context: Dict[str, Any] = None
) -> dict:
question = messages["messages"][0]["content"]

formatted_response = {"content": question, "role": "assistant"}
messages["messages"].append(formatted_response)
return {
"messages": messages["messages"],
"stream": stream,
"session_state": session_state,
"context": context,
}

simulator = AdversarialSimulator(azure_ai_project=azure_ai_project, credential=azure_cred)

outputs = asyncio.run(
simulator(
scenario=AdversarialScenario.ADVERSARIAL_REWRITE,
max_conversation_turns=1,
max_simulation_results=1,
target=callback,
api_call_retry_limit=3,
api_call_retry_sleep_sec=1,
api_call_delay_sec=30,
concurrent_async_task=1,
jailbreak=True,
)
)
print(outputs.to_json_lines())
nagkumar91 marked this conversation as resolved.
Show resolved Hide resolved
print("*****************************")
assert len(outputs) == 1
34 changes: 34 additions & 0 deletions src/promptflow-evals/tests/evals/unittests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,37 @@ async def callback(x):
scenario="unknown-scenario", max_conversation_turns=1, max_simulation_results=3, target=callback
)
)

@patch("promptflow.evals.synthetic._model_tools._rai_client.RAIClient._get_service_discovery_url")
@patch("promptflow.evals.synthetic._model_tools.AdversarialTemplateHandler._get_content_harm_template_collections")
@patch("promptflow.evals.synthetic.adversarial_simulator.AdversarialSimulator._simulate_async")
@patch("promptflow.evals.synthetic.adversarial_simulator.AdversarialSimulator._ensure_service_dependencies")
def test_initialization_parity_with_evals(
self,
mock_ensure_service_dependencies,
mock_get_content_harm_template_collections,
mock_simulate_async,
mock_get_service_discovery_url,
):
mock_get_service_discovery_url.return_value = "http://some.url/discovery/"
mock_simulate_async.return_value = MagicMock()
mock_get_content_harm_template_collections.return_value = ["t1", "t2", "t3", "t4", "t5", "t6", "t7"]
mock_ensure_service_dependencies.return_value = True
azure_ai_project = {
"subscription_id": "test_subscription",
"resource_group_name": "test_resource_group",
"project_name": "test_workspace",
}
available_scenarios = [
AdversarialScenario.ADVERSARIAL_CONVERSATION,
AdversarialScenario.ADVERSARIAL_QA,
AdversarialScenario.ADVERSARIAL_SUMMARIZATION,
AdversarialScenario.ADVERSARIAL_SEARCH,
AdversarialScenario.ADVERSARIAL_REWRITE,
AdversarialScenario.ADVERSARIAL_CONTENT_GEN_UNGROUNDED,
AdversarialScenario.ADVERSARIAL_CONTENT_GEN_GROUNDED,
]
for scenario in available_scenarios:
simulator = AdversarialSimulator(azure_ai_project=azure_ai_project, credential="test_credential")
assert callable(simulator)
simulator(scenario=scenario, max_conversation_turns=1, max_simulation_results=3, target=async_callback)
Loading
Loading