Skip to content

Commit

Permalink
Merge branch 'main' into dev/chenyin/pf_service
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying Chen committed Mar 30, 2024
2 parents faa0e13 + cb770fe commit 7324751
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 134 deletions.
11 changes: 11 additions & 0 deletions src/promptflow-core/promptflow/executor/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ class SingleNodeValidationError(UserErrorException):
pass


class AggregationNodeExecutionTimeoutError(UserErrorException):
"""Exception raised when aggregation node execution timeout"""

def __init__(self, timeout):
super().__init__(
message_format="Aggregation node execution timeout for exceeding {timeout} seconds",
timeout=timeout,
target=ErrorTarget.EXECUTOR,
)


class LineExecutionTimeoutError(UserErrorException):
"""Exception raised when single line execution timeout"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# parsers for query parameters
list_line_run_parser = api.parser()
list_line_run_parser.add_argument("session", type=str, required=False)
list_line_run_parser.add_argument("collection", type=str, required=False)
list_line_run_parser.add_argument("run", type=str, required=False)
list_line_run_parser.add_argument("experiment", type=str, required=False)
list_line_run_parser.add_argument("trace_ids", type=str, required=False)
Expand All @@ -26,7 +27,7 @@
# use @dataclass for strong type
@dataclass
class ListLineRunParser:
session_id: typing.Optional[str] = None
collection: typing.Optional[str] = None
runs: typing.Optional[typing.List[str]] = None
experiments: typing.Optional[typing.List[str]] = None
trace_ids: typing.Optional[typing.List[str]] = None
Expand All @@ -41,7 +42,7 @@ def _parse_string_list(value: typing.Optional[str]) -> typing.Optional[typing.Li
def from_request() -> "ListLineRunParser":
args = list_line_run_parser.parse_args()
return ListLineRunParser(
session_id=args.session,
collection=args.collection or args.session,
runs=ListLineRunParser._parse_string_list(args.run),
experiments=ListLineRunParser._parse_string_list(args.experiment),
trace_ids=ListLineRunParser._parse_string_list(args.trace_ids),
Expand Down Expand Up @@ -86,7 +87,7 @@ def get(self):
client: PFClient = get_client_from_request()
args = ListLineRunParser.from_request()
line_runs: typing.List[LineRunEntity] = client._traces.list_line_runs(
collection=args.session_id,
collection=args.collection,
runs=args.runs,
experiments=args.experiments,
trace_ids=args.trace_ids,
Expand Down
11 changes: 8 additions & 3 deletions src/promptflow-devkit/promptflow/_sdk/entities/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _persist_line_run(self) -> None:
# 1. first span: create, as we cannot identify the first span, so will use a try-catch
# 2. root span: update
if self.parent_id is None:
LineRun._from_root_span(self)._update()
LineRun._from_root_span(self)._try_update()
else:
LineRun._from_non_root_span(self)._try_create()

Expand Down Expand Up @@ -386,8 +386,13 @@ def _try_create(self) -> None:
except LineRunNotFoundError:
self._to_orm_object().persist()

def _update(self) -> None:
self._to_orm_object()._update()
def _try_update(self) -> None:
# try to get first; need to create, instead of update, for trace with only one root span
try:
ORMLineRun.get(line_run_id=self.line_run_id)
self._to_orm_object()._update()
except LineRunNotFoundError:
self._to_orm_object().persist()

@staticmethod
def _get_inputs_from_span(span: Span) -> typing.Optional[typing.Dict]:
Expand Down
108 changes: 69 additions & 39 deletions src/promptflow-devkit/promptflow/batch/_base_executor_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from promptflow.batch._errors import ExecutorServiceUnhealthy
from promptflow.contracts.run_info import FlowRunInfo
from promptflow.exceptions import ErrorTarget, ValidationException
from promptflow.executor._errors import AggregationNodeExecutionTimeoutError, LineExecutionTimeoutError
from promptflow.executor._result import AggregationResult, LineResult
from promptflow.storage._run_storage import AbstractRunStorage

Expand Down Expand Up @@ -327,7 +328,7 @@ def generator():
with httpx.Client() as client:
with client.stream("POST", url, json=payload, timeout=LINE_TIMEOUT_SEC, headers=headers) as response:
if response.status_code != 200:
result = self._process_http_response(response)
result = self._process_error_response(response)
run_info = FlowRunInfo.create_with_error(start_time, inputs, index, run_id, result)
yield LineResult(output={}, aggregation_inputs={}, run_info=run_info, node_run_infos={})
for line in response.iter_lines():
Expand Down Expand Up @@ -362,36 +363,69 @@ async def exec_line_async(
run_id: Optional[str] = None,
) -> LineResult:
if self.enable_stream_output:
# Todo: update to async, will get no result in "async for" of final_generator function in async mode
# TODO: update to async, will get no result in "async for" of final_generator function in async mode
raise NotSupported("Stream output is not supported in async mode for now")

response = None
start_time = datetime.utcnow()
# call execution api to get line results
url = self.api_endpoint + "/execution"
payload = {"run_id": run_id, "line_number": index, "inputs": inputs}

async with httpx.AsyncClient() as client:
response = await client.post(url, json=payload, timeout=LINE_TIMEOUT_SEC)
# process the response
result = self._process_http_response(response)
if response.status_code != 200:
run_info = FlowRunInfo.create_with_error(start_time, inputs, index, run_id, result)
return LineResult(output={}, aggregation_inputs={}, run_info=run_info, node_run_infos={})
return LineResult.deserialize(result)
try:
# Call execution api to get line results
url = self.api_endpoint + "/execution"
payload = {"run_id": run_id, "line_number": index, "inputs": inputs}
async with httpx.AsyncClient() as client:
response = await client.post(url, json=payload, timeout=LINE_TIMEOUT_SEC)
# This will raise an HTTPError for 4xx/5xx responses
response.raise_for_status()
return LineResult.deserialize(response.json())
except httpx.ReadTimeout:
ex = LineExecutionTimeoutError(line_number=index, timeout=LINE_TIMEOUT_SEC)
except httpx.HTTPStatusError:
# For 4xx and 5xx status codes
ex = self._process_error_response(response)
except Exception as e:
ex = UnexpectedError(
target=ErrorTarget.BATCH,
message_format=(
"Unexpected error occurred while executing one line in the batch run. "
"Error: {error_type_and_message}."
),
error_type_and_message=f"({e.__class__.__name__}) {e}",
)
# If any exception occurs, format and return a line result with error
error = ExceptionPresenter.create(ex).to_dict() if isinstance(ex, Exception) else ex
run_info = FlowRunInfo.create_with_error(start_time, inputs, index, run_id, error)
return LineResult(output={}, aggregation_inputs={}, run_info=run_info, node_run_infos={})

async def exec_aggregation_async(
self,
batch_inputs: Mapping[str, Any],
aggregation_inputs: Mapping[str, Any],
run_id: Optional[str] = None,
) -> AggregationResult:
# call aggregation api to get aggregation result
async with httpx.AsyncClient() as client:
response = None
try:
# Call aggregation api to get aggregation result
url = self.api_endpoint + "/aggregation"
payload = {"run_id": run_id, "batch_inputs": batch_inputs, "aggregation_inputs": aggregation_inputs}
response = await client.post(url, json=payload, timeout=LINE_TIMEOUT_SEC)
result = self._process_http_response(response)
return AggregationResult.deserialize(result)
async with httpx.AsyncClient() as client:
response = await client.post(url, json=payload, timeout=LINE_TIMEOUT_SEC)
# This will raise an HTTPError for 4xx/5xx responses
response.raise_for_status()
return AggregationResult.deserialize(response.json())
except httpx.ReadTimeout:
raise AggregationNodeExecutionTimeoutError(timeout=LINE_TIMEOUT_SEC)
except Exception as e:
ex_msg = f"({e.__class__.__name__}) {e}"
if isinstance(e, httpx.HTTPStatusError):
error_dict = self._process_error_response(e.response)
ex_msg = error_dict["message"]
raise UnexpectedError(
target=ErrorTarget.BATCH,
message_format=(
"Unexpected error occurred while executing aggregation nodes in the batch run. Error: {ex_msg}"
),
ex_msg=ex_msg,
)

async def ensure_executor_startup(self, error_file):
"""Ensure the executor service is initialized before calling the API to get the results"""
Expand Down Expand Up @@ -454,23 +488,19 @@ def _check_startup_error_from_file(self, error_file) -> Exception:
return ValidationException(error_response.message, target=ErrorTarget.BATCH)
return None

def _process_http_response(self, response: httpx.Response):
if response.status_code == 200:
# if the status code is 200, the response is the json dict of a line result
return response.json()
else:
# use this instead of response.text to handle streaming response
response_text = response.read().decode(DEFAULT_ENCODING)
# if the status code is not 200, log the error
message_format = "Unexpected error when executing a line, status code: {status_code}, error: {error}"
bulk_logger.error(message_format.format(status_code=response.status_code, error=response_text))
# if response can be parsed as json, return the error dict
# otherwise, wrap the error in an UnexpectedError and return the error dict
try:
error_dict = json.loads(response_text)
return error_dict["error"]
except (JSONDecodeError, KeyError):
unexpected_error = UnexpectedError(
message_format=message_format, status_code=response.status_code, error=response_text
)
return ExceptionPresenter.create(unexpected_error).to_dict()
def _process_error_response(self, response: httpx.Response):
# use this instead of response.text to handle streaming response
response_text = response.read().decode(DEFAULT_ENCODING)
# if the status code is not 200, log the error
message_format = "Unexpected error when executing a line, status code: {status_code}, error: {error}"
bulk_logger.error(message_format.format(status_code=response.status_code, error=response_text))
# if response can be parsed as json, return the error dict
# otherwise, wrap the error in an UnexpectedError and return the error dict
try:
error_dict = json.loads(response_text)
return error_dict["error"]
except (JSONDecodeError, KeyError):
unexpected_error = UnexpectedError(
message_format=message_format, status_code=response.status_code, error=response_text
)
return ExceptionPresenter.create(unexpected_error).to_dict()
12 changes: 7 additions & 5 deletions src/promptflow-tracing/promptflow/tracing/_start_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import typing

from opentelemetry import trace
from opentelemetry.sdk.environment_variables import OTEL_EXPORTER_OTLP_ENDPOINT
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.environment_variables import OTEL_EXPORTER_OTLP_ENDPOINT


from ._constants import (
PF_TRACING_SKIP_LOCAL_SETUP_ENVIRON,
Expand All @@ -34,6 +33,12 @@ def start_trace(
:param session: Specify the session id for current tracing session.
:type session: typing.Optional[str]
"""

