Skip to content

Commit

Permalink
[prompty] Prompty supports to estimate token count (#3210)
Browse files Browse the repository at this point in the history
# Description
Return value is the prompt token + max_response_token
```python
prompty = Prompty.load(source=f"{PROMPTY_DIR}/prompty_example.prompty")
total_token = prompty.estimate_token_count(question="what is the result of 1+1?")
```
Please add an informative description that covers that changes made by
the pull request and link all relevant issues.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.
  • Loading branch information
lalala123123 authored May 11, 2024
1 parent 5bcfe22 commit 7117995
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 0 deletions.
31 changes: 31 additions & 0 deletions src/promptflow-core/promptflow/core/_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from promptflow._constants import DEFAULT_ENCODING, LANGUAGE_KEY, PROMPTY_EXTENSION, FlowLanguage
from promptflow._utils.flow_utils import is_flex_flow, is_prompty_flow, resolve_flow_path
from promptflow._utils.logger_utils import LoggerFactory
from promptflow._utils.yaml_utils import load_yaml_string
from promptflow.contracts.tool import ValueType
from promptflow.core._errors import MissingRequiredInputError
Expand All @@ -20,6 +21,7 @@
format_llm_response,
get_open_ai_client_by_connection,
handle_openai_error,
num_tokens_from_messages,
prepare_open_ai_request_params,
resolve_references,
send_request_to_llm,
Expand All @@ -31,6 +33,8 @@
from promptflow.tracing._experimental import enrich_prompt_template
from promptflow.tracing._trace import _traced

logger = LoggerFactory.get_logger(name=__name__)


class AbstractFlowBase(abc.ABC):
"""Abstract class for all Flow entities in both core and devkit."""
Expand Down Expand Up @@ -469,6 +473,33 @@ def render(self, *args, **kwargs):
# For chat mode, the message generated is list type. Convert to string type and return to user.
return str(prompt)

def estimate_token_count(self, *args, **kwargs):
"""Estimate the token count.
LLM will reject the request when prompt token + response token is greater than the maximum number of
tokens supported by the model. It is used to estimate the number of total tokens in this round of chat.
:param args: positional arguments are not supported.
:param kwargs: prompty inputs with key word arguments.
:return: Estimate total token count
:rtype: int
"""
if args:
raise UserErrorException("Prompty can only be rendered with keyword arguments.")
inputs = self._resolve_inputs(kwargs)
prompt = convert_prompt_template(self._template, inputs, self._model.api)
response_max_token = self._model.parameters.get("max_tokens", None)
if response_max_token is None:
logger.warning(
"The maximum number of tokens that can be generated in the chat completion is not configured. "
"It will directly return prompt token count."
)
elif not isinstance(response_max_token, int):
raise UserErrorException("Max_token needs to be integer.")
elif response_max_token <= 1:
raise UserErrorException(f"{response_max_token} is less than the minimum of max_tokens.")
total_token = num_tokens_from_messages(prompt, self._model._model) + (response_max_token or 0)
return total_token


class AsyncPrompty(Prompty):
"""Async prompty is based on Prompty, which is used to invoke prompty in async mode.
Expand Down
45 changes: 45 additions & 0 deletions src/promptflow-core/promptflow/core/_prompty_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path
from typing import List, Mapping

import tiktoken
from openai import APIConnectionError, APIStatusError, APITimeoutError, BadRequestError, OpenAIError, RateLimitError

from promptflow._utils.logger_utils import LoggerFactory
Expand Down Expand Up @@ -280,6 +281,50 @@ def format_stream(llm_response):
return result


def num_tokens_from_messages(messages, model):
"""Return the number of tokens used by a list of messages."""
# Ref: https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken#6-counting-tokens-for-chat-completions-api-calls # noqa: E501
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.warning("Model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model or "gpt-35-turbo":
logger.warning("gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
logger.warning("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return num_tokens_from_messages(messages, model="gpt-4-0613")
else:
raise NotImplementedError(
f"num_tokens_from_messages() is not implemented for model {model}. "
"See https://github.com/openai/openai-python/blob/main/chatml.md for information on "
"how messages are converted to tokens."
)
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens


def resolve_references(origin, base_path=None):
"""Resolve all reference in the object."""
if isinstance(origin, str):
Expand Down
34 changes: 34 additions & 0 deletions src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_prompty.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,40 @@ def test_render_prompty(self):
prompty.render(mock_key="mock_value")
assert "Missing required inputs" in ex.value.message

def test_estimate_token_count(self):
prompty = Prompty.load(
source=f"{PROMPTY_DIR}/prompty_example.prompty",
model={"response": "all"},
)
with pytest.raises(UserErrorException) as ex:
prompty.estimate_token_count("mock_input")
assert "Prompty can only be rendered with keyword arguments." in ex.value.message

with pytest.raises(MissingRequiredInputError) as ex:
prompty.estimate_token_count()
assert "Missing required inputs" in ex.value.message

with pytest.raises(UserErrorException) as ex:
invalid_prompty = Prompty.load(
source=f"{PROMPTY_DIR}/prompty_example.prompty",
model={"parameters": {"max_tokens": "invalid_tokens"}},
)
invalid_prompty.estimate_token_count(question="what is the result of 1+1?")
assert "Max_token needs to be integer." in ex.value.message

response = prompty(question="what is the result of 1+1?")
prompt_tokens = response.usage.prompt_tokens

total_token = prompty.estimate_token_count(question="what is the result of 1+1?")
assert total_token == prompt_tokens + prompty._model.parameters.get("max_tokens")

prompty = Prompty.load(
source=f"{PROMPTY_DIR}/prompty_example.prompty",
model={"parameters": {"max_tokens": None}},
)
total_token = prompty.estimate_token_count(question="what is the result of 1+1?")
assert total_token == prompt_tokens

def test_prompty_with_reference_file(self):
# Test run prompty with reference file
prompty = Prompty.load(source=f"{PROMPTY_DIR}/prompty_with_reference_file.prompty")
Expand Down

0 comments on commit 7117995

Please sign in to comment.