-
Notifications
You must be signed in to change notification settings - Fork 15.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
openai[patch]: image token counting (#23147)
Resolves #23000 --------- Co-authored-by: isaac hershenson <[email protected]> Co-authored-by: ccurme <[email protected]>
- Loading branch information
1 parent
b3e53ff
commit 0a4ee86
Showing
6 changed files
with
487 additions
and
101 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,10 +2,13 @@ | |
|
||
from __future__ import annotations | ||
|
||
import base64 | ||
import json | ||
import logging | ||
import os | ||
import sys | ||
from io import BytesIO | ||
from math import ceil | ||
from operator import itemgetter | ||
from typing import ( | ||
Any, | ||
|
@@ -26,6 +29,7 @@ | |
cast, | ||
overload, | ||
) | ||
from urllib.parse import urlparse | ||
|
||
import openai | ||
import tiktoken | ||
|
@@ -736,7 +740,13 @@ def get_token_ids(self, text: str) -> List[int]: | |
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: | ||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. | ||
Official documentation: https://github.com/openai/openai-cookbook/blob/ | ||
**Requirements**: You must have the ``pillow`` installed if you want to count | ||
image tokens if you are specifying the image as a base64 string, and you must | ||
have both ``pillow`` and ``httpx`` installed if you are specifying the image | ||
as a URL. If these aren't installed image inputs will be ignored in token | ||
counting. | ||
OpenAI reference: https://github.com/openai/openai-cookbook/blob/ | ||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" | ||
if sys.version_info[1] <= 7: | ||
return super().get_num_tokens_from_messages(messages) | ||
|
@@ -753,17 +763,35 @@ def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: | |
raise NotImplementedError( | ||
f"get_num_tokens_from_messages() is not presently implemented " | ||
f"for model {model}. See " | ||
"https://platform.openai.com/docs/guides/text-generation/managing-tokens" | ||
"https://platform.openai.com/docs/guides/text-generation/managing-tokens" # noqa: E501 | ||
" for information on how messages are converted to tokens." | ||
) | ||
num_tokens = 0 | ||
messages_dict = [_convert_message_to_dict(m) for m in messages] | ||
for message in messages_dict: | ||
num_tokens += tokens_per_message | ||
for key, value in message.items(): | ||
# Cast str(value) in case the message value is not a string | ||
# This occurs with function messages | ||
num_tokens += len(encoding.encode(str(value))) | ||
if isinstance(value, list): | ||
for val in value: | ||
if isinstance(val, str) or val["type"] == "text": | ||
text = val["text"] if isinstance(val, dict) else val | ||
num_tokens += len(encoding.encode(text)) | ||
elif val["type"] == "image_url": | ||
if val["image_url"].get("detail") == "low": | ||
num_tokens += 85 | ||
else: | ||
image_size = _url_to_size(val["image_url"]["url"]) | ||
if not image_size: | ||
continue | ||
num_tokens += _count_image_tokens(*image_size) | ||
else: | ||
raise ValueError( | ||
f"Unrecognized content block type\n\n{val}" | ||
) | ||
else: | ||
# Cast str(value) in case the message value is not a string | ||
# This occurs with function messages | ||
num_tokens += len(encoding.encode(value)) | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
cefe-yalo
|
||
if key == "name": | ||
num_tokens += tokens_per_name | ||
# every reply is primed with <im_start>assistant | ||
|
@@ -1541,3 +1569,75 @@ def _lc_invalid_tool_call_to_openai_tool_call( | |
"arguments": invalid_tool_call["args"], | ||
}, | ||
} | ||
|
||
|
||
def _url_to_size(image_source: str) -> Optional[Tuple[int, int]]: | ||
try: | ||
from PIL import Image # type: ignore[import] | ||
except ImportError: | ||
logger.info( | ||
"Unable to count image tokens. To count image tokens please install " | ||
"`pip install -U pillow httpx`." | ||
) | ||
return None | ||
if _is_url(image_source): | ||
try: | ||
import httpx | ||
except ImportError: | ||
logger.info( | ||
"Unable to count image tokens. To count image tokens please install " | ||
"`pip install -U httpx`." | ||
) | ||
return None | ||
response = httpx.get(image_source) | ||
response.raise_for_status() | ||
width, height = Image.open(BytesIO(response.content)).size | ||
return width, height | ||
elif _is_b64(image_source): | ||
_, encoded = image_source.split(",", 1) | ||
data = base64.b64decode(encoded) | ||
width, height = Image.open(BytesIO(data)).size | ||
return width, height | ||
else: | ||
return None | ||
|
||
|
||
def _count_image_tokens(width: int, height: int) -> int: | ||
# Reference: https://platform.openai.com/docs/guides/vision/calculating-costs | ||
width, height = _resize(width, height) | ||
h = ceil(height / 512) | ||
w = ceil(width / 512) | ||
return (170 * h * w) + 85 | ||
|
||
|
||
def _is_url(s: str) -> bool: | ||
try: | ||
result = urlparse(s) | ||
return all([result.scheme, result.netloc]) | ||
except Exception as e: | ||
logger.debug(f"Unable to parse URL: {e}") | ||
return False | ||
|
||
|
||
def _is_b64(s: str) -> bool: | ||
return s.startswith("data:image") | ||
|
||
|
||
def _resize(width: int, height: int) -> Tuple[int, int]: | ||
# larger side must be <= 2048 | ||
if width > 2048 or height > 2048: | ||
if width > height: | ||
height = (height * 2048) // width | ||
width = 2048 | ||
else: | ||
width = (width * 2048) // height | ||
height = 2048 | ||
# smaller side must be <= 768 | ||
if width > 768 and height > 768: | ||
if width > height: | ||
width = (width * 768) // height | ||
height = 768 | ||
else: | ||
height = (width * 768) // height | ||
width = 768 | ||
return width, height |
Oops, something went wrong.
Is this a bug that there is no
str()
here?