# When PF_TRACING_SKIP_LOCAL_SETUP_ENVIRON is set to true, the start_trace should be skipped.
# An example is that user call start_trace at cloud mode. Nothing should happen.
if _skip_tracing_local_setup():
return

# prepare resource.attributes and set tracer provider
res_attrs = {ResourceAttributesFieldName.SERVICE_NAME: RESOURCE_ATTRIBUTES_SERVICE_NAME}
if session is not None:
Expand All @@ -43,9 +48,6 @@ def start_trace(
res_attrs[attr_key] = attr_value
_set_tracer_provider(res_attrs)

if _skip_tracing_local_setup():
return

if _is_devkit_installed():
from promptflow._sdk._tracing import start_trace_with_devkit

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,33 +30,35 @@ async def test_exec_line_async(self, has_error):
run_id = "test_run_id"
index = 1
inputs = {"question": "test"}
with patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock:
with patch("httpx.Response.raise_for_status"):
line_result_dict = _get_line_result_dict(run_id, index, inputs, has_error=has_error)
status_code = 400 if has_error else 200
mock.return_value = httpx.Response(status_code, json=line_result_dict)
line_result = await mock_executor_proxy.exec_line_async(inputs, index, run_id)
assert line_result.output == {} if has_error else {"answer": "Hello world!"}
assert line_result.run_info.run_id == run_id
assert line_result.run_info.index == index
assert line_result.run_info.status == Status.Failed if has_error else Status.Completed
assert line_result.run_info.inputs == inputs
assert (line_result.run_info.error is not None) == has_error
response = httpx.Response(status_code=status_code, json=line_result_dict)
with patch("httpx.AsyncClient.post", return_value=response):
line_result = await mock_executor_proxy.exec_line_async(inputs, index, run_id)
assert line_result.output == {} if has_error else {"answer": "Hello world!"}
assert line_result.run_info.run_id == run_id
assert line_result.run_info.index == index
assert line_result.run_info.status == Status.Failed if has_error else Status.Completed
assert line_result.run_info.inputs == inputs
assert (line_result.run_info.error is not None) == has_error

@pytest.mark.asyncio
async def test_exec_aggregation_async(self):
mock_executor_proxy = await MockAPIBasedExecutorProxy.create("")
run_id = "test_run_id"
batch_inputs = {"question": ["test", "error"]}
aggregation_inputs = {"${get_answer.output}": ["Incorrect", "Correct"]}
with patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock:
with patch("httpx.Response.raise_for_status"):
aggr_result_dict = _get_aggr_result_dict(run_id, aggregation_inputs)
mock.return_value = httpx.Response(200, json=aggr_result_dict)
aggr_result = await mock_executor_proxy.exec_aggregation_async(batch_inputs, aggregation_inputs, run_id)
assert aggr_result.metrics == {"accuracy": 0.5}
assert len(aggr_result.node_run_infos) == 1
assert aggr_result.node_run_infos["aggregation"].flow_run_id == run_id
assert aggr_result.node_run_infos["aggregation"].inputs == aggregation_inputs
assert aggr_result.node_run_infos["aggregation"].status == Status.Completed
response = httpx.Response(200, json=aggr_result_dict)
with patch("httpx.AsyncClient.post", return_value=response):
aggr_result = await mock_executor_proxy.exec_aggregation_async(batch_inputs, aggregation_inputs, run_id)
assert aggr_result.metrics == {"accuracy": 0.5}
assert len(aggr_result.node_run_infos) == 1
assert aggr_result.node_run_infos["aggregation"].flow_run_id == run_id
assert aggr_result.node_run_infos["aggregation"].inputs == aggregation_inputs
assert aggr_result.node_run_infos["aggregation"].status == Status.Completed

@pytest.mark.asyncio
async def test_ensure_executor_startup_when_no_error(self):
Expand Down Expand Up @@ -141,10 +143,6 @@ async def test_check_health(self, mock_value, expected_result):
@pytest.mark.parametrize(
"response, expected_result",
[
(
httpx.Response(200, json={"result": "test"}),
{"result": "test"},
),
(
httpx.Response(500, json={"error": "test error"}),
"test error",
Expand Down Expand Up @@ -189,9 +187,9 @@ async def test_check_health(self, mock_value, expected_result):
),
],
)
async def test_process_http_response(self, response, expected_result):
async def test_process_error_response(self, response, expected_result):
mock_executor_proxy = await MockAPIBasedExecutorProxy.create("")
assert mock_executor_proxy._process_http_response(response) == expected_result
assert mock_executor_proxy._process_error_response(response) == expected_result


class MockAPIBasedExecutorProxy(APIBasedExecutorProxy):
Expand Down
Loading

0 comments on commit 7324751

Please sign in to comment.