Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improves trackers thread pool #1340

Merged
merged 5 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/test_common.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ jobs:
if: runner.os != 'Windows'
name: Run pipeline smoke tests with minimum deps Linux/MAC
- run: |
poetry run pytest tests/pipeline/test_pipeline.py tests/pipeline/test_import_export_schema.py
poetry run pytest tests/pipeline/test_pipeline.py tests/pipeline/test_import_export_schema.py -m "not forked"
if: runner.os == 'Windows'
name: Run smoke tests with minimum deps Windows
shell: cmd
Expand All @@ -117,7 +117,7 @@ jobs:
if: runner.os != 'Windows'
name: Run pipeline tests with pyarrow but no pandas installed
- run: |
poetry run pytest tests/pipeline/test_pipeline_extra.py -k arrow
poetry run pytest tests/pipeline/test_pipeline_extra.py -k arrow -m "not forked"
if: runner.os == 'Windows'
name: Run pipeline tests with pyarrow but no pandas installed Windows
shell: cmd
Expand All @@ -130,7 +130,7 @@ jobs:
if: runner.os != 'Windows'
name: Run extract and pipeline tests Linux/MAC
- run: |
poetry run pytest tests/extract tests/pipeline tests/libs tests/cli/common tests/destinations
poetry run pytest tests/extract tests/pipeline tests/libs tests/cli/common tests/destinations -m "not forked"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to skip these tests on windows?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because they are run in fork and that is not available on windows. the test is messing up with global objects

if: runner.os == 'Windows'
name: Run extract tests Windows
shell: cmd
Expand Down
10 changes: 5 additions & 5 deletions dlt/common/managed_thread_pool.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from typing import Optional

import atexit
from concurrent.futures import ThreadPoolExecutor


class ManagedThreadPool:
def __init__(self, max_workers: int = 1) -> None:
def __init__(self, thread_name_prefix: str, max_workers: int = 1) -> None:
self.thread_name_prefix = thread_name_prefix
self._max_workers = max_workers
self._thread_pool: Optional[ThreadPoolExecutor] = None

def _create_thread_pool(self) -> None:
assert not self._thread_pool, "Thread pool already created"
self._thread_pool = ThreadPoolExecutor(self._max_workers)
# flush pool on exit
atexit.register(self.stop)
self._thread_pool = ThreadPoolExecutor(
self._max_workers, thread_name_prefix=self.thread_name_prefix
)

@property
def thread_pool(self) -> ThreadPoolExecutor:
Expand Down
51 changes: 26 additions & 25 deletions dlt/common/runtime/anon_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,28 @@

# several code fragments come from https://github.com/RasaHQ/rasa/blob/main/rasa/telemetry.py
import os

import atexit
import base64
import requests
from typing import Literal, Optional
from dlt.common.configuration.paths import get_dlt_data_dir
from requests import Session

from dlt.common import logger
from dlt.common.managed_thread_pool import ManagedThreadPool
from dlt.common.configuration.specs import RunConfiguration
from dlt.common.configuration.paths import get_dlt_data_dir
from dlt.common.runtime.exec_info import get_execution_context, TExecutionContext
from dlt.common.typing import DictStrAny, StrAny
from dlt.common.utils import uniq_id

from dlt.version import __version__

TEventCategory = Literal["pipeline", "command", "helper"]

_THREAD_POOL: ManagedThreadPool = ManagedThreadPool(1)
_SESSION: requests.Session = None
_THREAD_POOL: ManagedThreadPool = None
_WRITE_KEY: str = None
_REQUEST_TIMEOUT = (1.0, 1.0) # short connect & send timeouts
_ANON_TRACKER_ENDPOINT: str = None
_TRACKER_CONTEXT: TExecutionContext = None
requests: Session = None


def init_anon_tracker(config: RunConfiguration) -> None:
Expand All @@ -36,12 +35,20 @@ def init_anon_tracker(config: RunConfiguration) -> None:
config.dlthub_telemetry_segment_write_key
), "dlthub_telemetry_segment_write_key not present in RunConfiguration"

global _WRITE_KEY, _SESSION, _ANON_TRACKER_ENDPOINT
# lazily import requests to avoid binding config before initialization
global requests
from dlt.sources.helpers import requests as r_

