Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add caching for generate() #751

Merged
merged 4 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 123 additions & 11 deletions docs/ai/text/generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,23 @@ Marvin can generate synthetic data according to a schema and instructions. Gener

!!! example

=== "Names"
=== "Names (`str`)"

We can generate a variety of names by providing instructions:
We can generate a variety of names by providing instructions. Note the default behavior is to generate a list of strings:

```python
import marvin

names = marvin.generate(
str, n=4, instructions="first names"
n=4, instructions="first names"
)

french_names = marvin.generate(
str, n=4, instructions="first names from France"
n=4, instructions="first names from France"
)

star_wars_names = marvin.generate(
str, n=4, instructions="first names from Star Wars"
n=4, instructions="first names from Star Wars"
)

```
Expand All @@ -42,9 +42,33 @@ Marvin can generate synthetic data according to a schema and instructions. Gener
assert star_wars_names == ['Luke', 'Leia', 'Han', 'Anakin']
```

=== "Locations"
=== "Populations (`dict[str, int]`)"

We can also generate structured data, such as locations:
By providing a target type, we can generate dictionaries that map countries to their populations:

```python
from pydantic import BaseModel

populations = marvin.generate(
target=dict[str, int],
n=4,
instructions="a map of country: population",
)
```

!!! success "Result"

```python
assert populations == [
{'China': 1444216107},
{'India': 1380004385},
{'United States': 331893745},
{'Indonesia': 276361783},
]
```
=== "Locations (Pydantic model)"

Pydantic models can also be used as targets. Here's a list of US cities named for presidents:

```python
from pydantic import BaseModel
Expand All @@ -54,9 +78,9 @@ Marvin can generate synthetic data according to a schema and instructions. Gener
state: str

