Skip to content

Commit

Permalink
Fix o1 support (#31)
Browse files Browse the repository at this point in the history
* Add system role handling and message enhancements in GenericLMFunctionOptimizer

- Define `system_role` based on model type for user prompts.
- Include contextual messages before user inputs to improve interaction flow.
- Update kwargs for completion creation to include necessary parameters conditionally.

* Bump version from 0.0.20 to 0.0.21 in project configuration

* Remove commented-out test function for OpenAI bootstrap few shot

* Bump version to 0.0.22 in project configuration
  • Loading branch information
ammirsm authored Nov 22, 2024
1 parent fc7c3dd commit d276f78
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 209 deletions.
2 changes: 1 addition & 1 deletion py/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "zenbase"
version = "0.0.20"
version = "0.0.22"
description = "LLMs made Zen"
authors = [{ name = "Cyrus Nouroozi", email = "[email protected]" }]
dependencies = [
Expand Down
63 changes: 41 additions & 22 deletions py/src/zenbase/predefined/generic_lm_function/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ def __post_init__(self):
def _generate_lm_function(self) -> LMFunction:
@self.zenbase_tracer.trace_function
def generic_function(request):
system_role = "assistant" if self.model.startswith("o1") else "system"
messages = [
{"role": "system", "content": self.prompt},
{"role": system_role, "content": self.prompt},
]

if request.zenbase.task_demos:
messages.append({"role": "system", "content": "Here are some examples:"})
messages.append({"role": system_role, "content": "Here are some examples:"})
for demo in request.zenbase.task_demos:
if demo.inputs == request.inputs:
continue
Expand All @@ -63,17 +64,26 @@ def generic_function(request):
{"role": "assistant", "content": str(demo.outputs)},
]
)
messages.append({"role": "system", "content": "Now, please answer the following question:"})
messages.append({"role": system_role, "content": "Now, please answer the following question:"})

messages.append({"role": "user", "content": str(request.inputs)})
return self.instructor_client.chat.completions.create(
model=self.model,
response_model=self.output_model,
messages=messages,
max_retries=3,
logprobs=True,
top_logprobs=5,
)

kwargs = {
"model": self.model,
"response_model": self.output_model,
"messages": messages,
"max_retries": 3,
}

if not self.model.startswith("o1"):
kwargs.update(
{
"logprobs": True,
"top_logprobs": 5,
}
)

return self.instructor_client.chat.completions.create(**kwargs)

return generic_function

Expand Down Expand Up @@ -134,33 +144,42 @@ def _evaluate_best_function(self, test_evaluator, optimizer_result):
def create_lm_function_with_demos(self, prompt: str, demos: List[dict]) -> LMFunction:
@self.zenbase_tracer.trace_function
def lm_function_with_demos(request):
system_role = "assistant" if self.model.startswith("o1") else "system"
messages = [
{"role": "system", "content": prompt},
{"role": system_role, "content": prompt},
]

# Add demos to the messages
if demos:
messages.append({"role": "system", "content": "Here are some examples:"})
messages.append({"role": system_role, "content": "Here are some examples:"})
for demo in demos:
messages.extend(
[
{"role": "user", "content": str(demo["inputs"])},
{"role": "assistant", "content": str(demo["outputs"])},
]
)
messages.append({"role": "system", "content": "Now, please answer the following question:"})
messages.append({"role": system_role, "content": "Now, please answer the following question:"})

# Add the actual request
messages.append({"role": "user", "content": str(request.inputs)})

return self.instructor_client.chat.completions.create(
model=self.model,
response_model=self.output_model,
messages=messages,
max_retries=3,
logprobs=True,
top_logprobs=5,
)
kwargs = {
"model": self.model,
"response_model": self.output_model,
"messages": messages,
"max_retries": 3,
}

if not self.model.startswith("o1"):
kwargs.update(
{
"logprobs": True,
"top_logprobs": 5,
}
)

return self.instructor_client.chat.completions.create(**kwargs)

return lm_function_with_demos

Expand Down
186 changes: 0 additions & 186 deletions py/tests/adaptors/test_lunary.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import json
import logging
import os

import lunary
import pytest
Expand All @@ -16,7 +14,6 @@
from zenbase.core.managers import ZenbaseTracer
from zenbase.optim.metric.bootstrap_few_shot import BootstrapFewShot
from zenbase.optim.metric.labeled_few_shot import LabeledFewShot
from zenbase.settings import TEST_DIR
from zenbase.types import LMRequest

SAMPLES = 2
Expand Down Expand Up @@ -109,186 +106,3 @@ def langchain_chain(request: LMRequest):
@pytest.fixture(scope="module")
def lunary_helper():
return ZenLunary(client=lunary)


@pytest.mark.helpers
def test_lunary_openai_bootstrap_few_shot(optim: LabeledFewShot, lunary_helper: ZenLunary, openai):
zenbase_manager = ZenbaseTracer()

@zenbase_manager.trace_function
# @retry(
# stop=stop_after_attempt(3),
# wait=wait_exponential_jitter(max=8),
# before_sleep=before_sleep_log(log, logging.WARN),
# )
def solver(request: LMRequest):
messages = [
{
"role": "system",
"content": "You are an expert math solver. You have a question that you should answer. You have step by step actions that you should take to solve the problem. You have the operations that you should do to solve the problem. You should come just with the number for the answer, just the actual number like examples that you have. Follow the format of the examples as they have the final answer, you need to come up with the plan for solving them." # noqa
'return it with json like return it in the {"answer": " the answer "}',
}
]