requests = r_ # type: ignore[assignment]

global _WRITE_KEY, _ANON_TRACKER_ENDPOINT, _THREAD_POOL
# start the pool
# create thread pool to send telemetry to anonymous tracker
if not _SESSION:
_SESSION = requests.Session()
# flush pool on exit
atexit.register(_at_exit_cleanup)
if _THREAD_POOL is None:
_THREAD_POOL = ManagedThreadPool("anon_tracker", 1)
# do not instantiate lazy: this happens in _future_send which may be triggered
# from a thread and also in parallel when many pipelines are run at once
_THREAD_POOL._create_thread_pool()
# store write key if present
if config.dlthub_telemetry_segment_write_key:
key_bytes = (config.dlthub_telemetry_segment_write_key + ":").encode("ascii")
Expand All @@ -53,8 +60,13 @@ def init_anon_tracker(config: RunConfiguration) -> None:


def disable_anon_tracker() -> None:
_at_exit_cleanup()
atexit.unregister(_at_exit_cleanup)
global _WRITE_KEY, _TRACKER_CONTEXT, _ANON_TRACKER_ENDPOINT, _THREAD_POOL
if _THREAD_POOL is not None:
_THREAD_POOL.stop(True)
_ANON_TRACKER_ENDPOINT = None
_WRITE_KEY = None
_TRACKER_CONTEXT = None
_THREAD_POOL = None


def track(event_category: TEventCategory, event_name: str, properties: DictStrAny) -> None:
Expand Down Expand Up @@ -84,17 +96,6 @@ def before_send(event: DictStrAny) -> Optional[DictStrAny]:
return event


def _at_exit_cleanup() -> None:
global _SESSION, _WRITE_KEY, _TRACKER_CONTEXT, _ANON_TRACKER_ENDPOINT
if _SESSION:
_THREAD_POOL.stop(True)
_SESSION.close()
_SESSION = None
_ANON_TRACKER_ENDPOINT = None
_WRITE_KEY = None
_TRACKER_CONTEXT = None


def _tracker_request_header(write_key: str) -> StrAny:
"""Use a segment write key to create authentication headers for the segment API.

Expand Down Expand Up @@ -188,7 +189,7 @@ def _send_event(event_name: str, properties: StrAny, context: StrAny) -> None:
def _future_send() -> None:
# import time
# start_ts = time.time_ns()
resp = _SESSION.post(
resp = requests.post(
_ANON_TRACKER_ENDPOINT, headers=headers, json=payload, timeout=_REQUEST_TIMEOUT
)
# end_ts = time.time_ns()
Expand Down
7 changes: 7 additions & 0 deletions dlt/common/runtime/telemetry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import atexit
import time
import contextlib
import inspect
Expand All @@ -12,6 +13,7 @@
disable_anon_tracker,
track,
)
from dlt.pipeline.platform import disable_platform_tracker, init_platform_tracker

_TELEMETRY_STARTED = False

Expand All @@ -32,9 +34,13 @@ def start_telemetry(config: RunConfiguration) -> None:
if config.dlthub_telemetry:
init_anon_tracker(config)

if config.dlthub_dsn:
init_platform_tracker()

_TELEMETRY_STARTED = True


@atexit.register
def stop_telemetry() -> None:
global _TELEMETRY_STARTED
if not _TELEMETRY_STARTED:
Expand All @@ -48,6 +54,7 @@ def stop_telemetry() -> None:
pass

disable_anon_tracker()
disable_platform_tracker()

_TELEMETRY_STARTED = False

Expand Down
27 changes: 24 additions & 3 deletions dlt/pipeline/platform.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Implements SupportsTracking"""
from typing import Any, cast, TypedDict, List
import requests
from urllib.parse import urljoin
from requests import Session

from dlt.common import logger
from dlt.common.json import json
Expand All @@ -11,9 +10,10 @@

from dlt.pipeline.trace import PipelineTrace, PipelineStepTrace, TPipelineStep, SupportsPipeline

_THREAD_POOL: ManagedThreadPool = ManagedThreadPool(1)
_THREAD_POOL: ManagedThreadPool = None
TRACE_URL_SUFFIX = "/trace"
STATE_URL_SUFFIX = "/state"
requests: Session = None


