Skip to content

Commit

Permalink
improves trackers thread pool (#1340)
Browse files Browse the repository at this point in the history
* adds mandatory prefix to managed pool

* starts and stops pools on init, not lazy

* simplifies at_exit, adds end2end anon telemetry test

* fixes platform test

* skips windows forked tests
  • Loading branch information
rudolfix authored May 8, 2024
1 parent dd3ae4b commit 80684d1
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 44 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/test_common.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ jobs:
if: runner.os != 'Windows'
name: Run pipeline smoke tests with minimum deps Linux/MAC
- run: |
poetry run pytest tests/pipeline/test_pipeline.py tests/pipeline/test_import_export_schema.py
poetry run pytest tests/pipeline/test_pipeline.py tests/pipeline/test_import_export_schema.py -m "not forked"
if: runner.os == 'Windows'
name: Run smoke tests with minimum deps Windows
shell: cmd
Expand All @@ -117,7 +117,7 @@ jobs:
if: runner.os != 'Windows'
name: Run pipeline tests with pyarrow but no pandas installed
- run: |
poetry run pytest tests/pipeline/test_pipeline_extra.py -k arrow
poetry run pytest tests/pipeline/test_pipeline_extra.py -k arrow -m "not forked"
if: runner.os == 'Windows'
name: Run pipeline tests with pyarrow but no pandas installed Windows
shell: cmd
Expand All @@ -130,7 +130,7 @@ jobs:
if: runner.os != 'Windows'
name: Run extract and pipeline tests Linux/MAC
- run: |
poetry run pytest tests/extract tests/pipeline tests/libs tests/cli/common tests/destinations
poetry run pytest tests/extract tests/pipeline tests/libs tests/cli/common tests/destinations -m "not forked"
if: runner.os == 'Windows'
name: Run extract tests Windows
shell: cmd
Expand Down
10 changes: 5 additions & 5 deletions dlt/common/managed_thread_pool.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from typing import Optional

import atexit
from concurrent.futures import ThreadPoolExecutor


class ManagedThreadPool:
def __init__(self, max_workers: int = 1) -> None:
def __init__(self, thread_name_prefix: str, max_workers: int = 1) -> None:
self.thread_name_prefix = thread_name_prefix
self._max_workers = max_workers
self._thread_pool: Optional[ThreadPoolExecutor] = None

def _create_thread_pool(self) -> None:
assert not self._thread_pool, "Thread pool already created"
self._thread_pool = ThreadPoolExecutor(self._max_workers)
# flush pool on exit
atexit.register(self.stop)
self._thread_pool = ThreadPoolExecutor(
self._max_workers, thread_name_prefix=self.thread_name_prefix
)

@property
def thread_pool(self) -> ThreadPoolExecutor:
Expand Down
51 changes: 26 additions & 25 deletions dlt/common/runtime/anon_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,28 @@

# several code fragments come from https://github.com/RasaHQ/rasa/blob/main/rasa/telemetry.py
import os

import atexit
import base64
import requests
from typing import Literal, Optional
from dlt.common.configuration.paths import get_dlt_data_dir
from requests import Session

from dlt.common import logger
from dlt.common.managed_thread_pool import ManagedThreadPool
from dlt.common.configuration.specs import RunConfiguration
from dlt.common.configuration.paths import get_dlt_data_dir
from dlt.common.runtime.exec_info import get_execution_context, TExecutionContext
from dlt.common.typing import DictStrAny, StrAny
from dlt.common.utils import uniq_id

from dlt.version import __version__

TEventCategory = Literal["pipeline", "command", "helper"]

_THREAD_POOL: ManagedThreadPool = ManagedThreadPool(1)
_SESSION: requests.Session = None
_THREAD_POOL: ManagedThreadPool = None
_WRITE_KEY: str = None
_REQUEST_TIMEOUT = (1.0, 1.0) # short connect & send timeouts
_ANON_TRACKER_ENDPOINT: str = None
_TRACKER_CONTEXT: TExecutionContext = None
requests: Session = None


def init_anon_tracker(config: RunConfiguration) -> None:
Expand All @@ -36,12 +35,20 @@ def init_anon_tracker(config: RunConfiguration) -> None:
config.dlthub_telemetry_segment_write_key
), "dlthub_telemetry_segment_write_key not present in RunConfiguration"

