Skip to content

Commit

Permalink
[OPIK-99] Implement batching for spans creation (#298)
Browse files Browse the repository at this point in the history
* Batch processing draft

* Batching draft

* [DRAFT] Update backend emulator for tests, update message processing to use batching

* Fix lint errors

* Implement _process_create_span_batch_message method in MessageProcessor

* Remove assert, add debug log message

* Add base batcher unit test

* Make batch manager flush all batches if streamer.flush() was called, add tests for span batcher

* Add tests for batch manager

* Add test for flushing thread

* Add one more test for batch manager

* Fix lint errors

* Fix lint errors

* Fix bug when backend emulator didn't add feedback scores to traces

* Fix lint errors

* Rename flush_interval to flush_interval_seconds

* Enable debug logs for e2e tests. Update debug messages in message_processors.py

* Remove export statement from e2e workflow

* Enable backend build. Enable debug logs for e2e tests

* Update docker compose and e2e workflow file

* Make batching disabled by default in Opik client. It is now enabled manually in Opik clients created under the hood of SDK

* Add more unit tests for message processing

* Add docstring for _use_batching parameter

* Add missing _SECONDS suffix

* Rename constant

* Undone e2e tests infra changes

---------

Co-authored-by: Andres Cruz <[email protected]>
  • Loading branch information
alexkuzmik and andrescrz authored Sep 30, 2024
1 parent 5f99fae commit 939accf
Show file tree
Hide file tree
Showing 22 changed files with 623 additions and 31 deletions.
1 change: 1 addition & 0 deletions .github/workflows/sdk-e2e-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ jobs:
run: |
cd ${{ github.workspace }}/sdks/python
export OPIK_URL_OVERRIDE=http://localhost:5173/api
export OPIK_CONSOLE_LOGGING_LEVEL=DEBUG
pytest tests/e2e -vv
- name: Keep BE log in case of failure
Expand Down
2 changes: 2 additions & 0 deletions deployment/docker-compose/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ services:
build:
context: ../../apps/opik-backend
dockerfile: Dockerfile
args:
OPIK_VERSION: ${OPIK_VERSION:-latest}
platform: linux/amd64
hostname: backend
command: [ "bash", "-c", "./run_db_migrations.sh && ./entrypoint.sh" ]
Expand Down
13 changes: 11 additions & 2 deletions sdks/python/src/opik/api_objects/opik_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
project_name: Optional[str] = None,
workspace: Optional[str] = None,
host: Optional[str] = None,
_use_batching: bool = False,
) -> None:
"""
Initialize an Opik object that can be used to log traces and spans manually to Opik server.
Expand All @@ -37,6 +38,8 @@ def __init__(
project_name: The name of the project. If not provided, traces and spans will be logged to the `Default Project`.
workspace: The name of the workspace. If not provided, `default` will be used.
host: The host URL for the Opik server. If not provided, it will default to `https://www.comet.com/opik/api`.
_use_batching: intended for internal usage in specific conditions only.
Enabling it is unsafe and can lead to data loss.
Returns:
None
"""
Expand All @@ -51,11 +54,16 @@ def __init__(
base_url=config_.url_override,
workers=config_.background_workers,
api_key=config_.api_key,
use_batching=_use_batching,
)
atexit.register(self.end, timeout=self._flush_timeout)