class TPipelineSyncPayload(TypedDict):
Expand All @@ -25,6 +25,27 @@ class TPipelineSyncPayload(TypedDict):
schemas: List[TStoredSchema]


def init_platform_tracker() -> None:
# lazily import requests to avoid binding config before initialization
global requests
from dlt.sources.helpers import requests as r_

requests = r_ # type: ignore[assignment]

global _THREAD_POOL
if _THREAD_POOL is None:
_THREAD_POOL = ManagedThreadPool("platform_tracker", 1)
# create thread pool in controlled way, not lazy
_THREAD_POOL._create_thread_pool()


def disable_platform_tracker() -> None:
global _THREAD_POOL
if _THREAD_POOL:
_THREAD_POOL.stop()
_THREAD_POOL = None


def _send_trace_to_platform(trace: PipelineTrace, pipeline: SupportsPipeline) -> None:
"""
Send the full trace after a run operation to the platform
Expand Down
25 changes: 22 additions & 3 deletions tests/common/runtime/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import base64
from unittest.mock import patch

from pytest_mock import MockerFixture

from dlt.common import logger
from dlt.common.runtime.anon_tracker import get_anonymous_id, track, disable_anon_tracker
from dlt.common.typing import DictStrAny, DictStrStr
Expand All @@ -30,7 +32,7 @@ class SentryLoggerConfiguration(RunConfiguration):
sentry_dsn: str = (
"https://[email protected]/4504819859914752"
)
dlthub_telemetry_segment_write_key: str = "TLJiyRkGVZGCi2TtjClamXpFcxAA1rSB"
# dlthub_telemetry_segment_write_key: str = "TLJiyRkGVZGCi2TtjClamXpFcxAA1rSB"


@configspec
Expand Down Expand Up @@ -133,17 +135,34 @@ def test_sentry_init(environment: DictStrStr) -> None:


@pytest.mark.forked
def test_track_anon_event() -> None:
def test_track_anon_event(mocker: MockerFixture) -> None:
from dlt.sources.helpers import requests
from dlt.common.runtime import anon_tracker

mock_github_env(os.environ)
mock_pod_env(os.environ)
config = SentryLoggerConfiguration()

requests_post = mocker.spy(requests, "post")

props = {"destination_name": "duckdb", "elapsed_time": 712.23123, "success": True}
with patch("dlt.common.runtime.anon_tracker.before_send", _mock_before_send):
start_test_telemetry(SentryLoggerConfiguration())
start_test_telemetry(config)
track("pipeline", "run", props)
# this will send stuff
disable_anon_tracker()

event = SENT_ITEMS[0]
# requests were really called
requests_post.assert_called_once_with(
config.dlthub_telemetry_endpoint,
headers=anon_tracker._tracker_request_header(None),
json=event,
timeout=anon_tracker._REQUEST_TIMEOUT,
)
# was actually delivered
assert requests_post.spy_return.status_code == 204

assert event["anonymousId"] == get_anonymous_id()
assert event["event"] == "pipeline_run"
assert props.items() <= event["properties"].items()
Expand Down
16 changes: 11 additions & 5 deletions tests/pipeline/test_platform_connection.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import dlt
import os
import time
import pytest
import requests_mock

import dlt

from tests.utils import start_test_telemetry, stop_telemetry

TRACE_URL_SUFFIX = "/trace"
STATE_URL_SUFFIX = "/state"


@pytest.mark.forked
def test_platform_connection() -> None:
mock_platform_url = "http://platform.com/endpoint"

os.environ["RUNTIME__DLTHUB_DSN"] = mock_platform_url
start_test_telemetry()

trace_url = mock_platform_url + TRACE_URL_SUFFIX
state_url = mock_platform_url + STATE_URL_SUFFIX
Expand Down Expand Up @@ -42,11 +46,13 @@ def data():
m.put(mock_platform_url, json={}, status_code=200)
p.run([my_source(), my_source_2()])

# sleep a bit and find trace in mock requests
time.sleep(2)
# this will flush all telemetry queues
stop_telemetry()

trace_result = None
state_result = None

# mock tracks all calls
for call in m.request_history:
if call.url == trace_url:
assert not trace_result, "Multiple calls to trace endpoint"
Expand Down
Loading