From 939accfacc43babf7cf6180b009c8e17028f29a8 Mon Sep 17 00:00:00 2001 From: Aliaksandr Kuzmik <98702584+alexkuzmik@users.noreply.github.com> Date: Mon, 30 Sep 2024 17:16:24 +0300 Subject: [PATCH] [OPIK-99] Implement batching for spans creation (#298) * Batch processing draft * Batching draft * [DRAFT] Update backend emulator for tests, update message processing to use batching * Fix lint errors * Implement _process_create_span_batch_message method in MessageProcessor * Remove assert, add debug log message * Add base batcher unit test * Make batch manager flush all batches if streamer.flush() was called, add tests for span batcher * Add tests for batch manager * Add test for flushing thread * Add one more test for batch manager * Fix lint errors * Fix lint errors * Fix bug when backend emulator didn't add feedback scores to traces * Fix lint errors * Rename flush_interval to flush_interval_seconds * Enable debug logs for e2e tests. Update debug messages in message_processors.py * Remove export statement from e2e workflow * Enable backend build. Enable debug logs for e2e tests * Update docker compose and e2e workflow file * Make batching disabled by default in Opik client. It is now enabled manually in Opik clients created under the hood of SDK * Add more unit tests for message processing * Add docstring for _use_batching parameter * Add missing _SECONDS suffix * Rename constant * Undone e2e tests infra changes --------- Co-authored-by: Andres Cruz --- .github/workflows/sdk-e2e-tests.yaml | 1 + deployment/docker-compose/docker-compose.yaml | 2 + .../src/opik/api_objects/opik_client.py | 13 +- .../message_processing/batching/__init__.py | 0 .../batching/base_batcher.py | 49 ++++++++ .../batching/batch_manager.py | 41 +++++++ .../batching/batch_manager_constuctors.py | 29 +++++ .../batching/create_span_message_batcher.py | 9 ++ .../batching/flushing_thread.py | 30 +++++ .../message_processing/message_processors.py | 37 +++++- .../src/opik/message_processing/messages.py | 5 + .../src/opik/message_processing/streamer.py | 29 ++++- .../streamer_constructors.py | 18 ++- sdks/python/tests/conftest.py | 1 + sdks/python/tests/e2e/test_experiment.py | 1 - sdks/python/tests/e2e/test_feedback_scores.py | 3 +- .../backend_emulator_message_processor.py | 36 ++++-- .../message_processing/batching/__init__.py | 0 .../batching/test_batch_manager.py | 116 ++++++++++++++++++ .../batching/test_flushing_thread.py | 34 +++++ .../batching/test_span_batcher.py | 109 ++++++++++++++++ .../test_message_streaming.py | 91 +++++++++++++- 22 files changed, 623 insertions(+), 31 deletions(-) create mode 100644 sdks/python/src/opik/message_processing/batching/__init__.py create mode 100644 sdks/python/src/opik/message_processing/batching/base_batcher.py create mode 100644 sdks/python/src/opik/message_processing/batching/batch_manager.py create mode 100644 sdks/python/src/opik/message_processing/batching/batch_manager_constuctors.py create mode 100644 sdks/python/src/opik/message_processing/batching/create_span_message_batcher.py create mode 100644 sdks/python/src/opik/message_processing/batching/flushing_thread.py create mode 100644 sdks/python/tests/unit/message_processing/batching/__init__.py create mode 100644 sdks/python/tests/unit/message_processing/batching/test_batch_manager.py create mode 100644 sdks/python/tests/unit/message_processing/batching/test_flushing_thread.py create mode 100644 sdks/python/tests/unit/message_processing/batching/test_span_batcher.py diff --git a/.github/workflows/sdk-e2e-tests.yaml b/.github/workflows/sdk-e2e-tests.yaml index d8b205ad05..1329d94d0b 100644 --- a/.github/workflows/sdk-e2e-tests.yaml +++ b/.github/workflows/sdk-e2e-tests.yaml @@ -56,6 +56,7 @@ jobs: run: | cd ${{ github.workspace }}/sdks/python export OPIK_URL_OVERRIDE=http://localhost:5173/api + export OPIK_CONSOLE_LOGGING_LEVEL=DEBUG pytest tests/e2e -vv - name: Keep BE log in case of failure diff --git a/deployment/docker-compose/docker-compose.yaml b/deployment/docker-compose/docker-compose.yaml index 382a67493b..bbe9b65016 100644 --- a/deployment/docker-compose/docker-compose.yaml +++ b/deployment/docker-compose/docker-compose.yaml @@ -62,6 +62,8 @@ services: build: context: ../../apps/opik-backend dockerfile: Dockerfile + args: + OPIK_VERSION: ${OPIK_VERSION:-latest} platform: linux/amd64 hostname: backend command: [ "bash", "-c", "./run_db_migrations.sh && ./entrypoint.sh" ] diff --git a/sdks/python/src/opik/api_objects/opik_client.py b/sdks/python/src/opik/api_objects/opik_client.py index 5d4f49a985..80fe2e4e85 100644 --- a/sdks/python/src/opik/api_objects/opik_client.py +++ b/sdks/python/src/opik/api_objects/opik_client.py @@ -29,6 +29,7 @@ def __init__( project_name: Optional[str] = None, workspace: Optional[str] = None, host: Optional[str] = None, + _use_batching: bool = False, ) -> None: """ Initialize an Opik object that can be used to log traces and spans manually to Opik server. @@ -37,6 +38,8 @@ def __init__( project_name: The name of the project. If not provided, traces and spans will be logged to the `Default Project`. workspace: The name of the workspace. If not provided, `default` will be used. host: The host URL for the Opik server. If not provided, it will default to `https://www.comet.com/opik/api`. + _use_batching: intended for internal usage in specific conditions only. + Enabling it is unsafe and can lead to data loss. Returns: None """ @@ -51,11 +54,16 @@ def __init__( base_url=config_.url_override, workers=config_.background_workers, api_key=config_.api_key, + use_batching=_use_batching, ) atexit.register(self.end, timeout=self._flush_timeout) def _initialize_streamer( - self, base_url: str, workers: int, api_key: Optional[str] + self, + base_url: str, + workers: int, + api_key: Optional[str], + use_batching: bool, ) -> None: httpx_client_ = httpx_client.get(workspace=self._workspace, api_key=api_key) self._rest_client = rest_api_client.OpikApi( @@ -66,6 +74,7 @@ def _initialize_streamer( self._streamer = streamer_constructors.construct_online_streamer( n_consumers=workers, rest_client=self._rest_client, + use_batching=use_batching, ) def trace( @@ -437,6 +446,6 @@ def get_span_content(self, id: str) -> span_public.SpanPublic: @functools.lru_cache() def get_client_cached() -> Opik: - client = Opik() + client = Opik(_use_batching=True) return client diff --git a/sdks/python/src/opik/message_processing/batching/__init__.py b/sdks/python/src/opik/message_processing/batching/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdks/python/src/opik/message_processing/batching/base_batcher.py b/sdks/python/src/opik/message_processing/batching/base_batcher.py new file mode 100644 index 0000000000..77926e50ec --- /dev/null +++ b/sdks/python/src/opik/message_processing/batching/base_batcher.py @@ -0,0 +1,49 @@ +import threading +import time +import abc + +from typing import List, Callable +from .. import messages + + +class BaseBatcher(abc.ABC): + def __init__( + self, + flush_callback: Callable[[messages.BaseMessage], None], + max_batch_size: int, + flush_interval_seconds: float, + ): + self._flush_interval_seconds: float = flush_interval_seconds + self._flush_callback: Callable[[messages.BaseMessage], None] = flush_callback + self._accumulated_messages: List[messages.BaseMessage] = [] + self._max_batch_size: int = max_batch_size + + self._last_time_flush_callback_called: float = time.time() + self._lock = threading.RLock() + + def add(self, message: messages.BaseMessage) -> None: + with self._lock: + self._accumulated_messages.append(message) + if len(self._accumulated_messages) == self._max_batch_size: + self.flush() + + def flush(self) -> None: + with self._lock: + if len(self._accumulated_messages) > 0: + batch_message = self._create_batch_from_accumulated_messages() + self._accumulated_messages = [] + + self._flush_callback(batch_message) + self._last_time_flush_callback_called = time.time() + + def is_ready_to_flush(self) -> bool: + return ( + time.time() - self._last_time_flush_callback_called + ) >= self._flush_interval_seconds + + def is_empty(self) -> bool: + with self._lock: + return len(self._accumulated_messages) == 0 + + @abc.abstractmethod + def _create_batch_from_accumulated_messages(self) -> messages.BaseMessage: ... diff --git a/sdks/python/src/opik/message_processing/batching/batch_manager.py b/sdks/python/src/opik/message_processing/batching/batch_manager.py new file mode 100644 index 0000000000..cb5f489a14 --- /dev/null +++ b/sdks/python/src/opik/message_processing/batching/batch_manager.py @@ -0,0 +1,41 @@ +from typing import Type, Dict +from .. import messages +from . import base_batcher +from . import flushing_thread + + +class BatchManager: + def __init__( + self, + message_to_batcher_mapping: Dict[ + Type[messages.BaseMessage], base_batcher.BaseBatcher + ], + ) -> None: + self._message_to_batcher_mapping = message_to_batcher_mapping + self._flushing_thread = flushing_thread.FlushingThread( + batchers=list(self._message_to_batcher_mapping.values()) + ) + + def start(self) -> None: + self._flushing_thread.start() + + def stop(self) -> None: + self._flushing_thread.close() + + def message_supports_batching(self, message: messages.BaseMessage) -> bool: + return message.__class__ in self._message_to_batcher_mapping + + def process_message(self, message: messages.BaseMessage) -> None: + self._message_to_batcher_mapping[type(message)].add(message) + + def is_empty(self) -> bool: + return all( + [ + batcher.is_empty() + for batcher in self._message_to_batcher_mapping.values() + ] + ) + + def flush(self) -> None: + for batcher in self._message_to_batcher_mapping.values(): + batcher.flush() diff --git a/sdks/python/src/opik/message_processing/batching/batch_manager_constuctors.py b/sdks/python/src/opik/message_processing/batching/batch_manager_constuctors.py new file mode 100644 index 0000000000..8d6e57668d --- /dev/null +++ b/sdks/python/src/opik/message_processing/batching/batch_manager_constuctors.py @@ -0,0 +1,29 @@ +import queue +from typing import Type, Dict + +from .. import messages + +from . import base_batcher +from . import create_span_message_batcher +from . import batch_manager + +CREATE_SPANS_MESSAGE_BATCHER_FLUSH_INTERVAL_SECONDS = 1.0 +CREATE_SPANS_MESSAGE_BATCHER_MAX_BATCH_SIZE = 1000 + + +def create_batch_manager(message_queue: queue.Queue) -> batch_manager.BatchManager: + create_span_message_batcher_ = create_span_message_batcher.CreateSpanMessageBatcher( + flush_interval_seconds=CREATE_SPANS_MESSAGE_BATCHER_FLUSH_INTERVAL_SECONDS, + max_batch_size=CREATE_SPANS_MESSAGE_BATCHER_MAX_BATCH_SIZE, + flush_callback=message_queue.put, + ) + + MESSAGE_TO_BATCHER_MAPPING: Dict[ + Type[messages.BaseMessage], base_batcher.BaseBatcher + ] = {messages.CreateSpanMessage: create_span_message_batcher_} + + batch_manager_ = batch_manager.BatchManager( + message_to_batcher_mapping=MESSAGE_TO_BATCHER_MAPPING + ) + + return batch_manager_ diff --git a/sdks/python/src/opik/message_processing/batching/create_span_message_batcher.py b/sdks/python/src/opik/message_processing/batching/create_span_message_batcher.py new file mode 100644 index 0000000000..f683d4c465 --- /dev/null +++ b/sdks/python/src/opik/message_processing/batching/create_span_message_batcher.py @@ -0,0 +1,9 @@ +from . import base_batcher +from .. import messages + + +class CreateSpanMessageBatcher(base_batcher.BaseBatcher): + def _create_batch_from_accumulated_messages( + self, + ) -> messages.CreateSpansBatchMessage: + return messages.CreateSpansBatchMessage(batch=self._accumulated_messages) # type: ignore diff --git a/sdks/python/src/opik/message_processing/batching/flushing_thread.py b/sdks/python/src/opik/message_processing/batching/flushing_thread.py new file mode 100644 index 0000000000..82de7460b5 --- /dev/null +++ b/sdks/python/src/opik/message_processing/batching/flushing_thread.py @@ -0,0 +1,30 @@ +import threading +import time +from typing import List + +from . import base_batcher + + +class FlushingThread(threading.Thread): + def __init__( + self, + batchers: List[base_batcher.BaseBatcher], + probe_interval_seconds: float = 0.1, + ) -> None: + threading.Thread.__init__(self, daemon=True) + self._batchers = batchers + self._probe_interval_seconds = probe_interval_seconds + self._closed = False + + def close(self) -> None: + for batcher in self._batchers: + batcher.flush() + + self._closed = True + + def run(self) -> None: + while not self._closed: + for batcher in self._batchers: + if batcher.is_ready_to_flush(): + batcher.flush() + time.sleep(self._probe_interval_seconds) diff --git a/sdks/python/src/opik/message_processing/message_processors.py b/sdks/python/src/opik/message_processing/message_processors.py index 9c3164854e..678a17317c 100644 --- a/sdks/python/src/opik/message_processing/message_processors.py +++ b/sdks/python/src/opik/message_processing/message_processors.py @@ -1,6 +1,6 @@ import abc import logging -from typing import Callable, Dict, Type +from typing import Callable, Dict, Type, List from opik import logging_messages from . import messages @@ -8,6 +8,7 @@ from .. import dict_utils from ..rest_api import client as rest_api_client from ..rest_api.types import feedback_score_batch_item +from ..rest_api.types import span_write LOGGER = logging.getLogger(__name__) @@ -29,6 +30,7 @@ def __init__(self, rest_client: rest_api_client.OpikApi): messages.UpdateTraceMessage: self._process_update_trace_message, # type: ignore messages.AddTraceFeedbackScoresBatchMessage: self._process_add_trace_feedback_scores_batch_message, # type: ignore messages.AddSpanFeedbackScoresBatchMessage: self._process_add_span_feedback_scores_batch_message, # type: ignore + messages.CreateSpansBatchMessage: self._process_create_span_batch_message, # type: ignore } def process(self, message: messages.BaseMessage) -> None: @@ -144,7 +146,7 @@ def _process_add_span_feedback_scores_batch_message( for score_message in message.batch ] - LOGGER.debug("Score batch of spans request: %s", scores) + LOGGER.debug("Score batch of spans feedbacks scores request: %s", scores) self._rest_client.spans.score_batch_of_spans( scores=scores, @@ -158,8 +160,37 @@ def _process_add_trace_feedback_scores_batch_message( for score_message in message.batch ] - LOGGER.debug("Score batch of traces request: %s", scores) + LOGGER.debug("Score batch of traces feedbacks scores request: %s", scores) self._rest_client.traces.score_batch_of_traces( scores=scores, ) + + def _process_create_span_batch_message( + self, message: messages.CreateSpansBatchMessage + ) -> None: + span_write_batch: List[span_write.SpanWrite] = [] + for item in message.batch: + span_write_kwargs = { + "id": item.span_id, + "trace_id": item.trace_id, + "project_name": item.project_name, + "parent_span_id": item.parent_span_id, + "name": item.name, + "start_time": item.start_time, + "end_time": item.end_time, + "type": item.type, + "input": item.input, + "output": item.output, + "metadata": item.metadata, + "tags": item.tags, + "usage": item.usage, + } + cleaned_span_write_kwargs = dict_utils.remove_none_from_dict( + span_write_kwargs + ) + cleaned_span_write_kwargs = jsonable_encoder(cleaned_span_write_kwargs) + span_write_batch.append(span_write.SpanWrite(**cleaned_span_write_kwargs)) + + LOGGER.debug("Create spans batch request: %s", span_write_batch) + self._rest_client.spans.create_spans(spans=span_write_batch) diff --git a/sdks/python/src/opik/message_processing/messages.py b/sdks/python/src/opik/message_processing/messages.py index aee51e5df7..11b2e20edc 100644 --- a/sdks/python/src/opik/message_processing/messages.py +++ b/sdks/python/src/opik/message_processing/messages.py @@ -88,3 +88,8 @@ class AddTraceFeedbackScoresBatchMessage(BaseMessage): @dataclasses.dataclass class AddSpanFeedbackScoresBatchMessage(BaseMessage): batch: List[FeedbackScoreMessage] + + +@dataclasses.dataclass +class CreateSpansBatchMessage(BaseMessage): + batch: List[CreateSpanMessage] diff --git a/sdks/python/src/opik/message_processing/streamer.py b/sdks/python/src/opik/message_processing/streamer.py index 24d68ca61d..4c5c0dac1c 100644 --- a/sdks/python/src/opik/message_processing/streamer.py +++ b/sdks/python/src/opik/message_processing/streamer.py @@ -4,6 +4,7 @@ from . import messages, queue_consumer from .. import synchronization +from .batching import batch_manager class Streamer: @@ -11,17 +12,31 @@ def __init__( self, message_queue: "queue.Queue[Any]", queue_consumers: List[queue_consumer.QueueConsumer], + batch_manager: Optional[batch_manager.BatchManager], ) -> None: - self._lock = threading.Lock() + self._lock = threading.RLock() self._message_queue = message_queue self._queue_consumers = queue_consumers + self._batch_manager = batch_manager + self._drain = False self._start_queue_consumers() + if self._batch_manager is not None: + self._batch_manager.start() + def put(self, message: messages.BaseMessage) -> None: with self._lock: - if not self._drain: + if self._drain: + return + + if ( + self._batch_manager is not None + and self._batch_manager.message_supports_batching(message) + ): + self._batch_manager.process_message(message) + else: self._message_queue.put(message) def close(self, timeout: Optional[int]) -> bool: @@ -31,15 +46,23 @@ def close(self, timeout: Optional[int]) -> bool: with self._lock: self._drain = True + if self._batch_manager is not None: + self._batch_manager.stop() # stopping causes adding remaining batch messages to the queue + self.flush(timeout) self._close_queue_consumers() return self._message_queue.empty() def flush(self, timeout: Optional[int]) -> None: + if self._batch_manager is not None: + self._batch_manager.flush() + synchronization.wait_for_done( check_function=lambda: ( - self.workers_waiting() and self._message_queue.empty() + self.workers_waiting() + and self._message_queue.empty() + and (self._batch_manager is None or self._batch_manager.is_empty()) ), timeout=timeout, sleep_time=0.1, diff --git a/sdks/python/src/opik/message_processing/streamer_constructors.py b/sdks/python/src/opik/message_processing/streamer_constructors.py index dde6956aba..7840720b04 100644 --- a/sdks/python/src/opik/message_processing/streamer_constructors.py +++ b/sdks/python/src/opik/message_processing/streamer_constructors.py @@ -3,19 +3,23 @@ from . import queue_consumer, message_processors, streamer from ..rest_api import client as rest_api_client +from .batching import batch_manager_constuctors def construct_online_streamer( - rest_client: rest_api_client.OpikApi, n_consumers: int = 1 + rest_client: rest_api_client.OpikApi, + use_batching: bool, + n_consumers: int = 1, ) -> streamer.Streamer: message_processor = message_processors.MessageSender(rest_client=rest_client) - return construct_streamer(message_processor, n_consumers) + return construct_streamer(message_processor, n_consumers, use_batching) def construct_streamer( message_processor: message_processors.BaseMessageProcessor, n_consumers: int, + use_batching: bool, ) -> streamer.Streamer: message_queue: "queue.Queue[Any]" = queue.Queue() @@ -28,8 +32,16 @@ def construct_streamer( for i in range(n_consumers) ] + batch_manager = ( + batch_manager_constuctors.create_batch_manager(message_queue) + if use_batching + else None + ) + streamer_ = streamer.Streamer( - message_queue=message_queue, queue_consumers=queue_consumers + message_queue=message_queue, + queue_consumers=queue_consumers, + batch_manager=batch_manager, ) return streamer_ diff --git a/sdks/python/tests/conftest.py b/sdks/python/tests/conftest.py index 26358f8ba9..5410df9ead 100644 --- a/sdks/python/tests/conftest.py +++ b/sdks/python/tests/conftest.py @@ -28,6 +28,7 @@ def fake_streamer(): streamer = streamer_constructors.construct_streamer( message_processor=fake_message_processor_, n_consumers=1, + use_batching=True, ) yield streamer, fake_message_processor_ diff --git a/sdks/python/tests/e2e/test_experiment.py b/sdks/python/tests/e2e/test_experiment.py index 5baa6d8f50..d167fb023f 100644 --- a/sdks/python/tests/e2e/test_experiment.py +++ b/sdks/python/tests/e2e/test_experiment.py @@ -59,7 +59,6 @@ def task(item: dataset_item.DatasetItem): traces_amount=3, # one trace per dataset item feedback_scores_amount=1, # an average value of all Equals metric scores ) - # TODO: check more content of the experiment # # EXPECTED_DATASET_ITEMS = [ diff --git a/sdks/python/tests/e2e/test_feedback_scores.py b/sdks/python/tests/e2e/test_feedback_scores.py index 0a0f295d37..c44efe6787 100644 --- a/sdks/python/tests/e2e/test_feedback_scores.py +++ b/sdks/python/tests/e2e/test_feedback_scores.py @@ -24,8 +24,6 @@ def test_feedbacks_are_logged_via_trace_and_span__happyflow(opik_client: opik.Op span.log_feedback_score( "span-metric-2", value=0.25, category_name="category-4", reason="some-reason-4" ) - span.end() - trace.end() opik_client.flush() @@ -62,6 +60,7 @@ def test_feedbacks_are_logged_via_trace_and_span__happyflow(opik_client: opik.Op "reason": "some-reason-4", }, ] + verifiers.verify_trace( opik_client=opik_client, trace_id=trace.id, diff --git a/sdks/python/tests/testlib/backend_emulator_message_processor.py b/sdks/python/tests/testlib/backend_emulator_message_processor.py index d943c2b9cf..d24e4ce883 100644 --- a/sdks/python/tests/testlib/backend_emulator_message_processor.py +++ b/sdks/python/tests/testlib/backend_emulator_message_processor.py @@ -21,6 +21,12 @@ def __init__(self) -> None: self._span_to_parent_span: Dict[str, Optional[str]] = {} self._span_to_trace: Dict[str, Optional[str]] = {} + self._trace_to_feedback_scores: Dict[str, List[FeedbackScoreModel]] = ( + collections.defaultdict(list) + ) + self._span_to_feedback_scores: Dict[str, List[FeedbackScoreModel]] = ( + collections.defaultdict(list) + ) @property def trace_trees(self): @@ -41,6 +47,9 @@ def trace_trees(self): trace.spans.append(self._observations[span_id]) trace.spans.sort(key=lambda x: x.start_time) + for trace in self._trace_trees: + trace.feedback_scores = self._trace_to_feedback_scores[trace.id] + self._trace_trees.sort(key=lambda x: x.start_time) return self._trace_trees @@ -59,6 +68,11 @@ def span_trees(self): parent_span.spans.append(self._observations[span_id]) parent_span.spans.sort(key=lambda x: x.start_time) + all_span_ids = self._span_to_trace + for span_id in all_span_ids: + span = self._observations[span_id] + span.feedback_scores = self._span_to_feedback_scores[span_id] + self._span_trees.sort(key=lambda x: x.start_time) return self._span_trees @@ -99,6 +113,9 @@ def process(self, message: messages.BaseMessage) -> None: self._span_to_trace[span.id] = message.trace_id self._observations[message.span_id] = span + elif isinstance(message, messages.CreateSpansBatchMessage): + for item in message.batch: + self.process(item) elif isinstance(message, messages.UpdateSpanMessage): span: SpanModel = self._observations[message.span_id] span.output = message.output @@ -110,11 +127,6 @@ def process(self, message: messages.BaseMessage) -> None: current_trace.end_time = message.end_time elif isinstance(message, messages.AddSpanFeedbackScoresBatchMessage): for feedback_score_message in message.batch: - span_or_trace = self._observations[feedback_score_message.id] - if not isinstance(span_or_trace, SpanModel): - continue - - span: SpanModel = span_or_trace feedback_model = FeedbackScoreModel( id=feedback_score_message.id, name=feedback_score_message.name, @@ -122,15 +134,11 @@ def process(self, message: messages.BaseMessage) -> None: category_name=feedback_score_message.category_name, reason=feedback_score_message.reason, ) - span.feedback_scores.append(feedback_model) + self._span_to_feedback_scores[feedback_score_message.id].append( + feedback_model + ) elif isinstance(message, messages.AddTraceFeedbackScoresBatchMessage): for feedback_score_message in message.batch: - span_or_trace = self._observations[feedback_score_message.id] - if not isinstance(span_or_trace, TraceModel): - continue - - trace: TraceModel = span_or_trace - feedback_model = FeedbackScoreModel( id=feedback_score_message.id, name=feedback_score_message.name, @@ -138,7 +146,9 @@ def process(self, message: messages.BaseMessage) -> None: category_name=feedback_score_message.category_name, reason=feedback_score_message.reason, ) - trace.feedback_scores.append(feedback_model) + self._trace_to_feedback_scores[feedback_score_message.id].append( + feedback_model + ) self.processed_messages.append(message) diff --git a/sdks/python/tests/unit/message_processing/batching/__init__.py b/sdks/python/tests/unit/message_processing/batching/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdks/python/tests/unit/message_processing/batching/test_batch_manager.py b/sdks/python/tests/unit/message_processing/batching/test_batch_manager.py new file mode 100644 index 0000000000..2d4400072c --- /dev/null +++ b/sdks/python/tests/unit/message_processing/batching/test_batch_manager.py @@ -0,0 +1,116 @@ +import mock +import time + +from opik.message_processing import messages +from opik.message_processing.batching import batch_manager +from opik.message_processing.batching import create_span_message_batcher + +NOT_USED = None + + +def test_batch_manager__messages_processing_methods(): + integers_batcher = mock.Mock() + strings_batcher = mock.Mock() + MESSAGE_BATCHERS = { + int: integers_batcher, + str: strings_batcher, + } + + tested = batch_manager.BatchManager(MESSAGE_BATCHERS) + + assert tested.message_supports_batching("a-string") + assert tested.message_supports_batching(42) + assert not tested.message_supports_batching(float(42.5)) + + tested.process_message(42) + integers_batcher.add.assert_called_once_with(42) + + tested.process_message("a-string") + strings_batcher.add.assert_called_once_with("a-string") + + +def test_batch_manager__all_batchers_are_empty__batch_manager_is_empty(): + integers_batcher = mock.Mock() + integers_batcher.is_empty.return_value = True + strings_batcher = mock.Mock() + strings_batcher.is_empty.return_value = True + + MESSAGE_BATCHERS = { + int: integers_batcher, + str: strings_batcher, + } + + tested = batch_manager.BatchManager(MESSAGE_BATCHERS) + + assert tested.is_empty() + strings_batcher.is_empty.assert_called_once() + integers_batcher.is_empty.assert_called_once() + + +def test_batch_manager__at_least_one_batcher_is_not_empty__batch_manager_is_not_empty(): + integers_batcher = mock.Mock() + integers_batcher.is_empty.return_value = True + strings_batcher = mock.Mock() + strings_batcher.is_empty.return_value = False + + MESSAGE_BATCHERS = { + int: integers_batcher, + str: strings_batcher, + } + + tested = batch_manager.BatchManager(MESSAGE_BATCHERS) + + assert not tested.is_empty() + strings_batcher.is_empty.assert_called_once() + integers_batcher.is_empty.assert_called_once() + + +def test_batch_manager__flush_is_called__all_batchers_are_flushed(): + integers_batcher = mock.Mock() + strings_batcher = mock.Mock() + + MESSAGE_BATCHERS = { + int: integers_batcher, + str: strings_batcher, + } + + tested = batch_manager.BatchManager(MESSAGE_BATCHERS) + tested.flush() + integers_batcher.flush.assert_called_once() + strings_batcher.flush.assert_called_once() + + +def test_batch_manager__start_and_stop_were_called__accumulated_data_is_flushed(): + flush_callback = mock.Mock() + + CREATE_SPAN_MESSAGE = messages.CreateSpanMessage( + span_id=NOT_USED, + trace_id=NOT_USED, + parent_span_id=NOT_USED, + project_name=NOT_USED, + start_time=NOT_USED, + end_time=NOT_USED, + name=NOT_USED, + input=NOT_USED, + output=NOT_USED, + metadata=NOT_USED, + tags=NOT_USED, + type=NOT_USED, + usage=NOT_USED, + ) + + example_span_batcher = create_span_message_batcher.CreateSpanMessageBatcher( + flush_callback=flush_callback, max_batch_size=42, flush_interval_seconds=0.1 + ) + tested = batch_manager.BatchManager( + {messages.CreateSpanMessage: example_span_batcher} + ) + + tested.start() + time.sleep(0.1) + flush_callback.assert_not_called() + tested.process_message(CREATE_SPAN_MESSAGE) + tested.stop() + flush_callback.assert_called_once_with( + messages.CreateSpansBatchMessage(batch=[CREATE_SPAN_MESSAGE]) + ) diff --git a/sdks/python/tests/unit/message_processing/batching/test_flushing_thread.py b/sdks/python/tests/unit/message_processing/batching/test_flushing_thread.py new file mode 100644 index 0000000000..097b323d2f --- /dev/null +++ b/sdks/python/tests/unit/message_processing/batching/test_flushing_thread.py @@ -0,0 +1,34 @@ +import time +import mock +from opik.message_processing.batching import ( + flushing_thread, + create_span_message_batcher, +) + + +def test_flushing_thread__batcher_is_flushed__every_time_flush_interval_time_passes(): + flush_callback = mock.Mock() + FLUSH_INTERVAL = 0.2 + very_big_batch_size = float("inf") + batcher = create_span_message_batcher.CreateSpanMessageBatcher( + flush_callback=flush_callback, + max_batch_size=very_big_batch_size, + flush_interval_seconds=FLUSH_INTERVAL, + ) + tested = flushing_thread.FlushingThread(batchers=[batcher]) + + tested.start() + batcher.add("some-value-to-make-batcher-not-empty") + flush_callback.assert_not_called() + + time.sleep(FLUSH_INTERVAL + 0.01) + # flush interval has passed after batcher was created, batcher is ready to be flushed + # (0.1 is added because thread probation interval is 0.1 and it's already made it first check) + flush_callback.assert_called_once() + + flush_callback.reset_mock() + + batcher.add("some-value-to-make-batcher-not-empty") + time.sleep(FLUSH_INTERVAL) + # flush interval has passed after previous flush, batcher is ready to be flushed again + flush_callback.assert_called_once() diff --git a/sdks/python/tests/unit/message_processing/batching/test_span_batcher.py b/sdks/python/tests/unit/message_processing/batching/test_span_batcher.py new file mode 100644 index 0000000000..f7770a41a1 --- /dev/null +++ b/sdks/python/tests/unit/message_processing/batching/test_span_batcher.py @@ -0,0 +1,109 @@ +import mock +import time + +from opik.message_processing.batching import create_span_message_batcher +from opik.message_processing import messages + +NOT_USED = None + + +def test_create_span_message_batcher__exactly_max_batch_size_reached__batch_is_flushed(): + flush_callback = mock.Mock() + + MAX_BATCH_SIZE = 5 + + batcher = create_span_message_batcher.CreateSpanMessageBatcher( + max_batch_size=MAX_BATCH_SIZE, + flush_callback=flush_callback, + flush_interval_seconds=NOT_USED, + ) + + assert batcher.is_empty() + span_messages = [ + 1, + 2, + 3, + 4, + 5, + ] # batcher doesn't care about the content, it doesn't work + + for span_message in span_messages: + batcher.add(span_message) + assert batcher.is_empty() + + flush_callback.assert_called_once_with( + messages.CreateSpansBatchMessage(batch=[1, 2, 3, 4, 5]) + ) + + +def test_create_span_message_batcher__more_than_max_batch_size_items_added__one_batch_flushed__some_data_remains_in_batcher(): + flush_callback = mock.Mock() + + MAX_BATCH_SIZE = 5 + + batcher = create_span_message_batcher.CreateSpanMessageBatcher( + max_batch_size=MAX_BATCH_SIZE, + flush_callback=flush_callback, + flush_interval_seconds=NOT_USED, + ) + + assert batcher.is_empty() + span_messages = [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + ] # batcher doesn't care about the content, it doesn't work + + for span_message in span_messages: + batcher.add(span_message) + + assert not batcher.is_empty() + flush_callback.assert_called_once_with( + messages.CreateSpansBatchMessage(batch=[1, 2, 3, 4, 5]) + ) + flush_callback.reset_mock() + + batcher.flush() + flush_callback.assert_called_once_with( + messages.CreateSpansBatchMessage(batch=[6, 7]) + ) + + +def test_create_span_message_batcher__batcher_doesnt_have_items__flush_is_called__flush_callback_NOT_called(): + flush_callback = mock.Mock() + + MAX_BATCH_SIZE = 5 + + batcher = create_span_message_batcher.CreateSpanMessageBatcher( + max_batch_size=MAX_BATCH_SIZE, + flush_callback=flush_callback, + flush_interval_seconds=NOT_USED, + ) + + assert batcher.is_empty() + batcher.flush() + flush_callback.assert_not_called() + + +def test_create_span_message_batcher__ready_to_flush_returns_True__is_flush_interval_passed(): + flush_callback = mock.Mock() + + MAX_BATCH_SIZE = 5 + FLUSH_INTERVAL = 0.1 + + batcher = create_span_message_batcher.CreateSpanMessageBatcher( + max_batch_size=MAX_BATCH_SIZE, + flush_callback=flush_callback, + flush_interval_seconds=FLUSH_INTERVAL, + ) + assert not batcher.is_ready_to_flush() + time.sleep(0.1) + assert batcher.is_ready_to_flush() + batcher.flush() + assert not batcher.is_ready_to_flush() + time.sleep(0.1) + assert batcher.is_ready_to_flush() diff --git a/sdks/python/tests/unit/message_processing/test_message_streaming.py b/sdks/python/tests/unit/message_processing/test_message_streaming.py index 9662fbcc6f..22c7c54a94 100644 --- a/sdks/python/tests/unit/message_processing/test_message_streaming.py +++ b/sdks/python/tests/unit/message_processing/test_message_streaming.py @@ -1,14 +1,19 @@ import pytest import mock from opik.message_processing import streamer_constructors +from opik.message_processing import messages + +NOT_USED = None @pytest.fixture -def streamer_and_mock_message_processor(): +def batched_streamer_and_mock_message_processor(): try: mock_message_processor = mock.Mock() tested = streamer_constructors.construct_streamer( - message_processor=mock_message_processor, n_consumers=1 + message_processor=mock_message_processor, + n_consumers=1, + use_batching=True, ) yield tested, mock_message_processor @@ -16,8 +21,8 @@ def streamer_and_mock_message_processor(): tested.close(timeout=5) -def test_streamer__happyflow(streamer_and_mock_message_processor): - tested, mock_message_processor = streamer_and_mock_message_processor +def test_streamer__happyflow(batched_streamer_and_mock_message_processor): + tested, mock_message_processor = batched_streamer_and_mock_message_processor tested.put("message-1") tested.put("message-2") @@ -26,3 +31,81 @@ def test_streamer__happyflow(streamer_and_mock_message_processor): mock_message_processor.process.assert_has_calls( [mock.call("message-1"), mock.call("message-2")] ) + + +def test_streamer__batching_disabled__messages_that_support_batching_are_processed_independently(): + mock_message_processor = mock.Mock() + try: + tested = streamer_constructors.construct_streamer( + message_processor=mock_message_processor, + n_consumers=1, + use_batching=False, + ) + + CREATE_SPAN_MESSAGE = messages.CreateSpanMessage( + span_id=NOT_USED, + trace_id=NOT_USED, + parent_span_id=NOT_USED, + project_name=NOT_USED, + start_time=NOT_USED, + end_time=NOT_USED, + name=NOT_USED, + input=NOT_USED, + output=NOT_USED, + metadata=NOT_USED, + tags=NOT_USED, + type=NOT_USED, + usage=NOT_USED, + ) + + tested.put(CREATE_SPAN_MESSAGE) + tested.put(CREATE_SPAN_MESSAGE) + tested.put(CREATE_SPAN_MESSAGE) + tested.flush(0.1) + + mock_message_processor.process.assert_has_calls( + [ + mock.call(CREATE_SPAN_MESSAGE), + mock.call(CREATE_SPAN_MESSAGE), + mock.call(CREATE_SPAN_MESSAGE), + ] + ) + finally: + tested.close(timeout=1) + + +def test_streamer__batching_enabled__messages_that_support_batching_are_processed_in_batch(): + mock_message_processor = mock.Mock() + try: + tested = streamer_constructors.construct_streamer( + message_processor=mock_message_processor, + n_consumers=1, + use_batching=True, + ) + + CREATE_SPAN_MESSAGE = messages.CreateSpanMessage( + span_id=NOT_USED, + trace_id=NOT_USED, + parent_span_id=NOT_USED, + project_name=NOT_USED, + start_time=NOT_USED, + end_time=NOT_USED, + name=NOT_USED, + input=NOT_USED, + output=NOT_USED, + metadata=NOT_USED, + tags=NOT_USED, + type=NOT_USED, + usage=NOT_USED, + ) + + tested.put(CREATE_SPAN_MESSAGE) + tested.put(CREATE_SPAN_MESSAGE) + tested.put(CREATE_SPAN_MESSAGE) + tested.flush(1.1) + + mock_message_processor.process.assert_called_once_with( + messages.CreateSpansBatchMessage(batch=[CREATE_SPAN_MESSAGE] * 3) + ) + finally: + tested.close(timeout=1)