for demo in request.zenbase.task_demos:
messages += [
{"role": "user", "content": f"Example Question: {demo.inputs}"},
{"role": "assistant", "content": f"Example Answer: {demo.outputs}"},
]

plan = planner_chain(request.inputs)
the_plan = plan["plan"]
# the_plan = 'plan["plan"]'
the_operation = operation_finder(
{
"plan": the_plan,
"question": request.inputs,
}
)
# the_operation = {"operation": "operation_finder"}

messages.append({"role": "user", "content": f"Question: {request.inputs}"})
messages.append({"role": "user", "content": f"Plan: {the_plan}"})
messages.append(
{"role": "user", "content": f"Mathematical Operation that needed: {the_operation['operation']}"}
)
messages.append(
{
"role": "user",
"content": "Now come with the answer as number, just return the number, nothing else, just NUMBERS.",
}
)

response = openai.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
response_format={"type": "json_object"},
)

print("Mathing...")
answer = json.loads(response.choices[0].message.content)
return answer["answer"]

@zenbase_manager.trace_function
# @retry(
# stop=stop_after_attempt(3),
# wait=wait_exponential_jitter(max=8),
# before_sleep=before_sleep_log(log, logging.WARN),
# )
def planner_chain(request: LMRequest):
messages = [
{
"role": "system",
"content": "You are an expert math solver. You have a question that you should create a step-by-step plan to solve it. Follow the format of the examples and return JSON object." # noqa
'return it in the {"plan": " the plan "}',
}
]

if request.zenbase.task_demos:
for demo in request.zenbase.task_demos[:2]:
messages += [
{"role": "user", "content": demo.inputs},
{"role": "assistant", "content": demo.outputs["plan"]},
]
messages.append({"role": "user", "content": request.inputs})

response = openai.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
response_format={"type": "json_object"},
)

print("Planning...")
answer = json.loads(response.choices[0].message.content)
return {"plan": " ".join(i for i in answer["plan"])}

@zenbase_manager.trace_function
# @retry(
# stop=stop_after_attempt(3),
# wait=wait_exponential_jitter(max=8),
# before_sleep=before_sleep_log(log, logging.WARN),
# )
def operation_finder(request: LMRequest):
messages = [
{
"role": "system",
"content": "You are an expert math solver. You have a plan for solving a problem that is step-by-step, you need to find the overall operation in the math to solve it. Just come up with math operation with simple math operations like sum, multiply, division and minus. Follow the format of the examples." # noqa
'return it with json like return it in the {"operation": " the operation "}',
}
]

if request.zenbase.task_demos:
for demo in request.zenbase.task_demos[:2]:
messages += [
{"role": "user", "content": f"Question: {demo.inputs['question']}"},
{"role": "user", "content": f"Plan: {demo.inputs['plan']}"},
{"role": "assistant", "content": demo.outputs["operation"]},
]

messages.append({"role": "user", "content": f"Question: {request.inputs['question']}"})
messages.append({"role": "user", "content": f"Plan: {request.inputs['plan']}"})

response = openai.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
response_format={"type": "json_object"},
)

print("Finding operation...")
answer = json.loads(response.choices[0].message.content)
return {"operation": answer["operation"]}

solver("What is 2 + 2?")

evaluator_kwargs = dict(
checklist="exact-match",
concurrency=2,
)

# for lunary there is not feature to create dataset with code, so dataset are created
# manually with UI, if you want to replicate the test on your own, you should put
# GSM8K examples to dataset name like below:
TRAIN_SET = "gsmk8k-train-set"
TEST_SET = "gsm8k-test-set"
VALIDATION_SET = "gsm8k-validation-set"

assert lunary_helper.fetch_dataset_demos(TRAIN_SET) is not None
assert lunary_helper.fetch_dataset_demos(TEST_SET) is not None
assert lunary_helper.fetch_dataset_demos(VALIDATION_SET) is not None

bootstrap_few_shot = BootstrapFewShot(
shots=SHOTS,
training_set=TRAIN_SET,
test_set=TEST_SET,
validation_set=VALIDATION_SET,
evaluator_kwargs=evaluator_kwargs,
zen_adaptor=lunary_helper,
)

teacher_lm, candidates = bootstrap_few_shot.perform(
solver,
samples=SAMPLES,
rounds=1,
trace_manager=zenbase_manager,
)
assert teacher_lm is not None

zenbase_manager.flush()
teacher_lm("What is 2 + 2?")

assert [v for k, v in zenbase_manager.all_traces.items()][0]["optimized"]["planner_chain"]["args"][
"request"
].zenbase.task_demos[0].inputs is not None
assert [v for k, v in zenbase_manager.all_traces.items()][0]["optimized"]["operation_finder"]["args"][
"request"
].zenbase.task_demos[0].inputs is not None
assert [v for k, v in zenbase_manager.all_traces.items()][0]["optimized"]["solver"]["args"][
"request"
].zenbase.task_demos[0].inputs is not None

path_of_the_file = os.path.join(TEST_DIR, "adaptors/bootstrap_few_shot_optimizer_test.zenbase")

bootstrap_few_shot.save_optimizer_args(path_of_the_file)

# assert that the file has been saved
assert os.path.exists(path_of_the_file)

0 comments on commit d276f78

Please sign in to comment.