Skip to content

Commit

Permalink
Make batch processor fork aware and reinit when needed (#2242)
Browse files Browse the repository at this point in the history
Since 3.7 python provides register_at_fork which can be used to make our batch processor fork-safe.
  • Loading branch information
srikanthccv authored Nov 5, 2021
1 parent 41c5f99 commit 29e4bab
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#2153](https://github.com/open-telemetry/opentelemetry-python/pull/2153))
- Add metrics API
([#1887](https://github.com/open-telemetry/opentelemetry-python/pull/1887))
- Make batch processor fork aware and reinit when needed
([#2242](https://github.com/open-telemetry/opentelemetry-python/pull/2242))

## [1.6.2-0.25b2](https://github.com/open-telemetry/opentelemetry-python/releases/tag/v1.6.2-0.25b2) - 2021-10-19

Expand Down
12 changes: 12 additions & 0 deletions opentelemetry-sdk/src/opentelemetry/sdk/_logs/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import collections
import enum
import logging
import os
import sys
import threading
from os import linesep
Expand Down Expand Up @@ -154,6 +155,17 @@ def __init__(
None
] * self._max_export_batch_size # type: List[Optional[LogData]]
self._worker_thread.start()
# Only available in *nix since py37.
if hasattr(os, "register_at_fork"):
os.register_at_fork(
after_in_child=self._at_fork_reinit
) # pylint: disable=protected-access

def _at_fork_reinit(self):
self._condition = threading.Condition(threading.Lock())
self._queue.clear()
self._worker_thread = threading.Thread(target=self.worker, daemon=True)
self._worker_thread.start()

def worker(self):
timeout = self._schedule_delay_millis / 1e3
Expand Down
17 changes: 17 additions & 0 deletions opentelemetry-sdk/src/opentelemetry/sdk/trace/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import collections
import logging
import os
import sys
import threading
import typing
Expand Down Expand Up @@ -197,6 +198,11 @@ def __init__(
None
] * self.max_export_batch_size # type: typing.List[typing.Optional[Span]]
self.worker_thread.start()
# Only available in *nix since py37.
if hasattr(os, "register_at_fork"):
os.register_at_fork(
after_in_child=self._at_fork_reinit
) # pylint: disable=protected-access

def on_start(
self, span: Span, parent_context: typing.Optional[Context] = None
Expand All @@ -220,6 +226,17 @@ def on_end(self, span: ReadableSpan) -> None:
with self.condition:
self.condition.notify()

def _at_fork_reinit(self):
self.condition = threading.Condition(threading.Lock())
self.queue.clear()

# worker_thread is local to a process, only the thread that issued fork continues
# to exist. A new worker thread must be started in child process.
self.worker_thread = threading.Thread(
name="OtelBatchSpanProcessor", target=self.worker, daemon=True
)
self.worker_thread.start()

def worker(self):
timeout = self.schedule_delay_millis / 1e3
flush_request = None # type: typing.Optional[_FlushRequest]
Expand Down
53 changes: 52 additions & 1 deletion opentelemetry-sdk/tests/logs/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

# pylint: disable=protected-access
import logging
import multiprocessing
import os
import sys
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -38,6 +40,7 @@
from opentelemetry.sdk._logs.severity import SeverityNumber
from opentelemetry.sdk.resources import Resource as SDKResource
from opentelemetry.sdk.util.instrumentation import InstrumentationInfo
from opentelemetry.test.concurrency_test import ConcurrencyTestBase
from opentelemetry.trace import TraceFlags
from opentelemetry.trace.span import INVALID_SPAN_CONTEXT

Expand Down Expand Up @@ -158,7 +161,7 @@ def test_simple_log_processor_shutdown(self):
self.assertEqual(len(finished_logs), 0)


class TestBatchLogProcessor(unittest.TestCase):
class TestBatchLogProcessor(ConcurrencyTestBase):
def test_emit_call_log_record(self):
exporter = InMemoryLogExporter()
log_processor = Mock(wraps=BatchLogProcessor(exporter))
Expand Down Expand Up @@ -269,6 +272,54 @@ def bulk_log_and_flush(num_logs):
finished_logs = exporter.get_finished_logs()
self.assertEqual(len(finished_logs), 2415)

@unittest.skipUnless(
hasattr(os, "fork") and sys.version_info >= (3, 7),
"needs *nix and minor version 7 or later",
)
def test_batch_log_processor_fork(self):
# pylint: disable=invalid-name
exporter = InMemoryLogExporter()
log_processor = BatchLogProcessor(
exporter,
max_export_batch_size=64,
schedule_delay_millis=10,
)
provider = LogEmitterProvider()
provider.add_log_processor(log_processor)

emitter = provider.get_log_emitter(__name__)
logger = logging.getLogger("test-fork")
logger.addHandler(OTLPHandler(log_emitter=emitter))

logger.critical("yolo")
time.sleep(0.5) # give some time for the exporter to upload

self.assertTrue(log_processor.force_flush())
self.assertEqual(len(exporter.get_finished_logs()), 1)
exporter.clear()

multiprocessing.set_start_method("fork")

def child(conn):
def _target():
logger.critical("Critical message child")

self.run_with_many_threads(_target, 100)

time.sleep(0.5)

logs = exporter.get_finished_logs()
conn.send(len(logs) == 100)
conn.close()

parent_conn, child_conn = multiprocessing.Pipe()
p = multiprocessing.Process(target=child, args=(child_conn,))
p.start()
self.assertTrue(parent_conn.recv())
p.join()

log_processor.shutdown()


class TestConsoleExporter(unittest.TestCase):
def test_export(self): # pylint: disable=no-self-use
Expand Down
62 changes: 61 additions & 1 deletion opentelemetry-sdk/tests/trace/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import multiprocessing
import os
import sys
import threading
import time
import unittest
Expand All @@ -30,6 +32,10 @@
OTEL_BSP_SCHEDULE_DELAY,
)
from opentelemetry.sdk.trace import export
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
InMemorySpanExporter,
)
from opentelemetry.test.concurrency_test import ConcurrencyTestBase


class MySpanExporter(export.SpanExporter):
Expand Down Expand Up @@ -157,7 +163,7 @@ def _create_start_and_end_span(name, span_processor):
span.end()


class TestBatchSpanProcessor(unittest.TestCase):
class TestBatchSpanProcessor(ConcurrencyTestBase):
@mock.patch.dict(
"os.environ",
{
Expand Down Expand Up @@ -356,6 +362,60 @@ def test_batch_span_processor_not_sampled(self):
self.assertEqual(len(spans_names_list), 0)
span_processor.shutdown()

def _check_fork_trace(self, exporter, expected):
time.sleep(0.5) # give some time for the exporter to upload spans
spans = exporter.get_finished_spans()
for span in spans:
self.assertIn(span.name, expected)

@unittest.skipUnless(
hasattr(os, "fork") and sys.version_info >= (3, 7),
"needs *nix and minor version 7 or later",
)
def test_batch_span_processor_fork(self):
# pylint: disable=invalid-name
tracer_provider = trace.TracerProvider()
tracer = tracer_provider.get_tracer(__name__)

exporter = InMemorySpanExporter()
span_processor = export.BatchSpanProcessor(
exporter,
max_queue_size=256,
max_export_batch_size=64,
schedule_delay_millis=10,
)
tracer_provider.add_span_processor(span_processor)
with tracer.start_as_current_span("foo"):
pass
time.sleep(0.5) # give some time for the exporter to upload spans

self.assertTrue(span_processor.force_flush())
self.assertEqual(len(exporter.get_finished_spans()), 1)
exporter.clear()

def child(conn):
def _target():
with tracer.start_as_current_span("span") as s:
s.set_attribute("i", "1")
with tracer.start_as_current_span("temp"):
pass

self.run_with_many_threads(_target, 100)

time.sleep(0.5)

spans = exporter.get_finished_spans()
conn.send(len(spans) == 200)
conn.close()

parent_conn, child_conn = multiprocessing.Pipe()
p = multiprocessing.Process(target=child, args=(child_conn,))
p.start()
self.assertTrue(parent_conn.recv())
p.join()

span_processor.shutdown()

def test_batch_span_processor_scheduled_delay(self):
"""Test that spans are exported each schedule_delay_millis"""
spans_names_list = []
Expand Down

0 comments on commit 29e4bab

Please sign in to comment.