Skip to content

Commit

Permalink
feat(clean up) | clean up zenbase_tracer (#23)
Browse files Browse the repository at this point in the history
* Add trace flush method to ZenbaseTracer and integrate it in BootstrapFewShot

* Refactor ZenbaseTracer to limit trace storage

* Replace zenbase_tracer.all_traces = {} with zenbase_tracer.flush()

* Replace zenbase_manager.all_traces = {} with zenbase_manager.flush()

* Up version.
  • Loading branch information
ammirsm authored Jul 24, 2024
1 parent 49020ba commit ed9e24c
Show file tree
Hide file tree
Showing 14 changed files with 148 additions and 34 deletions.
6 changes: 2 additions & 4 deletions py/cookbooks/bootstrap_few_shot/arize.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,6 @@
},
"outputs": [],
"source": [
"# Empty the traces\n",
"zenbase_tracer.all_traces = {}\n",
"# Run the optimization\n",
"best_fn, candidates = bootstrap_few_shot.perform(\n",
" solver,\n",
Expand All @@ -578,7 +576,7 @@
},
"outputs": [],
"source": [
"zenbase_tracer.all_traces = {}\n",
"zenbase_tracer.flush()\n",
"best_fn({\"question\": \"What is 2+2?\"})"
]
},
Expand Down Expand Up @@ -736,7 +734,7 @@
},
"outputs": [],
"source": [
"zenbase_tracer.all_traces = {}\n",
"zenbase_tracer.flush()\n",
"optimized_function({\"question\": \"If I have 30% of shares, and Mo has 24.5% of shares, how many of our 10M shares are unassigned?\"})\n",
"function_traces = [v for k, v in zenbase_tracer.all_traces.items()][0][\"optimized\"]\n",
"from pprint import pprint\n",
Expand Down
4 changes: 1 addition & 3 deletions py/cookbooks/bootstrap_few_shot/langfuse.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,6 @@
},
"outputs": [],
"source": [
"# Empty the traces\n",
"zenbase_tracer.all_traces = {}\n",
"# Run the optimization\n",
"best_fn, candidates = bootstrap_few_shot.perform(\n",
" solver,\n",
Expand All @@ -525,7 +523,7 @@
},
"outputs": [],
"source": [
"zenbase_tracer.all_traces = {}\n",
"zenbase_tracer.flush()\n",
"best_fn({\"question\": \"What is 2 + 2?\"})"
]
},
Expand Down
6 changes: 2 additions & 4 deletions py/cookbooks/bootstrap_few_shot/langsmith.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -579,9 +579,7 @@
"collapsed": false
},
"outputs": [],
"source": [
"zenbase_tracer.all_traces = {}"
]
"source": "zenbase_tracer.flush()"
},
{
"cell_type": "code",
Expand Down Expand Up @@ -761,7 +759,7 @@
"collapsed": false
},
"source": [
"zenbase_tracer.all_traces = {}\n",
"zenbase_tracer.flush()\n",
"optimized_function({\"question\": \"If I have 30% of shares, and Mo has 24.5% of shares, how many of our 10M shares are unassigned?\"})\n",
"function_traces = [v for k, v in zenbase_tracer.all_traces.items()][0][\"optimized\"]\n",
"from pprint import pprint\n",
Expand Down
6 changes: 2 additions & 4 deletions py/cookbooks/bootstrap_few_shot/lunary.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -542,9 +542,7 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"zenbase_tracer.all_traces = {}"
]
"source": "zenbase_tracer.flush()"
},
{
"cell_type": "code",
Expand Down Expand Up @@ -705,7 +703,7 @@
},
"outputs": [],
"source": [
"zenbase_tracer.all_traces = {}\n",
"zenbase_tracer.flush()\n",
"optimized_function(\"If I have 30% of shares, and Mo has 24.5% of shares, how many of our 10M shares are unassigned?\")\n",
"function_traces = [v for k, v in zenbase_tracer.all_traces.items()][0][\"optimized\"]\n",
"from pprint import pprint\n",
Expand Down
6 changes: 2 additions & 4 deletions py/cookbooks/bootstrap_few_shot/parea.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,6 @@
},
"outputs": [],
"source": [
"# Empty the traces\n",
"zenbase_tracer.all_traces = {}\n",
"# Run the optimization\n",
"best_fn, candidates = bootstrap_few_shot.perform(\n",
" solver,\n",
Expand All @@ -618,7 +616,7 @@
"metadata": {},
"outputs": [],
"source": [
"zenbase_tracer.all_traces = {}\n",
"zenbase_tracer.flush()\n",
"best_fn({\"question\": \"What is 2 + 2?\"})"
]
},
Expand Down Expand Up @@ -789,7 +787,7 @@
},
"outputs": [],
"source": [
"zenbase_tracer.all_traces = {}\n",
"zenbase_tracer.flush()\n",
"optimized_function({\"question\": \"If I have 30% of shares, and Mo has 24.5% of shares, how many of our 10M shares are unassigned?\"})\n",
"function_traces = [v for k, v in zenbase_tracer.all_traces.items()][0][\"optimized\"]\n",
"from pprint import pprint\n",
Expand Down
2 changes: 1 addition & 1 deletion py/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "zenbase"
version = "0.0.4"
version = "0.0.5"
description = "LLMs made Zen"
authors = [{ name = "Cyrus Nouroozi", email = "[email protected]" }]
dependencies = [
Expand Down
21 changes: 16 additions & 5 deletions py/src/zenbase/core/managers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
from abc import ABC
from collections import OrderedDict
from contextlib import contextmanager
from typing import Any, Callable, Union

Expand All @@ -12,8 +13,9 @@ class BaseTracer(ABC):


class ZenbaseTracer(BaseTracer):
def __init__(self):
self.all_traces = {}
def __init__(self, max_traces=1000):
self.all_traces = OrderedDict()
self.max_traces = max_traces
self.current_trace = None
self.current_key = None
self.optimized_args = {}
Expand All @@ -23,6 +25,17 @@ def __call__(self, function: Callable[[Any], Any] = None, zenbase: LMZenbase = N
return lambda f: self.trace_function(f, zenbase)
return self.trace_function(function, zenbase)

def flush(self):
self.all_traces.clear()

def add_trace(self, run_timestamp: str, func_name: str, trace_data: dict):
if run_timestamp not in self.all_traces:
if len(self.all_traces) >= self.max_traces:
self.all_traces.popitem(last=False) # Remove the oldest item
self.all_traces[run_timestamp] = OrderedDict()
self.all_traces[run_timestamp][func_name] = trace_data
self.all_traces.move_to_end(run_timestamp) # Move this trace to the end (most recent)

def trace_function(self, function: Callable[[Any], Any] = None, zenbase: LMZenbase = None) -> LMFunction:
def wrapper(request, lm_function, *args, **kwargs):
func_name = function.__name__
Expand All @@ -47,9 +60,7 @@ def trace_context(self, func_name, run_timestamp, optimized_args=None):
yield
finally:
if self.current_key == run_timestamp:
if run_timestamp not in self.all_traces:
self.all_traces[run_timestamp] = {}
self.all_traces[run_timestamp][func_name] = self.current_trace
self.add_trace(run_timestamp, func_name, self.current_trace)
self.current_trace = None
self.current_key = None
self.optimized_args = {}
Expand Down
6 changes: 5 additions & 1 deletion py/src/zenbase/optim/metric/bootstrap_few_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,22 @@ def perform(
:param helper_class: The helper class that will be used to fetch the dataset and evaluator
"""
assert trace_manager is not None, "Zenbase is required for this operation"
# Clean up traces
trace_manager.flush()

test_set_evaluator = self.zen_adaptor.get_evaluator(data=self.test_set)
self.base_evaluation = test_set_evaluator(student_lm)

if not teacher_lm:
# Create the base LabeledFewShot teacher model
trace_manager.flush()
teacher_lm = self._create_teacher_model(self.zen_adaptor, student_lm, samples, rounds)

# Evaluate and validate the demo set
validated_training_set_demos = self._validate_demo_set(self.zen_adaptor, teacher_lm)

# Run each validated demo to fill up the traces
trace_manager.all_traces = {}
trace_manager.flush()
self._run_validated_demos(teacher_lm, validated_training_set_demos)

# Consolidate the traces to optimized args
Expand All @@ -86,6 +89,7 @@ def perform(
# Evaluate the optimized function
self.best_evaluation = test_set_evaluator(optimized_fn)

trace_manager.flush()
return self.Result(best_function=optimized_fn)

def _create_teacher_model(
Expand Down
2 changes: 1 addition & 1 deletion py/tests/adaptors/test_arize.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def score_answer(output: str, expected: dict):
)
assert teacher_lm is not None

zenbase_tracer.all_traces = {}
zenbase_tracer.flush()
teacher_lm({"question": "What is 2 + 2?"})

assert [v for k, v in zenbase_tracer.all_traces.items()][0]["optimized"]["planner_chain"]["args"][
Expand Down
6 changes: 3 additions & 3 deletions py/tests/adaptors/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def operation_finder(request: LMRequest):

assert teacher_lm is not None

zenbase_manager.all_traces = {}
zenbase_manager.flush()
teacher_lm({"question": "What is 2 + 2?"})

assert [v for k, v in zenbase_manager.all_traces.items()][0]["optimized"]["planner_chain"]["args"][
Expand Down Expand Up @@ -656,7 +656,7 @@ def operation_finder(request: LMRequest):

assert teacher_lm is not None

zenbase_manager.all_traces = {}
zenbase_manager.flush()
teacher_lm({"question": "What is 2 + 2?"})

assert [v for k, v in zenbase_manager.all_traces.items()][0]["optimized"]["planner_chain"]["args"][
Expand Down Expand Up @@ -836,7 +836,7 @@ def operation_finder(request: LMRequest):

assert teacher_lm is not None

zenbase_manager.all_traces = {}
zenbase_manager.flush()
teacher_lm({"question": "What is 2 + 2?"})

assert [v for k, v in zenbase_manager.all_traces.items()][0]["optimized"]["planner_chain"]["args"][
Expand Down
2 changes: 1 addition & 1 deletion py/tests/adaptors/test_langfuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def operation_finder(request: LMRequest):

assert teacher_lm is not None

zenbase_manager.all_traces = {}
zenbase_manager.flush()
teacher_lm({"question": "What is 2 + 2?"})

assert [v for k, v in zenbase_manager.all_traces.items()][0]["optimized"]["planner_chain"]["args"][
Expand Down
2 changes: 1 addition & 1 deletion py/tests/adaptors/test_lunary.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def operation_finder(request: LMRequest):
)
assert teacher_lm is not None

zenbase_manager.all_traces = {}
zenbase_manager.flush()
teacher_lm("What is 2 + 2?")

assert [v for k, v in zenbase_manager.all_traces.items()][0]["optimized"]["planner_chain"]["args"][
Expand Down
4 changes: 2 additions & 2 deletions py/tests/adaptors/test_parea.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def operation_finder(request: LMRequest):

assert best_lm is not None

zenbase_manager.all_traces = {}
zenbase_manager.flush()
best_lm({"question": "What is 2 + 2?"})

assert [v for k, v in zenbase_manager.all_traces.items()][0]["optimized"]["planner_chain"]["args"][
Expand Down Expand Up @@ -608,7 +608,7 @@ def operation_finder(request: LMRequest):
)
assert best_lm is not None

zenbase_manager.all_traces = {}
zenbase_manager.flush()
best_lm({"question": "What is 2 + 2?"})

assert [v for k, v in zenbase_manager.all_traces.items()][0]["optimized"]["planner_chain"]["args"][
Expand Down
109 changes: 109 additions & 0 deletions py/tests/core/managers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from collections import OrderedDict
from datetime import datetime
from unittest.mock import patch

import pytest

from zenbase.core.managers import ZenbaseTracer
from zenbase.types import LMFunction, LMZenbase


@pytest.fixture
Expand Down Expand Up @@ -132,3 +135,109 @@ def test_trace_layer_functions(zenbase_manager, layer_2_1, layer_1_1, layer_1_2)
assert len(zenbase_manager.all_traces) == 15
for trace in zenbase_manager.all_traces.values():
assert any(func in trace for func in ["_layer_2_1", "_layer_1_1", "_layer_1_2"])


@pytest.fixture
def tracer():
return ZenbaseTracer(max_traces=3)


def test_init(tracer):
assert isinstance(tracer.all_traces, OrderedDict)
assert tracer.max_traces == 3
assert tracer.current_trace is None
assert tracer.current_key is None
assert tracer.optimized_args == {}


def test_flush(tracer):
tracer.all_traces = OrderedDict({"key1": "value1", "key2": "value2"})
tracer.flush()
assert len(tracer.all_traces) == 0


def test_add_trace(tracer):
# Add first trace
tracer.add_trace("timestamp1", "func1", {"data": "trace1"})
assert len(tracer.all_traces) == 1
assert "timestamp1" in tracer.all_traces

# Add second trace
tracer.add_trace("timestamp2", "func2", {"data": "trace2"})
assert len(tracer.all_traces) == 2

# Add third trace
tracer.add_trace("timestamp3", "func3", {"data": "trace3"})
assert len(tracer.all_traces) == 3

# Add fourth trace (should remove oldest)
tracer.add_trace("timestamp4", "func4", {"data": "trace4"})
assert len(tracer.all_traces) == 3
assert "timestamp1" not in tracer.all_traces
assert "timestamp4" in tracer.all_traces


@patch("zenbase.utils.ksuid")
def test_trace_function(mock_ksuid, tracer):
mock_ksuid.return_value = "test_timestamp"

def test_func(request):
return request.inputs[0] + request.inputs[1]

zenbase = LMZenbase()
traced_func = tracer.trace_function(test_func, zenbase)
assert isinstance(traced_func, LMFunction)

result = traced_func(inputs=(2, 3))

assert result == 5
trace = tracer.all_traces[list(tracer.all_traces.keys())[0]]
assert "test_func" in trace["test_func"]
trace_info = trace["test_func"]["test_func"]
assert trace_info["args"]["request"].inputs == (2, 3)
assert trace_info["output"] == 5


def test_trace_context(tracer):
with tracer.trace_context("test_func", "test_timestamp"):
assert tracer.current_key == "test_timestamp"
assert isinstance(tracer.current_trace, dict)

assert tracer.current_trace is None
assert tracer.current_key is None
assert "test_timestamp" in tracer.all_traces


def test_max_traces_limit(tracer):
for i in range(5):
tracer.add_trace(f"timestamp{i}", f"func{i}", {"data": f"trace{i}"})

assert len(tracer.all_traces) == 3
assert "timestamp0" not in tracer.all_traces
assert "timestamp1" not in tracer.all_traces
assert "timestamp2" in tracer.all_traces
assert "timestamp3" in tracer.all_traces
assert "timestamp4" in tracer.all_traces


@patch("zenbase.utils.ksuid")
def test_optimized_args(mock_ksuid, tracer):
mock_ksuid.return_value = "test_timestamp"

def test_func(request, z=3):
x, y = request.inputs
return x + y + z

tracer.optimized_args = {"test_func": {"args": {"z": 5}}}
zenbase = LMZenbase()
traced_func = tracer.trace_function(test_func, zenbase)

result = traced_func(inputs=(2, 10))

assert result == 17 # 2 + 10 + 5
trace = tracer.all_traces[list(tracer.all_traces.keys())[0]]
assert "test_func" in trace["test_func"]
trace_info = trace["test_func"]["test_func"]
assert trace_info["args"]["request"].inputs == (2, 10)
assert trace_info["args"]["z"] == 5
assert trace_info["output"] == 17

0 comments on commit ed9e24c

Please sign in to comment.