def _initialize_streamer(
self, base_url: str, workers: int, api_key: Optional[str]
self,
base_url: str,
workers: int,
api_key: Optional[str],
use_batching: bool,
) -> None:
httpx_client_ = httpx_client.get(workspace=self._workspace, api_key=api_key)
self._rest_client = rest_api_client.OpikApi(
Expand All @@ -66,6 +74,7 @@ def _initialize_streamer(
self._streamer = streamer_constructors.construct_online_streamer(
n_consumers=workers,
rest_client=self._rest_client,
use_batching=use_batching,
)

def trace(
Expand Down Expand Up @@ -437,6 +446,6 @@ def get_span_content(self, id: str) -> span_public.SpanPublic:

@functools.lru_cache()
def get_client_cached() -> Opik:
client = Opik()
client = Opik(_use_batching=True)

return client
Empty file.
49 changes: 49 additions & 0 deletions sdks/python/src/opik/message_processing/batching/base_batcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import threading
import time
import abc

from typing import List, Callable
from .. import messages


class BaseBatcher(abc.ABC):
def __init__(
self,
flush_callback: Callable[[messages.BaseMessage], None],
max_batch_size: int,
flush_interval_seconds: float,
):
self._flush_interval_seconds: float = flush_interval_seconds
self._flush_callback: Callable[[messages.BaseMessage], None] = flush_callback
self._accumulated_messages: List[messages.BaseMessage] = []
self._max_batch_size: int = max_batch_size

self._last_time_flush_callback_called: float = time.time()
self._lock = threading.RLock()

def add(self, message: messages.BaseMessage) -> None:
with self._lock:
self._accumulated_messages.append(message)
if len(self._accumulated_messages) == self._max_batch_size:
self.flush()

def flush(self) -> None:
with self._lock:
if len(self._accumulated_messages) > 0:
batch_message = self._create_batch_from_accumulated_messages()
self._accumulated_messages = []

self._flush_callback(batch_message)
self._last_time_flush_callback_called = time.time()

def is_ready_to_flush(self) -> bool:
return (
time.time() - self._last_time_flush_callback_called
) >= self._flush_interval_seconds

def is_empty(self) -> bool:
with self._lock:
return len(self._accumulated_messages) == 0

@abc.abstractmethod
def _create_batch_from_accumulated_messages(self) -> messages.BaseMessage: ...
41 changes: 41 additions & 0 deletions sdks/python/src/opik/message_processing/batching/batch_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Type, Dict
from .. import messages
from . import base_batcher
from . import flushing_thread


class BatchManager:
def __init__(
self,
message_to_batcher_mapping: Dict[
Type[messages.BaseMessage], base_batcher.BaseBatcher
],
) -> None:
self._message_to_batcher_mapping = message_to_batcher_mapping
self._flushing_thread = flushing_thread.FlushingThread(
batchers=list(self._message_to_batcher_mapping.values())
)

def start(self) -> None:
self._flushing_thread.start()

def stop(self) -> None:
self._flushing_thread.close()

def message_supports_batching(self, message: messages.BaseMessage) -> bool:
return message.__class__ in self._message_to_batcher_mapping

def process_message(self, message: messages.BaseMessage) -> None:
self._message_to_batcher_mapping[type(message)].add(message)

def is_empty(self) -> bool:
return all(
[
batcher.is_empty()
for batcher in self._message_to_batcher_mapping.values()
]
)

def flush(self) -> None:
for batcher in self._message_to_batcher_mapping.values():
batcher.flush()
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import queue
from typing import Type, Dict

from .. import messages

from . import base_batcher
from . import create_span_message_batcher
from . import batch_manager

CREATE_SPANS_MESSAGE_BATCHER_FLUSH_INTERVAL_SECONDS = 1.0
CREATE_SPANS_MESSAGE_BATCHER_MAX_BATCH_SIZE = 1000


def create_batch_manager(message_queue: queue.Queue) -> batch_manager.BatchManager:
create_span_message_batcher_ = create_span_message_batcher.CreateSpanMessageBatcher(
flush_interval_seconds=CREATE_SPANS_MESSAGE_BATCHER_FLUSH_INTERVAL_SECONDS,
max_batch_size=CREATE_SPANS_MESSAGE_BATCHER_MAX_BATCH_SIZE,
flush_callback=message_queue.put,
)

MESSAGE_TO_BATCHER_MAPPING: Dict[
Type[messages.BaseMessage], base_batcher.BaseBatcher
] = {messages.CreateSpanMessage: create_span_message_batcher_}

batch_manager_ = batch_manager.BatchManager(
message_to_batcher_mapping=MESSAGE_TO_BATCHER_MAPPING
)

return batch_manager_
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from . import base_batcher
from .. import messages


class CreateSpanMessageBatcher(base_batcher.BaseBatcher):
def _create_batch_from_accumulated_messages(
self,
) -> messages.CreateSpansBatchMessage:
return messages.CreateSpansBatchMessage(batch=self._accumulated_messages) # type: ignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import threading
import time
from typing import List

from . import base_batcher


class FlushingThread(threading.Thread):
def __init__(
self,
batchers: List[base_batcher.BaseBatcher],
probe_interval_seconds: float = 0.1,
) -> None:
threading.Thread.__init__(self, daemon=True)
self._batchers = batchers
self._probe_interval_seconds = probe_interval_seconds
self._closed = False

def close(self) -> None:
for batcher in self._batchers:
batcher.flush()

self._closed = True

def run(self) -> None:
while not self._closed:
for batcher in self._batchers:
if batcher.is_ready_to_flush():
batcher.flush()
time.sleep(self._probe_interval_seconds)
37 changes: 34 additions & 3 deletions sdks/python/src/opik/message_processing/message_processors.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import abc
import logging
from typing import Callable, Dict, Type
from typing import Callable, Dict, Type, List

from opik import logging_messages
from . import messages
from .jsonable_encoder import jsonable_encoder
from .. import dict_utils
from ..rest_api import client as rest_api_client
from ..rest_api.types import feedback_score_batch_item
from ..rest_api.types import span_write

LOGGER = logging.getLogger(__name__)

Expand All @@ -29,6 +30,7 @@ def __init__(self, rest_client: rest_api_client.OpikApi):
messages.UpdateTraceMessage: self._process_update_trace_message, # type: ignore
messages.AddTraceFeedbackScoresBatchMessage: self._process_add_trace_feedback_scores_batch_message, # type: ignore
messages.AddSpanFeedbackScoresBatchMessage: self._process_add_span_feedback_scores_batch_message, # type: ignore
messages.CreateSpansBatchMessage: self._process_create_span_batch_message, # type: ignore
}

def process(self, message: messages.BaseMessage) -> None:
Expand Down Expand Up @@ -144,7 +146,7 @@ def _process_add_span_feedback_scores_batch_message(
for score_message in message.batch
]

LOGGER.debug("Score batch of spans request: %s", scores)
LOGGER.debug("Score batch of spans feedbacks scores request: %s", scores)

self._rest_client.spans.score_batch_of_spans(
scores=scores,
Expand All @@ -158,8 +160,37 @@ def _process_add_trace_feedback_scores_batch_message(
for score_message in message.batch
]

LOGGER.debug("Score batch of traces request: %s", scores)
LOGGER.debug("Score batch of traces feedbacks scores request: %s", scores)

self._rest_client.traces.score_batch_of_traces(
scores=scores,
)

def _process_create_span_batch_message(
self, message: messages.CreateSpansBatchMessage
) -> None:
span_write_batch: List[span_write.SpanWrite] = []
for item in message.batch:
span_write_kwargs = {
"id": item.span_id,
"trace_id": item.trace_id,
"project_name": item.project_name,
"parent_span_id": item.parent_span_id,
"name": item.name,
"start_time": item.start_time,
"end_time": item.end_time,
"type": item.type,
"input": item.input,
"output": item.output,
"metadata": item.metadata,
"tags": item.tags,
"usage": item.usage,
}
cleaned_span_write_kwargs = dict_utils.remove_none_from_dict(
span_write_kwargs
)
cleaned_span_write_kwargs = jsonable_encoder(cleaned_span_write_kwargs)
span_write_batch.append(span_write.SpanWrite(**cleaned_span_write_kwargs))

LOGGER.debug("Create spans batch request: %s", span_write_batch)
self._rest_client.spans.create_spans(spans=span_write_batch)
5 changes: 5 additions & 0 deletions sdks/python/src/opik/message_processing/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,8 @@ class AddTraceFeedbackScoresBatchMessage(BaseMessage):
@dataclasses.dataclass
class AddSpanFeedbackScoresBatchMessage(BaseMessage):
batch: List[FeedbackScoreMessage]


@dataclasses.dataclass
class CreateSpansBatchMessage(BaseMessage):
batch: List[CreateSpanMessage]
29 changes: 26 additions & 3 deletions sdks/python/src/opik/message_processing/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,39 @@

from . import messages, queue_consumer
from .. import synchronization
from .batching import batch_manager


class Streamer:
def __init__(
self,
message_queue: "queue.Queue[Any]",
queue_consumers: List[queue_consumer.QueueConsumer],
batch_manager: Optional[batch_manager.BatchManager],
) -> None:
self._lock = threading.Lock()
self._lock = threading.RLock()
self._message_queue = message_queue
self._queue_consumers = queue_consumers
self._batch_manager = batch_manager

self._drain = False

self._start_queue_consumers()

if self._batch_manager is not None:
self._batch_manager.start()

def put(self, message: messages.BaseMessage) -> None:
with self._lock:
if not self._drain:
if self._drain:
return

if (
self._batch_manager is not None
and self._batch_manager.message_supports_batching(message)
):
self._batch_manager.process_message(message)
else:
self._message_queue.put(message)

def close(self, timeout: Optional[int]) -> bool:
Expand All @@ -31,15 +46,23 @@ def close(self, timeout: Optional[int]) -> bool:
with self._lock:
self._drain = True

if self._batch_manager is not None:
self._batch_manager.stop() # stopping causes adding remaining batch messages to the queue

self.flush(timeout)
self._close_queue_consumers()

return self._message_queue.empty()

def flush(self, timeout: Optional[int]) -> None:
if self._batch_manager is not None:
self._batch_manager.flush()

synchronization.wait_for_done(
check_function=lambda: (
self.workers_waiting() and self._message_queue.empty()
self.workers_waiting()
and self._message_queue.empty()
and (self._batch_manager is None or self._batch_manager.is_empty())
),
timeout=timeout,
sleep_time=0.1,
Expand Down
Loading

0 comments on commit 939accf

Please sign in to comment.