locations = marvin.generate(
Location,
target=Location,
n=4,
instructions="cities in the United States named after famous people"
instructions="cities in the United States named after presidents"
)
```

Expand Down Expand Up @@ -84,14 +108,102 @@ Marvin can generate synthetic data according to a schema and instructions. Gener
The `generate` function is the primary tool for generating synthetic data. It accepts a `type` argument, which can be any Python type, Pydantic model, or `Literal`. It also has an argument `n`, which specifies the number of samples to generate. Finally, it accepts an `instructions` argument, which is a natural language description of the desired output. The LLM will use these instructions, in addition to the provided type, to guide its generation process. Instructions are especially important for types that are not self documenting, such as Python builtins like `str` and `int`.


## Supported types
## Supported targets

`generate` supports almost all builtin Python types, plus Pydantic models, Python's `Literal`, and `TypedDict`. Pydantic models are especially useful for specifying specific features of the generated data, such as locations, dates, or more complex types. Builtin types are most useful in conjunction with instructions that provide more precise criteria for generation.

Note that `generate` will always return a list of type you provide.
To specify the output type, pass it as the `target` argument to `generate`. Generate will always return a list of `n` items of the specified type. If no target is provided, `generate` will return a list of strings.

!!! warning "Avoid tuples"
OpenAI models currently have trouble parsing the API representation of tuples. Therefore we recommend using lists or Pydantic models (for more strict typing) instead. Tuple support will be added in a future release.

## Instructions

Data generation relies even more on instructions than other Marvin tools, as the potential for variation is much greater. Therefore, you should provide as much detail as possible in your instructions, in addition to any implicit documentation in your requested type.

Instructions are freeform natural language and can be as general or specific as you like. The LLM will do its best to comply with any instructions you give.

## Caching

Normally, each `generate` call would be independent. For some prompts, this would mean that each call produced very similar results to other calls. That would mean that generating, say, 10 items in a single call would produce a much more varied and high-quality result than generating 10 items in 5 calls of 2 items each.

To mediate this issue, Marvin maintains an in-memory cache of the last 100 results produced by each `generate` prompt. These responses are shown to the LLM during generation to encourage variation. Note that the cache is not persisted across Python sessions. Cached results are also subject to a token cap to avoid flooding the LLM's context window. The token cap can be set with `MARVIN_AI_TEXT_GENERATE_CACHE_TOKEN_CAP` and defaults to 600.

To disable this behavior, pass `use_cache=False` to `generate`.

Here is an example of how the cache improves generation. The first tab shows 10 cities generated in a single call; the second shows 10 cities generated in 5 calls of 2 cities each; and the third shows 10 cities generated in 5 calls but with the cache disabled.

The first and second tabs both show high-quality, varied results. The third tab is more disappointing, as it shows almost no variation.

=== "Single call"
Generate 10 cities in a single call, which produces a varied list:

```python
cities = marvin.generate(n=10, instructions='major US cities')
```

!!! success "Result"
```python
assert cities == [
'New York',
'Los Angeles',
'Chicago',
'Houston',
'Phoenix',
'Philadelphia',
'San Antonio',
'San Diego',
'Dallas',
'San Jose'
]
```

=== "Five calls, with caching"
Generate 10 cities in a five calls, using the cache. This also produces a varied list:
```python
cities = []
for _ in range(5):
cities.extend(marvin.generate(n=2, instructions='major US cities'))
```
!!! success "Result"
```python
assert cities == [
'Chicago',
'San Francisco',
'Seattle',
'New York City',
'Los Angeles',
'Houston',
'Miami',
'Dallas',
'Atlanta',
'Boston'
]
```
=== "Five calls, without caching"
Generate 10 cities in five calls, without the cache. This produces a list with almost no variation, since each call is independent:

```python
cities = []
for _ in range(5):
cities.extend(marvin.generate(
n=2,
instructions='major US cities',
use_cache=False,
))
```
!!! failure "Result"
```python
assert cities == [
'Houston',
'Seattle',
'Chicago',
'Houston',
'Chicago',
'Houston',
'Chicago',
'Houston',
'Los Angeles',
'Houston'
]
```
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ nav:
- marvin.utilities.jinja: api_reference/utilities/jinja.md
- marvin.utilities.logging: api_reference/utilities/logging.md
- marvin.utilities.tools: api_reference/utilities/tools.md
- marvin.utilities.strings: api_reference/utilities/strings.md
- Cookbook:
- Entity deduplication: examples/deduplication.md
# - GitHub Activity Digest: examples/github_digest.md
Expand Down
29 changes: 15 additions & 14 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,23 @@ classifiers = [
"Programming Language :: Python :: 3 :: Only",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
keywords = ["ai", "chatbot", "llm"]
keywords = ["ai", "chatbot", "llm", "NLP", "natural language processing"]
requires-python = ">=3.9"
dependencies = [
"fastapi",
"httpx>=0.24.1",
"jinja2>=3.1.2",
"jsonpatch>=1.33",
"openai>=1.1.0",
"pydantic>=2.4.2",
"pydantic_settings",
"rich>=12",
"tiktoken>=0.4.0",
"typer>=0.9.0",
"typing_extensions>=4.0.0",
"tzdata>=2023.3", # need for windows
"uvicorn>=0.22.0"
"cachetools>=5",
"fastapi",
"httpx>=0.24.1",
"jinja2>=3.1.2",
"jsonpatch>=1.33",
"openai>=1.1.0",
"pydantic>=2.4.2",
"pydantic_settings",
"rich>=12",
"tiktoken>=0.4.0",
"typer>=0.9.0",
"typing_extensions>=4.0.0",
"tzdata>=2023.3", # need for windows
"uvicorn>=0.22.0"
]

[project.optional-dependencies]
Expand Down
19 changes: 16 additions & 3 deletions src/marvin/ai/prompts/text_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@
whenever necessary to supply missing or omitted data. You will be given
instructions or a type format, as well as a number of entities to generate.

Unless the user explicitly says otherwise, assume they are request a HIGHLY
RANDOM and DIVERSE but EXTREMELY REALISTIC selection of useful outputs that
meet their criteria.
Unless the user explicitly says otherwise, assume they are request a VARIED
and REALISTIC selection of useful outputs that meet their criteria. However,
you should prefer common responses to uncommon ones.

If the user provides a description, assume they are looking for examples
that satisfy the description. Do not provide more information than the user
Expand All @@ -116,6 +116,19 @@
Call the `FormatResponse` tool to validate your response, and use the
following schema: {{ response_format }}

{% if previous_responses -%}
## Previous responses

You have been asked to generate this data before, and these were your
responses (ordered by most recently seen to least recently seen). Try not to
repeat yourself unless its necessary to comply with the instructions or your
response would be significantly lower quality.

{% for response in previous_responses -%}
- {{response}}
{% endfor %}
{% endif %}

"""
)

Expand Down
50 changes: 42 additions & 8 deletions src/marvin/ai/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import inspect
from collections import deque
from enum import Enum
from functools import partial, wraps
from typing import (
Expand All @@ -17,6 +18,7 @@
get_origin,
)

from cachetools import LRUCache
from pydantic import BaseModel

import marvin
Expand All @@ -38,12 +40,15 @@
from marvin.utilities.jinja import Transcript
from marvin.utilities.logging import get_logger
from marvin.utilities.python import PythonFunction
from marvin.utilities.strings import count_tokens

T = TypeVar("T")
M = TypeVar("M", bound=BaseModel)

logger = get_logger(__name__)

GENERATE_CACHE = LRUCache(maxsize=1000)


class EjectRequest(Exception):
def __init__(self, request):
Expand Down Expand Up @@ -343,9 +348,10 @@ def classify(


def generate(
type_: Optional[type[T]] = None,
target: Optional[type[T]] = None,
instructions: Optional[str] = None,
n: int = 1,
use_cache: bool = True,
jlowin marked this conversation as resolved.
Show resolved Hide resolved
temperature: float = 1,
model_kwargs: Optional[dict] = None,
client: Optional[MarvinClient] = None,
Expand All @@ -358,9 +364,12 @@ def generate(
least 'n' items.

Args:
type_ (type, optional): The type of items to generate. Defaults to None.
target (type, optional): The type of items to generate. Defaults to None.
instructions (str, optional): Instructions for the generation. Defaults to None.
n (int, optional): The number of items to generate. Defaults to 1.
use_cache (bool, optional): If True, the function will cache the last
100 responses for each (target, instructions, and temperature) and use
those to avoid repetition on subsequent calls. Defaults to True.
temperature (float, optional): The temperature for the generation. Defaults to 1.
model_kwargs (dict, optional): Additional keyword arguments for the
language model. Defaults to None.
Expand All @@ -370,24 +379,49 @@ def generate(
list: A list of generated items.
"""

if type_ is None and instructions is None:
raise ValueError("Must provide either a type or instructions.")
elif type_ is None:
type_ = str
if target is None and instructions is None:
raise ValueError("Must provide either a target type or instructions.")
elif target is None:
target = str

# cache the last 100 responses for each (target, instructions, and temperature)
# to avoid repetition and encourage variation
cache_key = (target, instructions, temperature)
cached_responses = GENERATE_CACHE.setdefault(cache_key, deque(maxlen=100))
previous_responses = []
tokens = 0
model = model_kwargs.get("model", None) if model_kwargs else None
# use a token cap to avoid flooding the prompt with previous responses
for r in list(cached_responses) if use_cache else []:
if tokens > marvin.settings.ai.text.generate_cache_token_cap:
continue
tokens += count_tokens(str(r), model=model)
previous_responses.append(r)

# make sure we generate at least n items
result = [0] * (n + 1)
while len(result) != n:
result = _generate_typed_llm_response_with_tool(
prompt_template=GENERATE_PROMPT,
prompt_kwargs=dict(type_=type_, n=n, instructions=instructions),
type_=list[type_],
prompt_kwargs=dict(
type_=target,
n=n,
instructions=instructions,
previous_responses=previous_responses,
),
type_=list[target],
model_kwargs=(model_kwargs or {}) | dict(temperature=temperature),
client=client,
)

if len(result) > n:
result = result[:n]

# don't cache the respones if we're not using the cache, because the AI will
# see repeats and conclude they're ok
if use_cache:
for r in result:
cached_responses.appendleft(r)
return result


Expand Down
10 changes: 10 additions & 0 deletions src/marvin/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,15 @@ def discover_api_key(cls, v):
return v


class TextAISettings(MarvinSettings):
model_config = SettingsConfigDict(env_prefix="marvin_ai_text_")
generate_cache_token_cap: int = Field(600)


class AISettings(MarvinSettings):
text: TextAISettings = Field(default_factory=TextAISettings)


class Settings(MarvinSettings):
"""Settings for `marvin`.

Expand All @@ -224,6 +233,7 @@ class Settings(MarvinSettings):
)

openai: OpenAISettings = Field(default_factory=OpenAISettings)
ai: AISettings = Field(default_factory=AISettings)

log_level: str = Field(
default="INFO",
Expand Down
Loading