Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix o1 support #31

Merged
merged 4 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion py/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "zenbase"
version = "0.0.20"
version = "0.0.22"
description = "LLMs made Zen"
authors = [{ name = "Cyrus Nouroozi", email = "[email protected]" }]
dependencies = [
Expand Down
63 changes: 41 additions & 22 deletions py/src/zenbase/predefined/generic_lm_function/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ def __post_init__(self):
def _generate_lm_function(self) -> LMFunction:
@self.zenbase_tracer.trace_function
def generic_function(request):
system_role = "assistant" if self.model.startswith("o1") else "system"
messages = [
{"role": "system", "content": self.prompt},
{"role": system_role, "content": self.prompt},
]

if request.zenbase.task_demos:
messages.append({"role": "system", "content": "Here are some examples:"})
messages.append({"role": system_role, "content": "Here are some examples:"})
for demo in request.zenbase.task_demos:
if demo.inputs == request.inputs:
continue
Expand All @@ -63,17 +64,26 @@ def generic_function(request):
{"role": "assistant", "content": str(demo.outputs)},
]
)
messages.append({"role": "system", "content": "Now, please answer the following question:"})
messages.append({"role": system_role, "content": "Now, please answer the following question:"})

messages.append({"role": "user", "content": str(request.inputs)})
return self.instructor_client.chat.completions.create(
model=self.model,
response_model=self.output_model,
messages=messages,
max_retries=3,
logprobs=True,
top_logprobs=5,
)

kwargs = {
"model": self.model,
"response_model": self.output_model,
"messages": messages,
"max_retries": 3,
}

if not self.model.startswith("o1"):
kwargs.update(
{
"logprobs": True,
"top_logprobs": 5,
}
)

return self.instructor_client.chat.completions.create(**kwargs)

return generic_function

Expand Down Expand Up @@ -134,33 +144,42 @@ def _evaluate_best_function(self, test_evaluator, optimizer_result):
def create_lm_function_with_demos(self, prompt: str, demos: List[dict]) -> LMFunction:
@self.zenbase_tracer.trace_function
def lm_function_with_demos(request):
system_role = "assistant" if self.model.startswith("o1") else "system"
messages = [
{"role": "system", "content": prompt},
{"role": system_role, "content": prompt},
]

# Add demos to the messages
if demos:
messages.append({"role": "system", "content": "Here are some examples:"})
messages.append({"role": system_role, "content": "Here are some examples:"})
for demo in demos:
messages.extend(
[
{"role": "user", "content": str(demo["inputs"])},
{"role": "assistant", "content": str(demo["outputs"])},
]
)
messages.append({"role": "system", "content": "Now, please answer the following question:"})
messages.append({"role": system_role, "content": "Now, please answer the following question:"})

# Add the actual request
messages.append({"role": "user", "content": str(request.inputs)})

return self.instructor_client.chat.completions.create(
model=self.model,
response_model=self.output_model,
messages=messages,
max_retries=3,
logprobs=True,
top_logprobs=5,
)
kwargs = {
"model": self.model,
"response_model": self.output_model,
"messages": messages,
"max_retries": 3,
}

if not self.model.startswith("o1"):
kwargs.update(
{
"logprobs": True,
"top_logprobs": 5,
}
)

return self.instructor_client.chat.completions.create(**kwargs)

return lm_function_with_demos

Expand Down
186 changes: 0 additions & 186 deletions py/tests/adaptors/test_lunary.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import json
import logging
import os

import lunary
import pytest
Expand All @@ -16,7 +14,6 @@
from zenbase.core.managers import ZenbaseTracer
from zenbase.optim.metric.bootstrap_few_shot import BootstrapFewShot
from zenbase.optim.metric.labeled_few_shot import LabeledFewShot
from zenbase.settings import TEST_DIR
from zenbase.types import LMRequest

SAMPLES = 2
Expand Down Expand Up @@ -109,186 +106,3 @@ def langchain_chain(request: LMRequest):
@pytest.fixture(scope="module")
def lunary_helper():
return ZenLunary(client=lunary)


@pytest.mark.helpers
def test_lunary_openai_bootstrap_few_shot(optim: LabeledFewShot, lunary_helper: ZenLunary, openai):
zenbase_manager = ZenbaseTracer()

@zenbase_manager.trace_function
# @retry(
# stop=stop_after_attempt(3),
# wait=wait_exponential_jitter(max=8),
# before_sleep=before_sleep_log(log, logging.WARN),
# )
def solver(request: LMRequest):
messages = [
{
"role": "system",
"content": "You are an expert math solver. You have a question that you should answer. You have step by step actions that you should take to solve the problem. You have the operations that you should do to solve the problem. You should come just with the number for the answer, just the actual number like examples that you have. Follow the format of the examples as they have the final answer, you need to come up with the plan for solving them." # noqa
'return it with json like return it in the {"answer": " the answer "}',
}
]

for demo in request.zenbase.task_demos:
messages += [
{"role": "user", "content": f"Example Question: {demo.inputs}"},
{"role": "assistant", "content": f"Example Answer: {demo.outputs}"},
]

