Skip to content

Commit

Permalink
[prompty] Prompty supports image as input (#3303)
Browse files Browse the repository at this point in the history
# Description
prompty
```
---
name: Basic Prompt with Image
description: A basic prompt that uses the GPT-3 chat API to answer questions
model:
    api: chat
    configuration:
      type: azure_openai
      azure_deployment: gpt-4-vision-preview
      connection: azure_open_ai_connection
    parameters:
      temperature: 0.2
sample:
  image: "data:image/jpg;path:../datas/logo.jpg"
  question: what is it
---
system:
As an AI assistant, your task involves interpreting images and responding to questions about the image.
Remember to provide accurate answers based on the information present in the image.
Directly give the answer, no more explanation is needed.

# user:
{{question}}
![image]({{image}})
```

```python
prompty = Prompty.load(source="path/to/prompty.prompty")

# Local image path
prompty(image="data:image/jpg;path:path/to/image.jpg")
# image url
prompty(image="data:image/jpg;url:http://link.to.image.jpg")

# or image obj
from promptflow.contracts.multimedia import Image

image = Image(image_buffer, mime_type="image/png")
prompty(image=image)
```


Trace:

![image](https://github.com/microsoft/promptflow/assets/17938940/5ffe8dcd-6fcd-4c0b-a28b-fa3c9ca072a0)

Please add an informative description that covers that changes made by
the pull request and link all relevant issues.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.
  • Loading branch information
lalala123123 authored May 22, 2024
1 parent 71b308b commit 52f6b90
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 6 deletions.
10 changes: 9 additions & 1 deletion src/promptflow-core/promptflow/core/_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from promptflow.core._errors import MissingRequiredInputError
from promptflow.core._model_configuration import PromptyModelConfiguration
from promptflow.core._prompty_utils import (
_get_image_obj,
convert_model_configuration_to_connection,
convert_prompt_template,
format_llm_response,
Expand Down Expand Up @@ -391,6 +392,11 @@ def _resolve_inputs(self, input_values):
resolved_inputs[input_name] = input_values.get(input_name, value.get("default", None))
if missing_inputs:
raise MissingRequiredInputError(f"Missing required inputs: {missing_inputs}")

# Resolve image input
for k, v in resolved_inputs.items():
if isinstance(v, str):
resolved_inputs[k] = _get_image_obj(v, working_dir=self.path.parent)
return resolved_inputs

def _get_input_signature(self):
Expand Down Expand Up @@ -497,7 +503,9 @@ def estimate_token_count(self, *args, **kwargs):
raise UserErrorException("Max_token needs to be integer.")
elif response_max_token <= 1:
raise UserErrorException(f"{response_max_token} is less than the minimum of max_tokens.")
total_token = num_tokens_from_messages(prompt, self._model._model) + (response_max_token or 0)
total_token = num_tokens_from_messages(prompt, self._model._model, working_dir=self.path.parent) + (
response_max_token or 0
)
return total_token


Expand Down
89 changes: 84 additions & 5 deletions src/promptflow-core/promptflow/core/_prompty_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from openai import APIConnectionError, APIStatusError, APITimeoutError, BadRequestError, OpenAIError, RateLimitError

from promptflow._utils.logger_utils import LoggerFactory
from promptflow._utils.multimedia_utils import MIME_PATTERN, ImageProcessor
from promptflow._utils.yaml_utils import load_yaml
from promptflow.contracts.types import PromptTemplate
from promptflow.core._connection import AzureOpenAIConnection, OpenAIConnection, _Connection
Expand Down Expand Up @@ -137,7 +138,8 @@ def convert_prompt_template(template, inputs, api):
template_content=prompt, trim_blocks=True, keep_trailing_newline=True, **inputs
)
else:
rendered_prompt = build_messages(prompt=prompt, **inputs)
reference_images = find_referenced_image_set(inputs)
rendered_prompt = build_messages(prompt=prompt, images=reference_images, **inputs)
return rendered_prompt


Expand Down Expand Up @@ -284,7 +286,7 @@ def format_stream(llm_response):
return result


def num_tokens_from_messages(messages, model):
def num_tokens_from_messages(messages, model, working_dir):
"""Return the number of tokens used by a list of messages."""
# Ref: https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken#6-counting-tokens-for-chat-completions-api-calls # noqa: E501
try:
Expand All @@ -307,10 +309,10 @@ def num_tokens_from_messages(messages, model):
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model or "gpt-35-turbo":
logger.warning("gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613", working_dir=working_dir)
elif "gpt-4" in model:
logger.warning("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return num_tokens_from_messages(messages, model="gpt-4-0613")
return num_tokens_from_messages(messages, model="gpt-4-0613", working_dir=working_dir)
else:
raise NotImplementedError(
f"num_tokens_from_messages() is not implemented for model {model}. "
Expand All @@ -321,13 +323,81 @@ def num_tokens_from_messages(messages, model):
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if isinstance(value, str):
num_tokens += len(encoding.encode(value))
elif isinstance(value, list):
for item in value:
value_type = item.get("type", "text")
if value_type == "text":
# Calculate content tokens
num_tokens += len(encoding.encode(item["text"]))
elif value_type == "image_url":
image_content = item["image_url"]["url"]
if ImageProcessor.is_url(image_content):
image_obj = ImageProcessor.create_image_from_url(image_content)
num_tokens += _num_tokens_for_image(image_obj.to_base64())
elif ImageProcessor.is_base64(image_content):
image_obj = ImageProcessor.create_image_from_base64(image_content)
num_tokens += _num_tokens_for_image(image_obj.to_base64())
else:
# Calculate image input as content
num_tokens += len(encoding.encode(item["image_url"]["url"]))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens


def _get_image_obj(image_str, working_dir):
mime_pattern_with_content = MIME_PATTERN.pattern[:-1] + r":\s*(.*)$"
match = re.match(mime_pattern_with_content, image_str)
if match:
mine_type, image_type, image_content = f"image/{match.group(1)}", match.group(2), match.group(3)
if image_type == "path":
if not Path(image_content).is_absolute():
image_content = Path(working_dir) / image_content
if not Path(image_content).exists():
logger.warning(f"Cannot find the image path {image_content}, it will be regarded as {type(image_str)}.")
return ImageProcessor.create_image_from_file(image_content, mine_type)
elif image_type == "base64":
return ImageProcessor.create_image_from_base64(image_content, mine_type)
elif image_type == "url":
return ImageProcessor.create_image_from_url(image_content, mine_type)
else:
logger.warning(f"Invalid mine type {mine_type}, it will be regarded as {type(image_str)}.")
return image_str


def _num_tokens_for_image(base64_str: str):
"""calculate token of image input"""
# https://platform.openai.com/docs/guides/vision/calculating-costs
import base64
from io import BytesIO
from math import ceil

from PIL import Image

imgdata = base64.b64decode(base64_str)
image = Image.open(BytesIO(imgdata))
width, height = image.size
if width > 2048 or height > 2048:
aspect_ratio = width / height
if aspect_ratio > 1:
width, height = 2048, int(2048 / aspect_ratio)
else:
width, height = int(2048 * aspect_ratio), 2048

if width >= height and height > 768:
width, height = int((768 / height) * width), 768
elif height > width and width > 768:
width, height = 768, int((768 / width) * height)

tiles_width = ceil(width / 512)
tiles_height = ceil(height / 512)
image_tokens = 85 + 170 * (tiles_width * tiles_height)
return image_tokens


def resolve_references(origin, base_path=None):
"""Resolve all reference in the object."""
if isinstance(origin, str):
Expand Down Expand Up @@ -415,6 +485,15 @@ def merge_escape_mapping_of_flow_inputs(self, _inputs_to_escape: list, **kwargs)
self.escaped_mapping.update(flow_inputs_escape_dict)


def convert_to_chat_list(obj):
if isinstance(obj, dict):
return {key: convert_to_chat_list(value) for key, value in obj.items()}
elif isinstance(obj, list):
return ChatInputList([convert_to_chat_list(item) for item in obj])
else:
return obj


def normalize_connection_config(connection):
"""
Normalizes the configuration of a given connection object for compatibility.
Expand Down
29 changes: 29 additions & 0 deletions src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_prompty.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from openai.types.chat import ChatCompletion

from promptflow._sdk._pf_client import PFClient
from promptflow._utils.multimedia_utils import ImageProcessor
from promptflow._utils.yaml_utils import load_yaml
from promptflow.client import load_flow
from promptflow.core import AsyncPrompty, Flow, Prompty
Expand Down Expand Up @@ -574,3 +575,31 @@ def test_tools_in_prompty(self):
result = prompty(chat_history=chat_history, question="No, predict me in next 3 days")
expect_argument = {"format": "json", "location": "Suzhou", "num_days": "3"}
assert expect_argument == json.loads(result["tool_calls"][0]["function"]["arguments"])

@pytest.mark.skip("Connection doesn't support vision model.")
def test_prompty_with_image_input(self, pf):
prompty_path = f"{PROMPTY_DIR}/prompty_with_image.prompty"
prompty = Prompty.load(source=prompty_path, model={"response": "all"})
response_result = prompty()
assert "Microsoft" in response_result.choices[0].message.content

image_path = DATA_DIR / "logo.jpg"
result = pf.test(
flow=prompty_path,
inputs={"question": "what is it", "image": f"data:image/jpg;path:{image_path.absolute()}"},
)
assert "Microsoft" in result

# Input with image object
image = ImageProcessor.create_image_from_string(str(image_path))
result = pf.test(flow=prompty_path, inputs={"question": "what is it", "image": image})
assert "Microsoft" in result

# Test prompty render
prompty = Prompty.load(source=prompty_path)
result = prompty.render(question="what is it", image=image)
assert f"data:image/jpeg;base64,{image.to_base64()}" in result

# Test estimate prompt token
result = prompty.estimate_token_count(question="what is it", image=image)
assert result == response_result.usage.prompt_tokens
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
---
name: Basic Prompt with Image
description: A basic prompt that uses the GPT-3 chat API to answer questions
model:
api: chat
configuration:
type: azure_openai
azure_deployment: gpt-4-vision-preview
connection: azure_open_ai_connection
parameters:
temperature: 0.2
sample:
image: "data:image/jpg;path:../datas/logo.jpg"
question: what is it
---
system:
As an AI assistant, your task involves interpreting images and responding to questions about the image.
Remember to provide accurate answers based on the information present in the image.
Directly give the answer, no more explanation is needed.

# user:
{{question}}
![image]({{image}})

0 comments on commit 52f6b90

Please sign in to comment.