global _WRITE_KEY, _SESSION, _ANON_TRACKER_ENDPOINT
# lazily import requests to avoid binding config before initialization
global requests
from dlt.sources.helpers import requests as r_

requests = r_ # type: ignore[assignment]

global _WRITE_KEY, _ANON_TRACKER_ENDPOINT, _THREAD_POOL
# start the pool
# create thread pool to send telemetry to anonymous tracker
if not _SESSION:
_SESSION = requests.Session()
# flush pool on exit
atexit.register(_at_exit_cleanup)
if _THREAD_POOL is None:
_THREAD_POOL = ManagedThreadPool("anon_tracker", 1)
# do not instantiate lazy: this happens in _future_send which may be triggered
# from a thread and also in parallel when many pipelines are run at once
_THREAD_POOL._create_thread_pool()
# store write key if present
if config.dlthub_telemetry_segment_write_key:
key_bytes = (config.dlthub_telemetry_segment_write_key + ":").encode("ascii")
Expand All @@ -53,8 +60,13 @@ def init_anon_tracker(config: RunConfiguration) -> None:


def disable_anon_tracker() -> None:
_at_exit_cleanup()
atexit.unregister(_at_exit_cleanup)
global _WRITE_KEY, _TRACKER_CONTEXT, _ANON_TRACKER_ENDPOINT, _THREAD_POOL
if _THREAD_POOL is not None:
_THREAD_POOL.stop(True)
_ANON_TRACKER_ENDPOINT = None
_WRITE_KEY = None
_TRACKER_CONTEXT = None
_THREAD_POOL = None


def track(event_category: TEventCategory, event_name: str, properties: DictStrAny) -> None:
Expand Down Expand Up @@ -84,17 +96,6 @@ def before_send(event: DictStrAny) -> Optional[DictStrAny]:
return event


def _at_exit_cleanup() -> None:
global _SESSION, _WRITE_KEY, _TRACKER_CONTEXT, _ANON_TRACKER_ENDPOINT
if _SESSION:
_THREAD_POOL.stop(True)
_SESSION.close()
_SESSION = None
_ANON_TRACKER_ENDPOINT = None
_WRITE_KEY = None
_TRACKER_CONTEXT = None


def _tracker_request_header(write_key: str) -> StrAny:
"""Use a segment write key to create authentication headers for the segment API.
Expand Down Expand Up @@ -188,7 +189,7 @@ def _send_event(event_name: str, properties: StrAny, context: StrAny) -> None:
def _future_send() -> None:
# import time
# start_ts = time.time_ns()
resp = _SESSION.post(
resp = requests.post(
_ANON_TRACKER_ENDPOINT, headers=headers, json=payload, timeout=_REQUEST_TIMEOUT
)
# end_ts = time.time_ns()
Expand Down
7 changes: 7 additions & 0 deletions dlt/common/runtime/telemetry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import atexit
import time
import contextlib
import inspect
Expand All @@ -12,6 +13,7 @@
disable_anon_tracker,
track,
)
from dlt.pipeline.platform import disable_platform_tracker, init_platform_tracker

_TELEMETRY_STARTED = False

Expand All @@ -32,9 +34,13 @@ def start_telemetry(config: RunConfiguration) -> None:
if config.dlthub_telemetry:
init_anon_tracker(config)

if config.dlthub_dsn:
init_platform_tracker()

_TELEMETRY_STARTED = True


@atexit.register
def stop_telemetry() -> None:
global _TELEMETRY_STARTED
if not _TELEMETRY_STARTED:
Expand All @@ -48,6 +54,7 @@ def stop_telemetry() -> None:
pass

disable_anon_tracker()
disable_platform_tracker()

_TELEMETRY_STARTED = False

Expand Down
27 changes: 24 additions & 3 deletions dlt/pipeline/platform.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Implements SupportsTracking"""
from typing import Any, cast, TypedDict, List
import requests
from urllib.parse import urljoin
from requests import Session

from dlt.common import logger
from dlt.common.json import json
Expand All @@ -11,9 +10,10 @@

from dlt.pipeline.trace import PipelineTrace, PipelineStepTrace, TPipelineStep, SupportsPipeline

_THREAD_POOL: ManagedThreadPool = ManagedThreadPool(1)
_THREAD_POOL: ManagedThreadPool = None
TRACE_URL_SUFFIX = "/trace"
STATE_URL_SUFFIX = "/state"
requests: Session = None