plan = planner_chain(request.inputs)
the_plan = plan["plan"]
# the_plan = 'plan["plan"]'
the_operation = operation_finder(
{
"plan": the_plan,
"question": request.inputs,
}
)
# the_operation = {"operation": "operation_finder"}

messages.append({"role": "user", "content": f"Question: {request.inputs}"})
messages.append({"role": "user", "content": f"Plan: {the_plan}"})
messages.append(
{"role": "user", "content": f"Mathematical Operation that needed: {the_operation['operation']}"}
)
messages.append(
{
"role": "user",
"content": "Now come with the answer as number, just return the number, nothing else, just NUMBERS.",
}
)

response = openai.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
response_format={"type": "json_object"},
)

print("Mathing...")
answer = json.loads(response.choices[0].message.content)
return answer["answer"]

@zenbase_manager.trace_function
# @retry(
# stop=stop_after_attempt(3),
# wait=wait_exponential_jitter(max=8),
# before_sleep=before_sleep_log(log, logging.WARN),
# )
def planner_chain(request: LMRequest):
messages = [
{
"role": "system",
"content": "You are an expert math solver. You have a question that you should create a step-by-step plan to solve it. Follow the format of the examples and return JSON object." # noqa
'return it in the {"plan": " the plan "}',
}
]

if request.zenbase.task_demos:
for demo in request.zenbase.task_demos[:2]:
messages += [
{"role": "user", "content": demo.inputs},
{"role": "assistant", "content": demo.outputs["plan"]},
]
messages.append({"role": "user", "content": request.inputs})

response = openai.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
response_format={"type": "json_object"},
)

print("Planning...")
answer = json.loads(response.choices[0].message.content)
return {"plan": " ".join(i for i in answer["plan"])}

@zenbase_manager.trace_function
# @retry(
# stop=stop_after_attempt(3),
# wait=wait_exponential_jitter(max=8),
# before_sleep=before_sleep_log(log, logging.WARN),
# )
def operation_finder(request: LMRequest):
messages = [
{
"role": "system",
"content": "You are an expert math solver. You have a plan for solving a problem that is step-by-step, you need to find the overall operation in the math to solve it. Just come up with math operation with simple math operations like sum, multiply, division and minus. Follow the format of the examples." # noqa
'return it with json like return it in the {"operation": " the operation "}',
}
]

if request.zenbase.task_demos:
for demo in request.zenbase.task_demos[:2]:
messages += [
{"role": "user", "content": f"Question: {demo.inputs['question']}"},
{"role": "user", "content": f"Plan: {demo.inputs['plan']}"},
{"role": "assistant", "content": demo.outputs["operation"]},
]

messages.append({"role": "user", "content": f"Question: {request.inputs['question']}"})
messages.append({"role": "user", "content": f"Plan: {request.inputs['plan']}"})

response = openai.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
response_format={"type": "json_object"},
)

print("Finding operation...")
answer = json.loads(response.choices[0].message.content)
return {"operation": answer["operation"]}

solver("What is 2 + 2?")

evaluator_kwargs = dict(
checklist="exact-match",
concurrency=2,
)

# for lunary there is not feature to create dataset with code, so dataset are created
# manually with UI, if you want to replicate the test on your own, you should put
# GSM8K examples to dataset name like below:
TRAIN_SET = "gsmk8k-train-set"
TEST_SET = "gsm8k-test-set"
VALIDATION_SET = "gsm8k-validation-set"

assert lunary_helper.fetch_dataset_demos(TRAIN_SET) is not None
assert lunary_helper.fetch_dataset_demos(TEST_SET) is not None
assert lunary_helper.fetch_dataset_demos(VALIDATION_SET) is not None

bootstrap_few_shot = BootstrapFewShot(
shots=SHOTS,
training_set=TRAIN_SET,
test_set=TEST_SET,
validation_set=VALIDATION_SET,
evaluator_kwargs=evaluator_kwargs,
zen_adaptor=lunary_helper,
)

teacher_lm, candidates = bootstrap_few_shot.perform(
solver,
samples=SAMPLES,
rounds=1,
trace_manager=zenbase_manager,
)
assert teacher_lm is not None

zenbase_manager.flush()
teacher_lm("What is 2 + 2?")

assert [v for k, v in zenbase_manager.all_traces.items()][0]["optimized"]["planner_chain"]["args"][
"request"
].zenbase.task_demos[0].inputs is not None
assert [v for k, v in zenbase_manager.all_traces.items()][0]["optimized"]["operation_finder"]["args"][
"request"
].zenbase.task_demos[0].inputs is not None
assert [v for k, v in zenbase_manager.all_traces.items()][0]["optimized"]["solver"]["args"][
"request"
].zenbase.task_demos[0].inputs is not None

path_of_the_file = os.path.join(TEST_DIR, "adaptors/bootstrap_few_shot_optimizer_test.zenbase")

bootstrap_few_shot.save_optimizer_args(path_of_the_file)

# assert that the file has been saved
assert os.path.exists(path_of_the_file)
Loading