Skip to content

Commit

Permalink
Merge branch 'devs/peiwen/fix_extra_dependency' of https://github.com…
Browse files Browse the repository at this point in the history
…/microsoft/promptflow into devs/peiwen/fix_extra_dependency
  • Loading branch information
PeiwenGaoMS committed Mar 27, 2024
2 parents 0f8bc51 + 401304e commit d639353
Show file tree
Hide file tree
Showing 13 changed files with 1,560 additions and 153 deletions.
10 changes: 5 additions & 5 deletions scripts/check_enforcer/check_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@
"scripts/building/**",
".github/workflows/promptflow-sdk-cli-test.yml",
],
# "sdk_cli_global_config_tests": [
# "src/promptflow/**",
# "scripts/building/**",
# ".github/workflows/promptflow-global-config-test.yml",
# ],
"sdk_cli_global_config_tests": [
"src/promptflow/**",
"scripts/building/**",
".github/workflows/promptflow-global-config-test.yml",
],
"sdk_cli_azure_test_replay": [
"src/promptflow/**",
"scripts/building/**",
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ def __init__(self):
self._operations = ConnectionOperations()

def get(self, name: str, **kwargs):
return self._operations.get(name, **kwargs)
# Connection provider here target for execution, so we always get with secrets.
with_secrets = kwargs.pop("with_secrets", True)
return self._operations.get(name, with_secrets=with_secrets, **kwargs)
6 changes: 6 additions & 0 deletions src/promptflow/promptflow/_sdk/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ class ConnectionNameNotSetError(SDKError):
pass


class ConnectionClassNotFoundError(SDKError):
"""Exception raised if relative sdk connection class not found."""

pass


class InvalidRunError(SDKError):
"""Exception raised if run name is not legal."""

Expand Down
2 changes: 1 addition & 1 deletion src/promptflow/promptflow/_sdk/_submitter/run_submitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from promptflow._sdk.operations._local_storage_operations import LocalStorageOperations
from promptflow._utils.context_utils import _change_working_dir
from promptflow._utils.flow_utils import parse_variant
from promptflow._utils.logger_utils import LoggerFactory
from promptflow.batch import BatchEngine
from promptflow.contracts.run_info import Status
from promptflow.contracts.run_mode import RunMode
from promptflow.exceptions import UserErrorException, ValidationException
from promptflow.tracing._operation_context import OperationContext

from ..._utils.logger_utils import LoggerFactory
from .._configuration import Configuration
from .._load_functions import load_flow
from ..entities._flow import FlexFlow
Expand Down
13 changes: 6 additions & 7 deletions src/promptflow/promptflow/_sdk/_submitter/test_submitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,25 @@

from colorama import Fore, init

from promptflow._constants import LINE_NUMBER_KEY, FlowLanguage
from promptflow._core._errors import NotSupported
from promptflow._internal import ConnectionManager
from promptflow._proxy import ProxyFactory
from promptflow._sdk._constants import PROMPT_FLOW_DIR_NAME
from promptflow._sdk.entities._flow import Flow, FlowContext
from promptflow._sdk.operations._local_storage_operations import LoggerOperations
from promptflow._utils.async_utils import async_run_allowing_running_loop
from promptflow._utils.context_utils import _change_working_dir
from promptflow._utils.dataclass_serializer import convert_eager_flow_output_to_dict
from promptflow._utils.exception_utils import ErrorResponse
from promptflow._utils.flow_utils import parse_variant
from promptflow._utils.flow_utils import dump_flow_result, parse_variant
from promptflow._utils.logger_utils import get_cli_sdk_logger
from promptflow.contracts.flow import Flow as ExecutableFlow
from promptflow.contracts.run_info import RunInfo, Status
from promptflow.exceptions import UserErrorException
from promptflow.executor._result import LineResult
from promptflow.storage._run_storage import DefaultRunStorage

from ..._constants import LINE_NUMBER_KEY, FlowLanguage
from ..._core._errors import NotSupported
from ..._utils.async_utils import async_run_allowing_running_loop
from ..._utils.dataclass_serializer import convert_eager_flow_output_to_dict
from ..._utils.flow_utils import dump_flow_result
from ..._utils.logger_utils import get_cli_sdk_logger
from ...batch import APIBasedExecutorProxy, CSharpExecutorProxy
from .._configuration import Configuration
from ..entities._flow import FlexFlow
Expand Down
2 changes: 2 additions & 0 deletions src/promptflow/promptflow/_sdk/entities/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def _from_run_history_entity(cls, run_entity: dict) -> "Run":
start_time = run_entity.get("startTimeUtc", None)
end_time = run_entity.get("endTimeUtc", None)
duration = run_entity.get("duration", None)
resume_from = run_entity["properties"].get("azureml.promptflow.resume_from_run_id", None)
return Run(
name=run_entity["runId"],
flow=Path(f"azureml://flows/{flow_name}"),
Expand All @@ -319,6 +320,7 @@ def _from_run_history_entity(cls, run_entity: dict) -> "Run":
is_archived=run_entity.get("archived", False), # TODO: Get archived status, depends on run history team
error=run_entity.get("error", None),
run_source=RunInfoSources.RUN_HISTORY,
resume_from=resume_from,
portal_url=run_entity[RunDataKeys.PORTAL_URL],
creation_context=run_entity["createdBy"],
data=run_entity[RunDataKeys.DATA],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from datetime import datetime
from typing import List
from typing import List, Type, TypeVar

from promptflow._sdk._constants import MAX_LIST_CLI_RESULTS
from promptflow._sdk._errors import ConnectionNameNotSetError
from promptflow._sdk._errors import ConnectionClassNotFoundError, ConnectionNameNotSetError
from promptflow._sdk._orm import Connection as ORMConnection
from promptflow._sdk._telemetry import ActivityType, TelemetryMixin, monitor_operation
from promptflow._sdk._utils import safe_parse_object_list
from promptflow._sdk.entities._connection import _Connection
from promptflow._sdk.entities._connection import CustomConnection, _Connection
from promptflow.connections import _Connection as _CoreConnection

T = TypeVar("T", bound="_Connection")


class ConnectionOperations(TelemetryMixin):
Expand Down Expand Up @@ -70,15 +73,37 @@ def delete(self, name: str) -> None:
"""
ORMConnection.delete(name)

@classmethod
def _convert_core_connection_to_sdk_connection(cls, core_conn):
sdk_conn_mapping = _Connection.SUPPORTED_TYPES
sdk_conn_cls = sdk_conn_mapping.get(core_conn.type)
if sdk_conn_cls is None:
raise ConnectionClassNotFoundError(
f"Correspond sdk connection type not found for core connection type: {core_conn.type!r}, "
f"please install the latest 'promptflow-devkit' and 'promptflow-core'."
)
common_args = {
"name": core_conn.name,
"module": core_conn.module,
"expiry_time": core_conn.expiry_time,
"created_date": core_conn.created_date,
"last_modified_date": core_conn.last_modified_date,
}
if sdk_conn_cls is CustomConnection:
return sdk_conn_cls(configs=core_conn.configs, secrets=core_conn.secrets, **common_args)
return sdk_conn_cls(**dict(core_conn), **common_args)

@monitor_operation(activity_name="pf.connections.create_or_update", activity_type=ActivityType.PUBLICAPI)
def create_or_update(self, connection: _Connection, **kwargs):
def create_or_update(self, connection: Type[_Connection], **kwargs):
"""Create or update a connection.
:param connection: Run object to create or update.
:type connection: ~promptflow.sdk.entities._connection._Connection
"""
if not connection.name:
raise ConnectionNameNotSetError("Name is required to create or update connection.")
if isinstance(connection, _CoreConnection) and not isinstance(connection, _Connection):
connection = self._convert_core_connection_to_sdk_connection(connection)
orm_object = connection._to_orm_object()
now = datetime.now().isoformat()
if orm_object.createdDate is None:
Expand Down
8 changes: 4 additions & 4 deletions src/promptflow/promptflow/azure/_entities/_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

import pydash

from promptflow._constants import FlowLanguage
from promptflow._sdk._constants import DAG_FILE_NAME, SERVICE_FLOW_TYPE_2_CLIENT_FLOW_TYPE, AzureFlowSource, FlowType
from promptflow._sdk._utils import PromptflowIgnoreFile, load_yaml, remove_empty_element_from_dict
from promptflow._utils.flow_utils import dump_flow_dag, load_flow_dag
from promptflow._utils.logger_utils import LoggerFactory
from promptflow.azure._ml import AdditionalIncludesMixin, Code

from ..._constants import FlowLanguage
from ..._sdk._utils import PromptflowIgnoreFile, load_yaml, remove_empty_element_from_dict
from ..._utils.flow_utils import dump_flow_dag, load_flow_dag
from ..._utils.logger_utils import LoggerFactory
from .._constants._flow import ADDITIONAL_INCLUDES, DEFAULT_STORAGE, ENVIRONMENT, PYTHON_REQUIREMENTS_TXT
from .._restclient.flow.models import FlowDto

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def test_run_resume(self, pf: PFClient, randstr: Callable[[str], str]):
run2 = pf.run(resume_from=run, name=name2)
assert isinstance(run2, Run)
# Enable name assert after PFS released
# assert run2.name == name2
assert run2.name == name2
assert run2._resume_from == run.name

def test_run_bulk_from_yaml(self, pf, runtime: str, randstr: Callable[[str], str]):
run_id = randstr("run_id")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def test_basic_flow_with_script_tool_with_custom_strong_type_connection(
self, install_custom_tool_pkg, local_client, pf
):
# Prepare custom connection
from promptflow._sdk.entities._connection import CustomConnection
from promptflow.connections import CustomConnection

conn = CustomConnection(name="custom_connection_2", secrets={"api_key": "test"}, configs={"api_url": "test"})
local_client.connections.create_or_update(conn)
Expand Down
4 changes: 2 additions & 2 deletions src/promptflow/tests/sdk_cli_test/e2etests/test_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_pf_test_flow_with_package_tool_with_custom_strong_type_connection(self,

def test_pf_test_flow_with_package_tool_with_custom_connection_as_input_value(self, install_custom_tool_pkg):
# Prepare custom connection
from promptflow._sdk.entities._connection import CustomConnection
from promptflow.connections import CustomConnection

conn = CustomConnection(name="custom_connection_3", secrets={"api_key": "test"}, configs={"api_base": "test"})
_client.connections.create_or_update(conn)
Expand All @@ -77,7 +77,7 @@ def test_pf_test_flow_with_package_tool_with_custom_connection_as_input_value(se

def test_pf_test_flow_with_script_tool_with_custom_strong_type_connection(self):
# Prepare custom connection
from promptflow._sdk.entities._connection import CustomConnection
from promptflow.connections import CustomConnection

conn = CustomConnection(name="custom_connection_2", secrets={"api_key": "test"}, configs={"api_url": "test"})
_client.connections.create_or_update(conn)
Expand Down
41 changes: 40 additions & 1 deletion src/promptflow/tests/sdk_cli_test/unittests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from promptflow._cli._pf._connection import validate_and_interactive_get_secrets
from promptflow._sdk._constants import SCRUBBED_VALUE, ConnectionAuthMode, CustomStrongTypeConnectionConfigs
from promptflow._sdk._errors import SDKError
from promptflow._sdk._errors import ConnectionClassNotFoundError, SDKError
from promptflow._sdk._load_functions import _load_env_to_connection
from promptflow._sdk.entities._connection import (
AzureContentSafetyConnection,
Expand All @@ -25,6 +25,7 @@
WeaviateConnection,
_Connection,
)
from promptflow._sdk.operations._connection_operations import ConnectionOperations
from promptflow._utils.yaml_utils import load_yaml
from promptflow.core._connection import RequiredEnvironmentVariablesNotSetError
from promptflow.exceptions import UserErrorException
Expand Down Expand Up @@ -467,3 +468,41 @@ def test_connection_from_env(self):
"organization": "test_org",
"base_url": "test_base",
}

def test_convert_core_connection_to_sdk_connection(self):
# Assert strong type
from promptflow.connections import AzureOpenAIConnection as CoreAzureOpenAIConnection

connection_args = {
"name": "abc",
"api_base": "abc",
"auth_mode": "meid_token",
"api_version": "2023-07-01-preview",
}
connection = CoreAzureOpenAIConnection(**connection_args)
sdk_connection = ConnectionOperations._convert_core_connection_to_sdk_connection(connection)
assert isinstance(sdk_connection, AzureOpenAIConnection)
assert sdk_connection._to_dict() == {
"module": "promptflow.connections",
"type": "azure_open_ai",
"api_type": "azure",
**connection_args,
}
# Assert custom type
from promptflow.connections import CustomConnection as CoreCustomConnection

connection_args = {
"name": "abc",
"configs": {"a": "1"},
"secrets": {"b": "2"},
}
connection = CoreCustomConnection(**connection_args)
sdk_connection = ConnectionOperations._convert_core_connection_to_sdk_connection(connection)
assert isinstance(sdk_connection, CustomConnection)
assert sdk_connection._to_dict() == {"module": "promptflow.connections", "type": "custom", **connection_args}

# Bad case
connection = CoreCustomConnection(**connection_args)
connection.type = "unknown"
with pytest.raises(ConnectionClassNotFoundError):
ConnectionOperations._convert_core_connection_to_sdk_connection(connection)

0 comments on commit d639353

Please sign in to comment.