Skip to content

Commit

Permalink
Fix race in set_tracer_provider
Browse files Browse the repository at this point in the history
  • Loading branch information
aabmass committed Oct 5, 2021
1 parent c9b18c6 commit 7edb787
Show file tree
Hide file tree
Showing 8 changed files with 277 additions and 34 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased](https://github.com/open-telemetry/opentelemetry-python/compare/v1.5.0-0.24b0...HEAD)
- Fix race in `set_tracer_provider()`
([#2182](https://github.com/open-telemetry/opentelemetry-python/pull/2182))
- Automatically load OTEL environment variables as options for `opentelemetry-instrument`
([#1969](https://github.com/open-telemetry/opentelemetry-python/pull/1969))
- `opentelemetry-semantic-conventions` Update to semantic conventions v1.6.1
Expand Down
50 changes: 31 additions & 19 deletions opentelemetry-api/src/opentelemetry/trace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
)
from opentelemetry.trace.status import Status, StatusCode
from opentelemetry.util import types
from opentelemetry.util._once import Once
from opentelemetry.util._providers import _load_provider

logger = getLogger(__name__)
Expand Down Expand Up @@ -452,8 +453,19 @@ def start_as_current_span(
yield INVALID_SPAN


_TRACER_PROVIDER = None
_PROXY_TRACER_PROVIDER = None
_TRACER_PROVIDER_SET_ONCE = Once()
_TRACER_PROVIDER: Optional[TracerProvider] = None
_PROXY_TRACER_PROVIDER = ProxyTracerProvider()


def _reset_globals() -> None:
"""WARNING: only use this for tests."""
global _TRACER_PROVIDER_SET_ONCE # pylint: disable=global-statement
global _TRACER_PROVIDER # pylint: disable=global-statement
global _PROXY_TRACER_PROVIDER # pylint: disable=global-statement
_TRACER_PROVIDER_SET_ONCE = Once()
_TRACER_PROVIDER = None
_PROXY_TRACER_PROVIDER = ProxyTracerProvider()


def get_tracer(
Expand All @@ -476,40 +488,40 @@ def get_tracer(
)


def _set_tracer_provider(tracer_provider: TracerProvider, log: bool) -> None:
def set_tp() -> None:
global _TRACER_PROVIDER # pylint: disable=global-statement
_TRACER_PROVIDER = tracer_provider

did_set = _TRACER_PROVIDER_SET_ONCE.do_once(set_tp)

if not did_set:
logger.warning("Overriding of current TracerProvider is not allowed")


def set_tracer_provider(tracer_provider: TracerProvider) -> None:
"""Sets the current global :class:`~.TracerProvider` object.
This can only be done once, a warning will be logged if any furter attempt
is made.
"""
global _TRACER_PROVIDER # pylint: disable=global-statement

if _TRACER_PROVIDER is not None:
logger.warning("Overriding of current TracerProvider is not allowed")
return

_TRACER_PROVIDER = tracer_provider
_set_tracer_provider(tracer_provider, log=True)


def get_tracer_provider() -> TracerProvider:
"""Gets the current global :class:`~.TracerProvider` object."""
# pylint: disable=global-statement
global _TRACER_PROVIDER
global _PROXY_TRACER_PROVIDER

if _TRACER_PROVIDER is None:
# if a global tracer provider has not been set either via code or env
# vars, return a proxy tracer provider
if OTEL_PYTHON_TRACER_PROVIDER not in os.environ:
if not _PROXY_TRACER_PROVIDER:
_PROXY_TRACER_PROVIDER = ProxyTracerProvider()
return _PROXY_TRACER_PROVIDER

_TRACER_PROVIDER = cast( # type: ignore
"TracerProvider",
_load_provider(OTEL_PYTHON_TRACER_PROVIDER, "tracer_provider"),
tracer_provider: TracerProvider = _load_provider(
OTEL_PYTHON_TRACER_PROVIDER, "tracer_provider"
)
return _TRACER_PROVIDER
_set_tracer_provider(tracer_provider, log=False)
# _TRACER_PROVIDER will have been set by one thread
return cast("TracerProvider", _TRACER_PROVIDER)


@contextmanager # type: ignore
Expand Down
47 changes: 47 additions & 0 deletions opentelemetry-api/src/opentelemetry/util/_once.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from threading import Lock
from typing import Callable


class Once:
"""Execute a function exactly once and block all callers until the function returns
Same as golang's `sync.Once <https://pkg.go.dev/sync#Once>`_
"""

def __init__(self) -> None:
self._lock = Lock()
self._done = False

def do_once(self, func: Callable[[], None]) -> bool:
"""Execute ``func`` if it hasn't been executed or return.
Will block until ``func`` has been called by one thread.
Returns:
Whether or not ``func`` was executed in this call
"""

# fast path, try to avoid locking
if self._done:
return False

with self._lock:
if not self._done:
func()
self._done = True
return True
return False
72 changes: 62 additions & 10 deletions opentelemetry-api/tests/trace/test_globals.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import unittest
from unittest.mock import patch
from unittest.mock import Mock, patch

from opentelemetry import context, trace
from opentelemetry.test.concurrency_test import ConcurrencyTestBase, MockFunc
from opentelemetry.trace.status import Status, StatusCode


Expand All @@ -27,23 +28,74 @@ def record_exception(

class TestGlobals(unittest.TestCase):
def setUp(self):
self._patcher = patch("opentelemetry.trace._TRACER_PROVIDER")
self._mock_tracer_provider = self._patcher.start()
super().setUp()
trace._reset_globals() # pylint: disable=protected-access

def tearDown(self) -> None:
self._patcher.stop()
def tearDown(self):
super().tearDown()
trace._reset_globals() # pylint: disable=protected-access

def test_get_tracer(self):
@staticmethod
@patch("opentelemetry.trace._TRACER_PROVIDER")
def test_get_tracer(mock_tracer_provider): # type: ignore
"""trace.get_tracer should proxy to the global tracer provider."""
trace.get_tracer("foo", "var")
self._mock_tracer_provider.get_tracer.assert_called_with(
"foo", "var", None
)
mock_provider = unittest.mock.Mock()
mock_tracer_provider.get_tracer.assert_called_with("foo", "var", None)
mock_provider = Mock()
trace.get_tracer("foo", "var", mock_provider)
mock_provider.get_tracer.assert_called_with("foo", "var", None)


class TestGlobalsConcurrency(ConcurrencyTestBase):
def setUp(self):
super().setUp()
trace._reset_globals() # pylint: disable=protected-access

def tearDown(self):
super().tearDown()
trace._reset_globals() # pylint: disable=protected-access

@patch("opentelemetry.trace.logger")
def test_set_tracer_provider_many_threads(self, mock_logger) -> None: # type: ignore
mock_logger.warning = MockFunc()

def do_concurrently() -> Mock:
# first get a proxy tracer
proxy_tracer = trace.ProxyTracerProvider().get_tracer("foo")

# try to set the global tracer provider
mock_tracer_provider = Mock(get_tracer=MockFunc())
trace.set_tracer_provider(mock_tracer_provider)

# start a span through the proxy which will call through to the mock provider
proxy_tracer.start_span("foo")

return mock_tracer_provider

num_threads = 100
mock_tracer_providers = self.run_with_many_threads(
do_concurrently,
num_threads=num_threads,
)

# despite trying to set tracer provider many times, only one of the
# mock_tracer_providers should have stuck and been called from
# proxy_tracer.start_span()
mock_tps_with_any_call = [
mock
for mock in mock_tracer_providers
if mock.get_tracer.call_count > 0
]

self.assertEqual(len(mock_tps_with_any_call), 1)
self.assertEqual(
mock_tps_with_any_call[0].get_tracer.call_count, num_threads
)

# should have warned everytime except for the successful set
self.assertEqual(mock_logger.warning.call_count, num_threads - 1)


class TestTracer(unittest.TestCase):
def setUp(self):
# pylint: disable=protected-access
Expand Down
15 changes: 11 additions & 4 deletions opentelemetry-api/tests/trace/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,15 @@ class TestSpan(NonRecordingSpan):


class TestProxy(unittest.TestCase):
def test_proxy_tracer(self):
original_provider = trace._TRACER_PROVIDER
def setUp(self) -> None:
super().setUp()
trace._reset_globals()

def tearDown(self) -> None:
super().tearDown()
trace._reset_globals()

def test_proxy_tracer(self):
provider = trace.get_tracer_provider()
# proxy provider
self.assertIsInstance(provider, trace.ProxyTracerProvider)
Expand All @@ -60,6 +66,9 @@ def test_proxy_tracer(self):
# set a real provider
trace.set_tracer_provider(TestProvider())

# get_tracer_provider() now returns the real provider
self.assertIsInstance(trace.get_tracer_provider(), TestProvider)

# tracer provider now returns real instance
self.assertIsInstance(trace.get_tracer_provider(), TestProvider)

Expand All @@ -71,5 +80,3 @@ def test_proxy_tracer(self):
# creates real spans
with tracer.start_span("") as span:
self.assertIsInstance(span, TestSpan)

trace._TRACER_PROVIDER = original_provider
48 changes: 48 additions & 0 deletions opentelemetry-api/tests/util/test_once.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from opentelemetry.test.concurrency_test import ConcurrencyTestBase, MockFunc
from opentelemetry.util._once import Once


class TestOnce(ConcurrencyTestBase):
def test_once_single_thread(self):
once_func = MockFunc()
once = Once()

self.assertEqual(once_func.call_count, 0)

# first call should run
called = once.do_once(once_func)
self.assertTrue(called)
self.assertEqual(once_func.call_count, 1)

# subsequent calls do nothing
called = once.do_once(once_func)
self.assertFalse(called)
self.assertEqual(once_func.call_count, 1)

def test_once_many_threads(self):
once_func = MockFunc()
once = Once()

def run_concurrently() -> bool:
return once.do_once(once_func)

results = self.run_with_many_threads(run_concurrently, num_threads=100)

self.assertEqual(once_func.call_count, 1)

# check that only one of the threads got True
self.assertEqual(results.count(True), 1)
75 changes: 75 additions & 0 deletions tests/util/src/opentelemetry/test/concurrency_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import threading
import unittest
from functools import partial
from typing import Callable, List, Optional, TypeVar
from unittest.mock import Mock

ReturnT = TypeVar("ReturnT")


# Can't use Mock directly because its call count is not thread safe
class MockFunc:
def __init__(self) -> None:
self.lock = threading.Lock()
self.call_count = 0
self.mock = Mock()

def __call__(self, *args, **kwargs):
with self.lock:
self.call_count += 1
return self.mock


class ConcurrencyTestBase(unittest.TestCase):
orig_switch_interval = sys.getswitchinterval()

@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
# switch threads more often to increase chance of contention
sys.setswitchinterval(1e-12)

@classmethod
def tearDownClass(cls) -> None:
super().tearDownClass()
sys.setswitchinterval(cls.orig_switch_interval)

@staticmethod
def run_with_many_threads(
func_to_test: Callable[[], ReturnT],
num_threads: int = 100,
) -> List[ReturnT]:
barrier = threading.Barrier(num_threads)
results: List[Optional[ReturnT]] = [None] * num_threads

def thread_start(idx: int) -> None:
nonlocal results
# Get all threads here before releasing them to create contention
barrier.wait()
results[idx] = func_to_test()

threads = [
threading.Thread(target=partial(thread_start, i))
for i in range(num_threads)
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()

return results # type: ignore
Loading

0 comments on commit 7edb787

Please sign in to comment.