Skip to content

Commit

Permalink
[pfs] Update run api in prompt flow service (#1215)
Browse files Browse the repository at this point in the history
# Description

### Submit run

![image](https://github.com/microsoft/promptflow/assets/17938940/c6c0e5be-ee90-41a3-8c18-87de0ece8425)

### Get Run

![image](https://github.com/microsoft/promptflow/assets/17938940/d2c78f8c-b0e3-4afb-a5cc-514f9d64adb6)

### Update Run

![image](https://github.com/microsoft/promptflow/assets/17938940/83a1d21c-cd82-40be-9046-eec53284f6bb)

### List Run

![image](https://github.com/microsoft/promptflow/assets/17938940/6366e52a-b9dd-4b98-aea1-a56346a0ac51)

### Child runs

![image](https://github.com/microsoft/promptflow/assets/17938940/80c1281d-27e0-4e54-b7e0-340c60907043)

### Node runs

![image](https://github.com/microsoft/promptflow/assets/17938940/76e5a24c-ecc5-4608-ba3e-183643fb0b82)

### Visualize

![image](https://github.com/microsoft/promptflow/assets/17938940/83f7538c-2894-497a-83df-36db93ad9353)


Please add an informative description that covers that changes made by
the pull request and link all relevant issues.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.
  • Loading branch information
lalala123123 authored Nov 23, 2023
1 parent d7bd87b commit 372c8f7
Show file tree
Hide file tree
Showing 6 changed files with 604 additions and 102 deletions.
16 changes: 5 additions & 11 deletions src/promptflow/promptflow/_sdk/_service/apis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import json

from flask import jsonify, request
from flask_restx import Namespace, Resource, fields
Expand All @@ -14,10 +13,6 @@

api = Namespace("Connections", description="Connections Management")

# Define create or update connection request parsing
create_or_update_parser = api.parser()
create_or_update_parser.add_argument("connection_dict", type=str, location="args", required=True)

# Response model of list connections
list_connection_field = api.model(
"Connection",
Expand Down Expand Up @@ -68,25 +63,24 @@ def get(self, name: str):
connection_dict = connection._to_dict()
return jsonify(connection_dict)

@api.doc(parser=create_or_update_parser, description="Create connection")
@api.doc(body=dict_field, description="Create connection")
@api.response(code=200, description="Connection details", model=dict_field)
@local_user_only
def post(self, name: str):
connection_op = ConnectionOperations()
args = create_or_update_parser.parse_args()
connection_data = json.loads(args["connection_dict"])
connection_data = request.get_json(force=True)
connection_data["name"] = name
connection = _Connection._load(data=connection_data)
connection = connection_op.create_or_update(connection)
return jsonify(connection._to_dict())

@api.doc(parser=create_or_update_parser, description="Update connection")
@api.doc(body=dict_field, description="Update connection")
@api.response(code=200, description="Connection details", model=dict_field)
@local_user_only
def put(self, name: str):
connection_op = ConnectionOperations()
args = create_or_update_parser.parse_args()
params_override = [{k: v} for k, v in json.loads(args["connection_dict"]).items()]
connection_dict = request.get_json(force=True)
params_override = [{k: v} for k, v in connection_dict.items()]
existing_connection = connection_op.get(name)
connection = _Connection._load(data=existing_connection._to_dict(), params_override=params_override)
connection._secrets = existing_connection._secrets
Expand Down
156 changes: 144 additions & 12 deletions src/promptflow/promptflow/_sdk/_service/apis/run.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,38 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import json
import subprocess
import tempfile
from dataclasses import asdict
from pathlib import Path

from flask import jsonify, request
from flask_restx import Namespace, Resource
from flask import Response, jsonify, request
from flask_restx import Namespace, Resource, fields

from promptflow._sdk._constants import FlowRunProperties, get_list_view_type
from promptflow._sdk._errors import RunNotFoundError
from promptflow._sdk.entities import Run as RunEntity
from promptflow._sdk.operations._local_storage_operations import LocalStorageOperations
from promptflow._sdk.operations._run_operations import RunOperations
from promptflow.contracts._run_management import RunMetadata

api = Namespace("Runs", description="Runs Management")

# Define update run request parsing
update_run_parser = api.parser()
update_run_parser.add_argument("display_name", type=str, location="form", required=False)
update_run_parser.add_argument("description", type=str, location="form", required=False)
update_run_parser.add_argument("tags", type=str, location="form", required=False)

# Define visualize request parsing
visualize_parser = api.parser()
visualize_parser.add_argument("html", type=str, location="form", required=False)

# Response model of run operation
dict_field = api.schema_model("RunDict", {"additionalProperties": True, "type": "object"})
list_field = api.schema_model("RunList", {"type": "array", "items": {"$ref": "#/definitions/RunDict"}})


@api.errorhandler(RunNotFoundError)
def handle_run_not_found_exception(error):
Expand All @@ -24,6 +42,7 @@ def handle_run_not_found_exception(error):

@api.route("/")
class RunList(Resource):
@api.response(code=200, description="Runs", model=list_field)
@api.doc(description="List all runs")
def get(self):
# parse query parameters
Expand All @@ -42,17 +61,80 @@ def get(self):
return jsonify(runs_dict)


@api.route("/submit")
class RunSubmit(Resource):
@api.response(code=200, description="Submit run info", model=dict_field)
@api.doc(body=dict_field, description="Submit run")
def post(self):
run_dict = request.get_json(force=True)
run_name = run_dict.get("name", None)
if not run_name:
run = RunEntity(**run_dict)
run_name = run._generate_run_name()
run_dict["name"] = run_name
with tempfile.TemporaryDirectory() as temp_dir:
run_file = Path(temp_dir) / "batch_run.json"
with open(run_file, "w") as f:
json.dump(run_dict, f)
cmd = f"pf run create --file {run_file}"
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
_, stderr = process.communicate()
if process.returncode == 0:
run_op = RunOperations()
run = run_op.get(name=run_name)
return jsonify(run._to_dict())
else:
raise Exception(f"Create batch run failed: {stderr}")


@api.route("/<string:name>")
class Run(Resource):
@api.response(code=200, description="Update run info", model=dict_field)
@api.doc(parser=update_run_parser, description="Update run")
def put(self, name: str):
args = update_run_parser.parse_args()
run_op = RunOperations()
tags = json.loads(args.tags) if args.tags else None
run = run_op.update(name=name, display_name=args.display_name, description=args.description, tags=tags)
return jsonify(run._to_dict())

@api.response(code=200, description="Get run info", model=dict_field)
@api.doc(description="Get run")
def get(self, name: str):
op = RunOperations()
run = op.get(name=name)
run_dict = run._to_dict()
return jsonify(run_dict)
run_op = RunOperations()
run = run_op.get(name=name)
return jsonify(run._to_dict())


@api.route("/<string:name>/metadata")
@api.route("/<string:name>/childRuns")
class FlowChildRuns(Resource):
@api.response(code=200, description="Child runs", model=list_field)
@api.doc(description="Get child runs")
def get(self, name: str):
run_op = RunOperations()
run = run_op.get(name=name)
local_storage_op = LocalStorageOperations(run=run)
detail_dict = local_storage_op.load_detail()
return jsonify(detail_dict["flow_runs"])


@api.route("/<string:name>/nodeRuns/<string:node_name>")
class FlowNodeRuns(Resource):
@api.response(code=200, description="Node runs", model=list_field)
@api.doc(description="Get node runs info")
def get(self, name: str, node_name: str):
run_op = RunOperations()
run = run_op.get(name=name)
local_storage_op = LocalStorageOperations(run=run)
detail_dict = local_storage_op.load_detail()
node_runs = [item for item in detail_dict["node_runs"] if item["node"] == node_name]
return jsonify(node_runs)


@api.route("/<string:name>/metaData")
class MetaData(Resource):
@api.doc(description="Get metadata of run")
@api.response(code=200, description="Run metadata", model=dict_field)
def get(self, name: str):
run_op = RunOperations()
run = run_op.get(name=name)
Expand All @@ -72,11 +154,61 @@ def get(self, name: str):
return jsonify(asdict(metadata))


@api.route("/<string:name>/detail")
class Detail(Resource):
@api.route("/<string:name>/logContent")
class LogContent(Resource):
@api.doc(description="Get run log content")
@api.response(code=200, description="Log content", model=fields.String)
def get(self, name: str):
run_op = RunOperations()
run = run_op.get(name=name)
local_storage_op = LocalStorageOperations(run=run)
detail_dict = local_storage_op.load_detail()
return jsonify(detail_dict)
log_content = local_storage_op.logger.get_logs()
return log_content


@api.route("/<string:name>/metrics")
class Metrics(Resource):
@api.doc(description="Get run metrics")
@api.response(code=200, description="Run metrics", model=dict_field)
def get(self, name: str):
run_op = RunOperations()
run = run_op.get(name=name)
local_storage_op = LocalStorageOperations(run=run)
metrics = local_storage_op.load_metrics()
return jsonify(metrics)


@api.route("/<string:name>/visualize")
class VisualizeRun(Resource):
@api.doc(description="Visualize run")
@api.response(code=200, description="Visualize run", model=fields.String)
@api.produces(["text/html"])
def get(self, name: str):
with tempfile.TemporaryDirectory() as temp_dir:
run_op = RunOperations()
run = run_op.get(name=name)
html_path = Path(temp_dir) / "visualize_run.html"
run_op.visualize(run, html_path=html_path)

with open(html_path, "r") as f:
return Response(f.read(), mimetype="text/html")


@api.route("/<string:name>/archive")
class ArchiveRun(Resource):
@api.doc(description="Archive run")
@api.response(code=200, description="Archived run", model=dict_field)
def get(self, name: str):
run_op = RunOperations()
run = run_op.archive(name=name)
return jsonify(run._to_dict())


@api.route("/<string:name>/restore")
class RestoreRun(Resource):
@api.doc(description="Restore run")
@api.response(code=200, description="Restored run", model=dict_field)
def get(self, name: str):
run_op = RunOperations()
run = run_op.restore(name=name)
return jsonify(run._to_dict())
Loading

0 comments on commit 372c8f7

Please sign in to comment.