class TPipelineSyncPayload(TypedDict):
Expand All @@ -25,6 +25,27 @@ class TPipelineSyncPayload(TypedDict):
schemas: List[TStoredSchema]


def init_platform_tracker() -> None:
# lazily import requests to avoid binding config before initialization
global requests
from dlt.sources.helpers import requests as r_

requests = r_ # type: ignore[assignment]

global _THREAD_POOL
if _THREAD_POOL is None:
_THREAD_POOL = ManagedThreadPool("platform_tracker", 1)
# create thread pool in controlled way, not lazy
_THREAD_POOL._create_thread_pool()


def disable_platform_tracker() -> None:
global _THREAD_POOL
if _THREAD_POOL:
_THREAD_POOL.stop()
_THREAD_POOL = None


def _send_trace_to_platform(trace: PipelineTrace, pipeline: SupportsPipeline) -> None:
"""
Send the full trace after a run operation to the platform
Expand Down
25 changes: 22 additions & 3 deletions tests/common/runtime/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import base64
from unittest.mock import patch

from pytest_mock import MockerFixture

from dlt.common import logger
from dlt.common.runtime.anon_tracker import get_anonymous_id, track, disable_anon_tracker
from dlt.common.typing import DictStrAny, DictStrStr
Expand All @@ -30,7 +32,7 @@ class SentryLoggerConfiguration(RunConfiguration):
sentry_dsn: str = (
"https://[email protected]/4504819859914752"
)
dlthub_telemetry_segment_write_key: str = "TLJiyRkGVZGCi2TtjClamXpFcxAA1rSB"
# dlthub_telemetry_segment_write_key: str = "TLJiyRkGVZGCi2TtjClamXpFcxAA1rSB"


@configspec
Expand Down Expand Up @@ -133,17 +135,34 @@ def test_sentry_init(environment: DictStrStr) -> None:


@pytest.mark.forked
def test_track_anon_event() -> None:
def test_track_anon_event(mocker: MockerFixture) -> None:
from dlt.sources.helpers import requests
from dlt.common.runtime import anon_tracker

mock_github_env(os.environ)
mock_pod_env(os.environ)
config = SentryLoggerConfiguration()

requests_post = mocker.spy(requests, "post")

props = {"destination_name": "duckdb", "elapsed_time": 712.23123, "success": True}
with patch("dlt.common.runtime.anon_tracker.before_send", _mock_before_send):
start_test_telemetry(SentryLoggerConfiguration())
start_test_telemetry(config)
track("pipeline", "run", props)
# this will send stuff
disable_anon_tracker()

event = SENT_ITEMS[0]
# requests were really called
requests_post.assert_called_once_with(
config.dlthub_telemetry_endpoint,
headers=anon_tracker._tracker_request_header(None),
json=event,
timeout=anon_tracker._REQUEST_TIMEOUT,
)
# was actually delivered
assert requests_post.spy_return.status_code == 204

assert event["anonymousId"] == get_anonymous_id()
assert event["event"] == "pipeline_run"
assert props.items() <= event["properties"].items()
Expand Down
16 changes: 11 additions & 5 deletions tests/pipeline/test_platform_connection.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import dlt
import os
import time
import pytest
import requests_mock

import dlt

from tests.utils import start_test_telemetry, stop_telemetry

TRACE_URL_SUFFIX = "/trace"
STATE_URL_SUFFIX = "/state"


@pytest.mark.forked
def test_platform_connection() -> None:
mock_platform_url = "http://platform.com/endpoint"

os.environ["RUNTIME__DLTHUB_DSN"] = mock_platform_url
start_test_telemetry()

trace_url = mock_platform_url + TRACE_URL_SUFFIX
state_url = mock_platform_url + STATE_URL_SUFFIX
Expand Down Expand Up @@ -42,11 +46,13 @@ def data():
m.put(mock_platform_url, json={}, status_code=200)
p.run([my_source(), my_source_2()])

# sleep a bit and find trace in mock requests
time.sleep(2)
# this will flush all telemetry queues
stop_telemetry()

trace_result = None
state_result = None

# mock tracks all calls
for call in m.request_history:
if call.url == trace_url:
assert not trace_result, "Multiple calls to trace endpoint"
Expand Down

0 comments on commit 80684d1

Please sign in to comment.